WQS二分,一种优化一类特殊DP的方法。
很多最优化问题都是形如“一堆物品,取与不取之间有限制。现在规定只取k个,最大/小化总收益”。
这类问题最自然的想法是:设f[i][j]表示前i个取j个的最大收益,转移即可。复杂度O(n^2)。
那么,如果在某些情况下,可以通过将问题稍作转化,变成一个不强制选k个的DP,而最后DP出来的最优解一定正好选了k个,那么问题就会简化很多。
WQS二分就是基于这个思想。
首先考虑建一个二维坐标系,x轴是选的数的个数,y轴是最大收益,如果这个x-y图像有凸性,那么就可能通过给每个被选的数一个偏差值,将复杂度中的一个n变成log。因此,WQS二分又叫作凸优化/带权二分。
来看一个题:[BZOJ2654]Tree
按照上面所说建立坐标系,发现x-y图像的斜率单调递增。是一个下凸函数。
我们考虑给每一条白边减去某个值(一些地方是加上某个值,本质是一样的)cost,那么如果最终解选了x条边,则得到的值为实际值-cost*x。考虑这个式子的几何意义,就相当于将凸包通过斜率为cost的直线投影到y轴上。
可以发现,如果合适的选取cost值,可以使凸包上横坐标为k的这个投影后的纵坐标最大,这时就可以直接得出这个点的值了。
我们二分cost,于是问题转化为,求一棵每条白边都减去cost的图中的最小生成树,直接求MST即可。
每次根据哪个点投影后的纵坐标最大调整二分边界,这个类似于用一条直线去切这个凸包,根据切点横坐标调整。
这里需要注意一个问题,可能会存在k-1,k,k+1三点共线的情况,这时如果当前二分的直线正好与这三点平行。这是我们要保证它返回的切点一定在我们当前枚举的二分区间之内。具体到这道题就是通过给等长的边按颜色排序控制最终收益相同的方案中白边的个数。
1 #include<cstdio> 2 #include<algorithm> 3 #define rep(i,l,r) for (int i=l; i<=r; i++) 4 typedef long long ll; 5 using namespace std; 6 7 const int N=100100; 8 int n,m,cnt,tot,k,ans,u[N],v[N],w[N],c[N],fa[N]; 9 struct E{ int u,v,w,c; }e[N]; 10 11 bool operator<(E a,E b){ return a.w==b.w ? a.c>b.c : a.w<b.w; } 12 int find(int x){ return x==fa[x] ? x : fa[x]=find(fa[x]); } 13 14 bool check(int x){ 15 tot=cnt=0; 16 rep(i,1,n) fa[i]=i; 17 rep(i,1,m){ 18 e[i].u=u[i]; e[i].v=v[i]; e[i].w=w[i]; e[i].c=c[i]; 19 if (!c[i]) e[i].w-=x; 20 } 21 sort(e+1,e+m+1); 22 rep(i,1,m){ 23 int p=find(e[i].u),q=find(e[i].v); 24 if (p!=q){ 25 fa[p]=q; tot+=e[i].w; 26 if (!e[i].c) cnt++; 27 } 28 } 29 return cnt<=k; 30 } 31 32 int main(){ 33 freopen("bzoj2654.in","r",stdin); 34 freopen("bzoj2654.out","w",stdout); 35 scanf("%d%d%d",&n,&m,&k); 36 rep(i,1,m) scanf("%d%d%d%d",&u[i],&v[i],&w[i],&c[i]),u[i]++,v[i]++; 37 int L=-105,R=105; 38 while(L<=R){ 39 int mid=(L+R)>>1; 40 if (check(mid)) L=mid+1,ans=tot+k*mid; else R=mid-1; 41 } 42 printf("%d\n",ans); 43 return 0; 44 }
这题的经典做法是可撤销贪心,但也可以用WQS做。
首先同样建出坐标系,发现是一个斜率单增的上凸包。先二分斜率去掉只选K个的限制,问题简化成普通DP。
f[i][0/1]表示前i个数,第i个数选了/没选,的最小代价。
1 #include<cstdio> 2 #include<cstring> 3 #include<algorithm> 4 #define rep(i,l,r) for (int i=(l); i<=(r); i++) 5 typedef long long ll; 6 using namespace std; 7 8 const int N=100010; 9 const ll inf=1e15; 10 int n,k,s[N]; 11 ll L,R,ans; 12 struct P{ ll v,x; }f[N][2]; 13 P min(P a,P b){ if (a.v<b.v || (a.v==b.v && a.x<b.x)) return a; else return b; } 14 15 bool jud(ll cost){ 16 memset(f,0x7f,sizeof(f)); 17 f[1][0]=(P){0,0}; 18 rep(i,2,n){ 19 f[i][0]=min(f[i-1][0],f[i-1][1]); 20 f[i][1]=(P){f[i-1][0].v+s[i]-s[i-1]-cost,f[i-1][0].x+1}; 21 } 22 f[n][0]=min(f[n][0],f[n][1]); 23 if (f[n][0].x<=k) { ans=f[n][0].v+k*cost; return 1; } else return 0; 24 } 25 26 int main(){ 27 freopen("bzoj1150.in","r",stdin); 28 freopen("bzoj1150.out","w",stdout); 29 scanf("%d%d",&n,&k); 30 rep(i,1,n) scanf("%d",&s[i]),R+=s[i]; 31 while (L<=R){ 32 ll mid=(L+R)>>1; 33 if (jud(mid)) L=mid+1; else R=mid-1; 34 } 35 printf("%lld\n",ans); 36 return 0; 37 }
同上题,设f[i][0/1][0/1]表示前i个数,第一个数选了/没选,第i个数选了/没选,的最大收益。
1 #include<cstdio> 2 #include<algorithm> 3 #include<cstring> 4 #define rep(i,l,r) for (int i=(l); i<=(r); i++) 5 using namespace std; 6 7 const int N=200010,inf=1e9; 8 int n,m,ans,a[N],L,R; 9 struct P{ int v,x; }f[N][2][2]; 10 P max(P a,P b){ if (a.v>b.v || (a.v==b.v && a.x<b.x)) return a; else return b; } 11 P add(P s,int b){ return (P){s.v+b,s.x+1}; } 12 13 bool jud(int cost){ 14 memset(f,0,sizeof(f)); 15 f[1][1][1]=(P){a[1]-cost,1}; f[1][0][1]=f[1][1][0]=(P){-inf,0}; 16 rep(i,2,n){ 17 f[i][0][0]=max(f[i-1][0][0],f[i-1][0][1]); 18 f[i][1][0]=max(f[i-1][1][0],f[i-1][1][1]); 19 f[i][0][1]=add(f[i-1][0][0],a[i]-cost); 20 f[i][1][1]=add(f[i-1][1][0],a[i]-cost); 21 } 22 P s=max(max(f[n][0][0],f[n][0][1]),f[n][1][0]); 23 if (s.x<=m) { ans=s.v+m*cost; return 1; } else return 0; 24 } 25 26 int main(){ 27 freopen("bzoj2151.in","r",stdin); 28 freopen("bzoj2151.out","w",stdout); 29 scanf("%d%d",&n,&m); 30 if (m>n/2) { puts("Error!"); return 0; } 31 rep(i,1,n) scanf("%d",&a[i]); 32 L=-1001; R=1001; 33 while (L<=R){ 34 int mid=(L+R)>>1; 35 if (jud(mid)) R=mid-1; else L=mid+1; 36 } 37 printf("%d\n",ans); 38 return 0; 39 }
同样先二分斜率去掉K的限制,问题变为求最小冲突。
f[i]表示前i个人的最小冲突,s[i][j]表示冲突表的二维前缀和,则有f[i]=max{f[j]+(s[i][i]-s[i][j]-s[j][i]+2*s[j][j])/2}。
同时这个DP是有决策单调性的,于是问题就由O(n^2*k)优化到了O(nlognlogk)。
1 #include<cstdio> 2 #include<algorithm> 3 #include<cstring> 4 #define rep(i,l,r) for (int i=(l); i<=(r); i++) 5 using namespace std; 6 7 const int N=4010; 8 int n,k,st,ed,s[N][N],f[N],g[N]; 9 struct P{ int x,l,r; }q[N]; 10 11 int rd(){ 12 int x=0; char ch=getchar(); 13 while (ch<‘0‘ || ch>‘9‘) ch=getchar(); 14 while (ch>=‘0‘ && ch<=‘9‘) x=(x<<3)+(x<<1)+(ch^48),ch=getchar(); 15 return x; 16 } 17 18 int cal(int j,int i){ return f[j]+((s[i][i]-s[i][j]-s[j][i]+s[j][j])>>1); } 19 20 bool chk(int i,int j,int k){ 21 int x=cal(i,k),y=cal(j,k); 22 return (x<y) || (x==y && g[i]<g[j]); 23 } 24 25 int find(int i,int j){ 26 int l=q[ed].l,r=n,res=0; 27 while (l<=r){ 28 int mid=(l+r)>>1; 29 if (chk(i,j,mid)) res=mid,r=mid-1; else l=mid+1; 30 } 31 return res; 32 } 33 34 void solve(int c){ 35 st=ed=1; q[1]=(P){0,0,n}; 36 rep(i,1,n){ 37 ++q[st].l; if (q[st].l>q[st].r) st++; 38 f[i]=cal(q[st].x,i)-c; g[i]=g[q[st].x]+1; 39 if (st>ed || chk(i,q[ed].x,n)){ 40 while (st<=ed && chk(i,q[ed].x,q[ed].l)) ed--; 41 if (st>ed) q[++ed]=(P){i,i,n}; 42 else{ 43 int x=find(i,q[ed].x); 44 q[ed].r=x-1; q[++ed]=(P){i,x,n}; 45 } 46 } 47 } 48 } 49 50 int main(){ 51 freopen("bzoj5311.in","r",stdin); 52 freopen("bzoj5311.out","w",stdout); 53 scanf("%d%d",&n,&k); 54 rep(i,1,n) rep(j,1,n) s[i][j]=s[i-1][j]+s[i][j-1]-s[i-1][j-1]+rd(); 55 int l=-s[n][n],r=0,res=0; 56 while (l<=r){ 57 int mid=(l+r)>>1; solve(mid); 58 if (g[n]<=k) res=mid,l=mid+1; else r=mid-1; 59 } 60 solve(res); printf("%d\n",f[n]+k*res); 61 return 0; 62 }
https://www.cnblogs.com/HocRiser/p/9055203.html
1 #include<cstdio> 2 #include<cstring> 3 #include<algorithm> 4 #define rep(i,l,r) for (int i=l; i<=r; i++) 5 typedef long long ll; 6 using namespace std; 7 8 const int N=300010; 9 int n,k,u,v,w,cnt,to[N<<1],nxt[N<<1],val[N<<1],h[N]; 10 ll mid,tot; 11 void add(int u,int v,int w){ to[++cnt]=v; val[cnt]=w; nxt[cnt]=h[u]; h[u]=cnt; } 12 struct P{ 13 ll x,y; 14 bool operator < (const P &b) const {return x==b.x? y>b.y : x<b.x;} 15 P operator + (const P &b) const {return (P){x+b.x,y+b.y};} 16 P operator + (int b) {return (P){x+b,y};} 17 }dp[3][N]; 18 P upd(P a){ return (P){a.x-mid,a.y+1}; } 19 20 void dfs(int u,int fa){ 21 dp[2][u]=max(dp[2][u],(P){-mid,1}); 22 for (int i=h[u],v; i; i=nxt[i]) 23 if ((v=to[i])!=fa){ 24 dfs(v,u); 25 dp[2][u]=max(dp[2][u]+dp[0][v],upd(dp[1][u]+dp[1][v]+val[i])); 26 dp[1][u]=max(dp[1][u]+dp[0][v],dp[0][u]+dp[1][v]+val[i]); 27 dp[0][u]=dp[0][u]+dp[0][v]; 28 } 29 dp[0][u]=max(dp[0][u],max(upd(dp[1][u]),dp[2][u])); 30 } 31 32 int main(){ 33 freopen("lct.in","r",stdin); 34 freopen("lct.out","w",stdout); 35 scanf("%d%d",&n,&k); k++; 36 rep(i,2,n) scanf("%d%d%d",&u,&v,&w),tot+=abs(w),add(u,v,w),add(v,u,w); 37 ll L=-tot,R=tot; 38 while (L<=R){ 39 mid=(L+R)>>1; memset(dp,0,sizeof(dp)); dfs(1,0); 40 if (dp[0][1].y<=k) R=mid-1; else L=mid+1; 41 } 42 memset(dp,0,sizeof(dp)); mid=L; dfs(1,0); printf("%lld\n",L*k+dp[0][1].x); 43 return 0; 44 }
WQS二分的另外两个题:CF958E2,CF739E
原文地址:https://www.cnblogs.com/HocRiser/p/9834069.html