最近好颓废,什么都学不进去...
感谢两篇:AKMer - 浅谈树分治 言简意赅
LadyLex - 点分治&动态点分治小结 讲解+例题,学到很多东西
点分治
动态点分治
~ 点分治 ~
经常遇见一类树上的计数题,问的是在某些条件下,选择一些点的方案数
若对于每个点的统计都需要遍历以其为根节点的子树,普通的做法就是$O(n^2)$的,在很多时候是不满足要求的
而这是点分治的专长
点分治是这样进行的:
1. 找到当前树的重心
2. 将重心及重心连出的边全部删去,那么就能将原来的树分割成森林
3. 对于森林中的每棵树,继续找重心;不断地这样递归下去
其中,树的重心$x$表示,以$x$作为树的根,使得(以$x$的儿子为根的)最大子树大小最小
分析一下复杂度
一般来说,用到点分治的时候,需要对于当前子树$O(n)$进行dfs
得出重心也需要一个$O(n)$的dfs
而由于我们选择删去树的重心,所以分裂出的树中最大的不会超过原树大小的一半
所以整个算法的复杂度是$O(n\cdot logn)$
这样看来,点分治相当于从每一个点开始、对子树做一次dfs,但时间复杂度为$O(n\cdot logn)$;于是可以在很多问题中将一个$n$降成一个$logn$
模板题:Luogu P3806 (【模板】点分治1)
这道题可以用点分治这样解决:
若存在一条路径的长度为$k$,则对于当前子树的根$x$,要不在路径上,要不不在路径上
1. 若$x$在路径上,则相当于两个点$u,v$到$x$的距离之和为$k$,且$LCA(u,v)=x$
2. 若$x$不在路径上,则对于每个 $x$的儿子为根的子树 递归下去
由于$k<1\times 10^7$,所以可以开一个$cnt$数组,记录到$x$距离为$dist$的节点数$cnt[dist]$,记得适时清空
这题里面,对$cnt$数组的统计和赋值最好分开做,以免产生影响
#include <cstdio> #include <vector> #include <cstring> using namespace std; typedef pair<int,int> pii; const int N=10005; const int K=10000005; int n,m,val; vector<pii> v[N]; bool flag; bool vis[N]; int root; int sz[N],mx[N]; inline void Find(int x,int fa,int tot) { sz[x]=1; mx[x]=0; for(int i=0;i<v[x].size();i++) { int nxt=v[x][i].first; if(vis[nxt] || nxt==fa) continue; Find(nxt,x,tot); sz[x]+=sz[nxt]; mx[x]=max(mx[x],sz[nxt]); } mx[x]=max(mx[x],tot-sz[x]); if(!root || mx[x]<mx[root]) root=x; } int cnt[K]; inline void dfs(int x,int fa,int sum,int type) { if(sum<K) { cnt[sum]+=type; if(type==0 && val-sum>=0 && cnt[val-sum]) flag=true; } for(int i=0;i<v[x].size();i++) { int nxt=v[x][i].first,len=v[x][i].second; if(vis[nxt] || nxt==fa) continue; dfs(nxt,x,sum+len,type); } } inline void Calc(int x,int tot) { root=0; Find(x,0,tot); int cur=root; Find(cur,0,tot); cnt[0]++; for(int i=0;i<v[cur].size();i++) { int nxt=v[cur][i].first,len=v[cur][i].second; if(vis[nxt]) continue; dfs(nxt,cur,len,0); dfs(nxt,cur,len,1); } dfs(cur,0,0,-1); vis[cur]=true; for(int i=0;i<v[cur].size();i++) { int nxt=v[cur][i].first; if(vis[nxt]) continue; Calc(nxt,sz[nxt]); } vis[cur]=false; } int main() { scanf("%d%d",&n,&m); for(int i=1;i<n;i++) { int x,y,w; scanf("%d%d%d",&x,&y,&w); v[x].push_back(pii(y,w)); v[y].push_back(pii(x,w)); } while(m--) { scanf("%d",&val); flag=false; memset(vis,false,sizeof(vis)); Calc(1,n); printf(flag?"AYE\n":"NAY\n"); } return 0; }
稍微复杂一点的题:Luogu P4178 ($Tree$)
上题是等于$k$,这题是小于等于$k$,用树状数组求个和就行了
这题可以检验上题的计数是否不重不漏
#include <cstdio> #include <vector> #include <cstring> using namespace std; typedef pair<int,int> pii; const int N=40005; const int K=20005; int n,m,val; vector<pii> v[N]; int t[K]; inline int lowbit(int x) { return x&(-x); } inline void Add(int i,int x) { for(;i<=val+1;i+=lowbit(i)) t[i]+=x; } inline int Query(int i) { int res=0; for(;i;i-=lowbit(i)) res+=t[i]; return res; } int ans; bool vis[N]; int root; int sz[N],mx[N]; inline void Find(int x,int fa,int tot) { sz[x]=1; mx[x]=0; for(int i=0;i<v[x].size();i++) { int nxt=v[x][i].first; if(vis[nxt] || nxt==fa) continue; Find(nxt,x,tot); sz[x]+=sz[nxt]; mx[x]=max(mx[x],sz[nxt]); } mx[x]=max(mx[x],tot-sz[x]); if(!root || mx[x]<mx[root]) root=x; } inline void dfs(int x,int fa,int sum,int type) { if(sum<K) { Add(sum+1,type); if(type==0 && val-sum>=0) ans+=Query(val-sum+1); } for(int i=0;i<v[x].size();i++) { int nxt=v[x][i].first,len=v[x][i].second; if(vis[nxt] || nxt==fa) continue; dfs(nxt,x,sum+len,type); } } inline void Calc(int x,int tot) { root=0; Find(x,0,tot); int cur=root; Find(cur,0,tot); Add(1,1); for(int i=0;i<v[cur].size();i++) { int nxt=v[cur][i].first,len=v[cur][i].second; if(vis[nxt]) continue; dfs(nxt,cur,len,0); dfs(nxt,cur,len,1); } dfs(cur,0,0,-1); vis[cur]=true; for(int i=0;i<v[cur].size();i++) { int nxt=v[cur][i].first; if(vis[nxt]) continue; Calc(nxt,sz[nxt]); } vis[cur]=false; } int main() { scanf("%d",&n); for(int i=1;i<n;i++) { int x,y,w; scanf("%d%d%d",&x,&y,&w); v[x].push_back(pii(y,w)); v[y].push_back(pii(x,w)); } scanf("%d",&val); Calc(1,n); printf("%d\n",ans); return 0; }
感觉也许难点不在点分治上?BZOJ 4016 (最短路径树问题,$FJOI2014$)
首先根据定义,建立最短路径最小字典序树
建树的过程如下:
1. 对原图跑起点为$1$的Dijkstra
2. 若通过$x\rightarrow y$的一条边能够使得$1$到$y$的距离更小,那么将$x$作为最短路径树中$y$的父亲(最短路径为第一优先);若通过$x\rightarrow y$的一条边到达$y$与$1$到$y$的距离相同,那么比较在最短路径树中,$x$ 与 当前$y$在最短路径树中的父亲$fa$ 的LCA的下一层节点(这样能直接比较两条路径第一个不同的位置,从而使得字典序为第二优先),选择字典序更小的作为父亲
这样一波操作能够用$O(n\cdot (logn)^2)$建立这棵树
然后,枚举包含$K$个点的路径就是点分治的专长了:
记$len_i$表示,当前子树中深度为$i$的点的最大路径长度;$num_i$表示,该长度的路径有多少条
于是就可以用跟上一题完全一样的方法统计答案了,只不过细节稍微多一点
#include <queue> #include <vector> #include <cstdio> #include <cstring> #include <algorithm> using namespace std; typedef pair<int,int> pii; const int N=30005; const int INF=1<<30; int n,m,K; vector<pii> v[N]; int h[N]; int to[N][20]; inline bool LCA(int x,int y) { for(int i=19;i>=0;i--) if(h[to[x][i]]>=h[y]) x=to[x][i]; for(int i=19;i>=0;i--) if(h[to[y][i]]>=h[x]) y=to[y][i]; for(int i=19;i>=0;i--) if(to[x][i]!=to[y][i]) x=to[x][i],y=to[y][i]; return x<y; } vector<pii> nv[N]; int d[N],rev[N]; priority_queue<pii,vector<pii>,greater<pii> > Q; void Build() { for(int i=1;i<=n;i++) d[i]=INF; d[1]=0; Q.push(pii(0,1)); while(!Q.empty()) { int x=Q.top().second,D=Q.top().first; Q.pop(); if(D>d[x]) continue; for(int i=0;i<v[x].size();i++) { int nxt=v[x][i].first,w=v[x][i].second; if(D+w==d[nxt]) if(!to[nxt][0] || LCA(x,to[nxt][0])) { h[nxt]=h[x]+1; to[nxt][0]=x,rev[nxt]=i; for(int j=1;j<20;j++) to[nxt][j]=to[to[nxt][j-1]][j-1]; } if(D+w<d[nxt]) { d[nxt]=D+w; h[nxt]=h[x]+1; to[nxt][0]=x,rev[nxt]=i; for(int j=1;j<20;j++) to[nxt][j]=to[to[nxt][j-1]][j-1]; Q.push(pii(d[nxt],nxt)); } } } for(int i=2;i<=n;i++) { int fa=to[i][0],cost=v[fa][rev[i]].second; nv[fa].push_back(pii(i,cost)); nv[i].push_back(pii(fa,cost)); } } int root; int sz[N],mx[N]; bool vis[N]; inline void Find(int x,int fa,int tot) { sz[x]=1; mx[x]=0; for(int i=0;i<nv[x].size();i++) { int nxt=nv[x][i].first; if(nxt==fa || vis[nxt]) continue; Find(nxt,x,tot); sz[x]+=sz[nxt]; mx[x]=max(mx[x],sz[nxt]); } mx[x]=max(mx[x],tot-sz[x]); if(!root || mx[root]>mx[x]) root=x; } int len[N],num[N]; int ans,cnt; inline void dfs(int x,int fa,int dep,int sum,int type) { if(dep<=K) { if(type==0 && (K==dep || len[K-dep+1])) { if(ans<sum+len[K-dep+1]) ans=sum+len[K-dep+1],cnt=0; if(ans==sum+len[K-dep+1]) cnt+=num[K-dep+1]; } if(type==1) { if(len[dep]<sum) len[dep]=sum,num[dep]=0; if(len[dep]==sum) num[dep]++; } if(type==-1) len[dep]=num[dep]=0; } for(int i=0;i<nv[x].size();i++) { int nxt=nv[x][i].first,w=nv[x][i].second; if(nxt==fa || vis[nxt]) continue; dfs(nxt,x,dep+1,sum+w,type); } } inline void Calc(int x,int tot) { root=0; Find(x,0,tot); int cur=root; Find(cur,0,tot); len[1]=0,num[1]=1; for(int i=0;i<nv[cur].size();i++) { int nxt=nv[cur][i].first,w=nv[cur][i].second; if(vis[nxt]) continue; dfs(nxt,cur,2,w,0); dfs(nxt,cur,2,w,1); } dfs(cur,0,1,0,-1); vis[cur]=true; for(int i=0;i<nv[cur].size();i++) { int nxt=nv[cur][i].first; if(vis[nxt]) continue; Calc(nxt,sz[nxt]); } } int main() { scanf("%d%d%d",&n,&m,&K); for(int i=1;i<=m;i++) { int x,y,w; scanf("%d%d%d",&x,&y,&w); v[x].push_back(pii(y,w)); v[y].push_back(pii(x,w)); } Build(); Calc(1,n); printf("%d %d\n",ans,cnt); return 0; }
~ 动态点分治 ~
在学习这个之前,需要先了解下欧拉序求LCA
可以参考这篇:Little_Fall - 【笔记】dfs序,欧拉序,LCA的RMQ解法
简单地说,若将 到达一节点(无论是从父节点还是从子节点来的) 记为事件,那么在dfs的同时记录每一次事件的时间戳
记$st_i$为最早到达$i$节点的时间,$ed_i$为最后达到$i$节点的时间
对于 事件的时间戳 建立ST表,存的是某一段时间到达过的深度最浅的节点编号
要查询$LCA(u,v)$,就相当于求$[min(st_u,st_v),max(ed_u,ed_v)]$这段时间中所到达过的深度最浅的点
由于每发生一次事件都相当于经过一条边,而一条边只会被正向、反向各经过一次,所以总事件数是$2n$级别的(所以在实际用在动态点分治时一般取$LOG=logn+1$,不要开小)
需要$O(nlogn)$的预处理($2$倍常数),但查询是$O(1)$的
模板题:Luogu P3379 (【模板】最近公共祖先)
因为此题数据量比较大、卡常数,这种做法在T的边缘;在动态点分治的题目中不会这样卡
#include <cstdio> #include <vector> #include <cstring> #include <algorithm> using namespace std; const int LOG=20; const int N=500005; int n,m,root; vector<int> v[N]; int id; int dep[N]; int st[N],ed[N]; int rmq[N<<1][LOG]; void dfs(int x,int fa) { dep[x]=dep[fa]+1; rmq[++id][0]=x; st[x]=id; for(int i=0;i<v[x].size();i++) { int nxt=v[x][i]; if(nxt==fa) continue; dfs(nxt,x); rmq[++id][0]=x; } ed[x]=id; } inline int cmp(int x,int y) { return (dep[x]<dep[y]?x:y); } int log[N<<1]; void ST() { int pw=-1; for(int i=1;i<=id;i++) { if(i==(1<<(pw+1))) pw++; log[i]=pw; } for(int i=0,t=1;i<LOG-1;i++,t<<=1) for(int j=1;j<=id;j++) { int l=rmq[j][i],r=(j+t>id?rmq[j][i]:rmq[j+t][i]); rmq[j][i+1]=cmp(l,r); } } inline int LCA(int x,int y) { int lb=min(st[x],st[y]),rb=max(ed[x],ed[y]); int k=log[rb-lb+1]; return cmp(rmq[lb][k],rmq[rb-(1<<k)+1][k]); } int main() { scanf("%d%d%d",&n,&m,&root); for(int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); v[x].push_back(y); v[y].push_back(x); } dfs(root,0); ST(); for(int i=1;i<=m;i++) { int x,y; scanf("%d%d",&x,&y); printf("%d\n",LCA(x,y)); } return 0; }
在很多动态点分治的题目中,修改操作都需要高效查询LCA,算是传统艺能了
然后考虑引入动态点分治
在点分治中,树上的信息是不会被修改的;但是在另一些题目中,存在修改树上信息的操作
大体上的解决办法仍然是,对于每个点分别计算子树中的贡献;但是显然不能在原树中进行,否则修改时经过一条长链就直接升天
点分治时,我们采用了 对子树不断求重心、递归 的策略
如果把每次求得的子树重心连边,就可以将原树重构成一棵新树(称为分治树),且深度是$logn$级别的
分治树具有很好的性质:若将修改操作 限定在该树的一条链上,就可以做到单次$O(logn)$
分治树与原树有一些不同:
1. 分治树上的边在原树中不一定存在,不过分治树中的一个子树必然对应了原树中的一个子树
2. 分治树上父节点的信息不一定是直接通过子节点计算的(这个坑了我好久...)
举个例子说明2:若想保存分治树中节点$x$ 到其子树中每个节点(在原树中)的距离之和,并不能直接由子节点$v_1,v_2,...,v_m$得到——虽然在分治树中它们靠的很近,也许在原树中能差上十万八千里
而正确的做法是,将每个节点的贡献加到分治树中它的祖先($logn$级别)上
具体的修改操作因题而异,不过整体思路都是在分治树的链上修改/查询
先来一道经典题:Luogu P2056 / BZOJ 1095 (捉迷藏,$ZJOI2007$)
想要查询 整体的最远未开灯房间的距离,显然可以在分治树上处理:相当于对于分治树上的每个点$x$,求 以其为根节点的分治树子树中 的相同问题
求这个子问题,需要知道 以$x$为根的子树中 所有未开灯点到$x$ 在原树中的距离
可以考虑用堆来维护:一个房间的灯被开启或关闭时,向其在分治树中的祖先中都删除/插入 在原树中到该节点的距离(可删除堆见代码实现)
由于树高为$logn$,所以最多也就插入$n\cdot logn$级别的信息,在时间空间上都很ok
有了大致思路,就可以比较深入的考虑细节了:
对于每个节点$i$以及其分治树上的父节点$fa_i$,保存这些信息
1. 可删除堆$up[i]$,表示分治树上 以$i$为根的子树中,每个点到$fa_i$(在原树上)的距离
2. 可删除堆$down[i]$,表示对于每个分治树上 以$i$的儿子为根 的子树中,(在原树上)到$i$的最远距离
这样保存的思路是,用$up[i]$来更新$down[fa_i]$;因为$fa_i$对整体答案 只贡献跨子树的最远点对距离,所以两个备选点必须在不同儿子的子树中
整体的答案由可删除堆$Q$维护,由$down[i]$更新(选择前两个备选点间的路径)
用ST表求LCA压一下常数就可以过了
#include <queue> #include <cstdio> #include <vector> #include <cstring> #include <algorithm> using namespace std; struct Queue { priority_queue<int> add,del; inline void push(int x) { add.push(x); } inline void pop(int x) { del.push(x); } inline int size() { return add.size()-del.size(); } inline int top() { while(!del.empty() && add.top()==del.top()) add.pop(),del.pop(); return add.top(); } inline int merge() { int val=top(),res=0; add.pop(); res=val+top(); add.push(val); return res; } }; typedef pair<int,int> pii; const int N=100005; const int LOG=18; int n,m; vector<int> v[N]; int root; bool vis[N]; int sz[N],mx[N]; void Find(int x,int f,int tot) { sz[x]=1,mx[x]=0; for(int i=0;i<v[x].size();i++) { int nxt=v[x][i]; if(nxt==f || vis[nxt]) continue; Find(nxt,x,tot); sz[x]+=sz[nxt]; mx[x]=max(mx[x],sz[nxt]); } mx[x]=max(mx[x],tot-sz[x]); if(!root || mx[root]>mx[x]) root=x; } int fa[N]; void Build(int x,int f,int tot) { root=0; Find(x,0,tot); x=root; Find(x,0,tot); fa[x]=f; vis[x]=true; for(int i=0;i<v[x].size();i++) { int nxt=v[x][i]; if(vis[nxt]) continue; Build(nxt,x,sz[nxt]); } } int id; int dep[N]; int st[N],ed[N]; int rmq[N<<1][LOG]; void dfs(int x,int f) { dep[x]=dep[f]+1; rmq[++id][0]=x; st[x]=id; for(int i=0;i<v[x].size();i++) { int nxt=v[x][i]; if(nxt==f) continue; dfs(nxt,x); rmq[++id][0]=x; } ed[x]=id; } int log[N<<1]; inline int cmp(int x,int y) { return (dep[x]<dep[y]?x:y); } void ST() { int pw=-1; for(int i=1;i<=id;i++) { if(i==(1<<(pw+1))) pw++; log[i]=pw; } for(int i=0,t=1;i<LOG-1;i++,t<<=1) for(int j=1;j<=id;j++) { int l=rmq[j][i],r=(j+t>id?rmq[j][i]:rmq[j+t][i]); rmq[j][i+1]=cmp(l,r); } } inline int LCA(int x,int y) { int lb=min(st[x],st[y]),rb=max(ed[x],ed[y]); int k=log[rb-lb+1]; return cmp(rmq[lb][k],rmq[rb-(1<<k)+1][k]); } inline int Dist(int x,int y) { return dep[x]+dep[y]-dep[LCA(x,y)]*2; } int open[N]; Queue up[N],down[N]; Queue ans; inline void Add(int x) { if(down[x].size()>1) ans.pop(down[x].merge()); down[x].push(0); if(down[x].size()>1) ans.push(down[x].merge()); int i=x; while(fa[i]) { if(down[fa[i]].size()>1) ans.pop(down[fa[i]].merge()); if(up[i].size()>0) down[fa[i]].pop(up[i].top()); up[i].push(Dist(x,fa[i])); down[fa[i]].push(up[i].top()); if(down[fa[i]].size()>1) ans.push(down[fa[i]].merge()); i=fa[i]; } } inline void Delete(int x) { ans.pop(down[x].merge()); down[x].pop(0); if(down[x].size()>1) ans.push(down[x].merge()); int i=x; while(fa[i]) { if(down[fa[i]].size()>1) ans.pop(down[fa[i]].merge()); down[fa[i]].pop(up[i].top()); up[i].pop(Dist(x,fa[i])); if(up[i].size()>0) down[fa[i]].push(up[i].top()); if(down[fa[i]].size()>1) ans.push(down[fa[i]].merge()); i=fa[i]; } } int main() { scanf("%d",&n); for(int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); v[x].push_back(y); v[y].push_back(x); } dfs(1,0); ST(); Build(1,0,n); for(int i=1;i<=n;i++) { open[i]=1; Add(i); } int tot=n; scanf("%d",&m); while(m--) { char op[10]; scanf("%s",op); if(op[0]==‘C‘) { int x; scanf("%d",&x); open[x]^=1; if(open[x]) Add(x),tot++; else Delete(x),tot--; } else { if(tot<2) { printf("%d\n",tot-1); continue; } printf("%d\n",ans.top()); } } return 0; }
跟上一题比较类似的一题:HDU 5571 ($tree$)
参考了鸟神的题解:poursoul - 【HDU】5571 tree【动态点分治】
看到对于xor的统计,可以考虑对点权$a_i$拆位(以后一定要长记性= =)
拆位后题目就变成,统计所有01点对之间的路径长度之和
这可以在分治树上这样实现:
1. $cnt[p][i][dig]$表示,对于点权第$p$位,在以$i$为根节点的子树中,有多少个点值为$dig,dig\in \{0,1\}$
2. $sum[p][i][dig]$表示,对于点权第$p$位,在以$i$为根节点的子树中,点值为$dig$的所有节点到$i$(在原树中)的距离之和
3. $sub[p][i][dig]$表示,对于点权第$p$位,在以$i$为根节点的子树中,点值为$dig$的所有节点到$fa_i$(在原树中)的距离之和
4. $res[p][i]$表示,对于点权第$p$为,在以$i$为根节点的子树中,所有跨子树的01点对路径长度之和
对于点权$a_x$的修改,可以对每一位 先消除原$a_x$的贡献、再加上新$a_x$的贡献
对于点权第$p$位、待修改点$x$、新填入的值$dig$,可以这样更新其对分治树上某祖先$i$的父节点$fa_i$的贡献
$cnt,sum,sub$的更新是比较显然的
而对于$res[p][fa_i]$,增加了这些贡献:$x$到 不以$i$为根节点的子树中 点值为$1-dig$的节点 (在原树中)的距离之和
可以拆成两部分,一部分是$x$到$fa_i$的路径,一部分是$fa_i$到那些节点的路径;第一部分借助$cnt$,第二部分借助$sum,sub$之差,就可以解决
消除贡献就是这个的逆操作,不多赘述
#include <cstdio> #include <vector> #include <cstring> #include <algorithm> using namespace std; typedef long long ll; typedef pair<int,int> pii; const int N=30005; const int LOG=17; const int M=14; int n,m; int a[N]; vector<pii> v[N]; int id; int dep[N],dist[N]; int st[N],ed[N]; int rmq[N<<1][LOG]; void dfs(int x,int f) { dep[x]=dep[f]+1; rmq[++id][0]=x; st[x]=id; for(int i=0;i<v[x].size();i++) { int nxt=v[x][i].first,w=v[x][i].second; if(nxt==f) continue; dist[nxt]=dist[x]+w; dfs(nxt,x); rmq[++id][0]=x; } ed[x]=id; } int Log[N<<1]; inline int cmp(int x,int y) { return (dep[x]<dep[y]?x:y); } void ST() { int pw=-1; for(int i=1;i<=id;i++) { if(i==(1<<(pw+1))) pw++; Log[i]=pw; } for(int i=0,t=1;i<LOG-1;i++,t<<=1) for(int j=1;j<=id;j++) { int l=rmq[j][i],r=(j+t>id?rmq[j][i]:rmq[j+t][i]); rmq[j][i+1]=cmp(l,r); } } inline int LCA(int x,int y) { int lb=min(st[x],st[y]),rb=max(ed[x],ed[y]); int k=Log[rb-lb+1]; return cmp(rmq[lb][k],rmq[rb-(1<<k)+1][k]); } inline int Dist(int x,int y) { return dist[x]+dist[y]-dist[LCA(x,y)]*2; } int root; bool vis[N]; int sz[N],mx[N]; void Find(int x,int f,int tot) { sz[x]=1,mx[x]=0; for(int i=0;i<v[x].size();i++) { int nxt=v[x][i].first; if(vis[nxt] || nxt==f) continue; Find(nxt,x,tot); sz[x]+=sz[nxt]; mx[x]=max(mx[x],sz[nxt]); } mx[x]=max(mx[x],tot-sz[x]); if(!root || mx[root]>mx[x]) root=x; } int fa[N]; void Build(int x,int f,int tot) { root=0; Find(x,0,tot); x=root; Find(x,0,tot); fa[x]=f; vis[x]=true; for(int i=0;i<v[x].size();i++) { int nxt=v[x][i].first; if(vis[nxt]) continue; Build(nxt,x,sz[nxt]); } vis[x]=false; } int cnt[M][N][2]; ll sum[M][N][2],sub[M][N][2]; ll res[M][N]; ll ans; inline void Add(int p,int x,int dig) { ll w=1LL<<p; ans-=w*res[p][x]; res[p][x]+=sum[p][x][1-dig]; ans+=w*res[p][x]; cnt[p][x][dig]++; int i=x; while(fa[i]) { ll D=Dist(x,fa[i]); ans-=w*res[p][fa[i]]; res[p][fa[i]]+=(sum[p][fa[i]][1-dig]-sub[p][i][1-dig]); res[p][fa[i]]+=D*(cnt[p][fa[i]][1-dig]-cnt[p][i][1-dig]); ans+=w*res[p][fa[i]]; sum[p][fa[i]][dig]+=D; sub[p][i][dig]+=D; cnt[p][fa[i]][dig]++; i=fa[i]; } } inline void Delete(int p,int x,int dig) { ll w=1LL<<p; ans-=w*res[p][x]; res[p][x]-=sum[p][x][1-dig]; ans+=w*res[p][x]; cnt[p][x][dig]--; int i=x; while(fa[i]) { ll D=Dist(x,fa[i]); ans-=w*res[p][fa[i]]; res[p][fa[i]]-=(sum[p][fa[i]][1-dig]-sub[p][i][1-dig]); res[p][fa[i]]-=D*(cnt[p][fa[i]][1-dig]-cnt[p][i][1-dig]); ans+=w*res[p][fa[i]]; sum[p][fa[i]][dig]-=D; sub[p][i][dig]-=D; cnt[p][fa[i]][dig]--; i=fa[i]; } } int main() { while(~scanf("%d",&n)) { memset(cnt,0,sizeof(cnt)); memset(sum,0LL,sizeof(sum)); memset(sub,0LL,sizeof(sub)); memset(res,0LL,sizeof(res)); ans=0,id=0; for(int i=1;i<=n;i++) v[i].clear(); for(int i=1;i<=n;i++) scanf("%d",&a[i]); for(int i=1;i<n;i++) { int x,y,w; scanf("%d%d%d",&x,&y,&w); v[x].push_back(pii(y,w)); v[y].push_back(pii(x,w)); } dfs(1,0); ST(); Build(1,0,n); for(int i=1;i<=n;i++) for(int j=0;j<M;j++) Add(j,i,(a[i]>>j)&1); scanf("%d",&m); while(m--) { int x,y; scanf("%d%d",&x,&y); for(int i=0;i<M;i++) Delete(i,x,(a[x]>>i)&1); a[x]=y; for(int i=0;i<M;i++) Add(i,x,(a[x]>>i)&1); printf("%lld\n",ans); } } return 0; }
稍稍总结一下
通过上面两题可以发现,动态点分治最重要的部分就是如何保证只统计跨子树方案
第一题在这方面并不是很明显(靠的是维护$down[i]$,从而保证跨子树)
第二题有一个很套路性的操作,就是靠$sub$来消除$sum$的一部分,从而保证只计算跨子树的贡献;这种方法在动态点分治中会经常用到
顺着上题的思路,有一个建分治树后的处理复杂一些的题目:BZOJ 3730 (震波)
这道题算是比较充分地利用了分治树的性能
由于询问的是距离$x$小于等于$k$的点权和,所以还是能够想到用线段树/树状数组维护的
对于分治树上的每一个点$x$:
1. 建立树状数组$sum[x]$,其中$\sum_{j=i}^{j-=lowbit(j)} sum[x][j]$表示,以$x$为根节点的子树中 与$x$(在原树中)距离小于等于$i$的点权之和
2. 考虑上面总结的“消除子树以保证只计算跨子树贡献”
建立树状数组$sub[x]$,其中$\sum_{j=i}^{j-=lowbit(j)} sub[x][j]$表示,以$x$为根节点的子树中 与$fa_x$(在原树中)距离小于等于$i$的点权之和
对于一次 将$x$点权值在原基础上加$dlt$ 的修改,记当前位置为$i$,父节点为$fa_i$,那么更新$Dist(x,fa_i)$处的$sum[fa_i]$和$sub[i]$
对于一次 距$x$小于等于$k$ 的查询,记当前位置为$i$,父节点为$fa_i$,那么$sum[fa_i]-sub[i]$就可以表示 去除以$i$为根的子树后 的点权和信息
根据分治树的结构(最好思考一下正确性),对于所有被统计的点$j$,都有$LCA(x,j)=fa_i$;所以前缀和$\sum_{j=k-Dist(x,fa_i)}^{j-=lowbit(j)} sum[fa_i][j]-sub[i][j]$就是 与$x$的路径经过$fa_i$ 的所有点的贡献
(注意当$k-Dist(x,fa_i)<0$时,要直接走向$fa_i$)
如果评测时遇到RE,其实就是WA;因为之前答案错误,强制在线后的$x$就被异或成奇怪的值了
加了快速读入输出后,跟网上的不少AC程序都差不多的速度, 不过还是TLE了;好像唯一明显比我快的是动态开点线段树,但那个常数到底是怎么卡的...
我的TLE代码如下(应该只多了$0.5$倍常数的样子,不过主要出在$sort$和操作上,没法优化了)
#include <ctime> #include <locale> #include <cstdio> #include <vector> #include <cstring> #include <algorithm> using namespace std; struct Edge { int to,nxt; Edge(int a=0,int b=0) { to=a,nxt=b; } }; const int N=100005; const int LOG=20; inline char nc() { static char buf[N],*p1,*p2; return p1==p2&&(p2=(p1=buf)+fread(buf,1,N,stdin),p1==p2)?EOF:*p1++; } inline void read(int &x) { char ch=nc(); while(!isdigit(ch)) ch=nc(); x=0; while(isdigit(ch)) { x=x*10+ch-‘0‘; ch=nc(); } } inline void out(int &x) { static char buf[10]; int tmp=x,s=0; while(tmp) { buf[s++]=tmp%10+‘0‘; tmp/=10; } while(s>0) putchar(buf[--s]); putchar(‘\n‘); } int n,m; int a[N]; int cnt; int v[N]; Edge e[N<<1]; inline void AddEdge(int x,int y) { e[++cnt]=Edge(y,v[x]); v[x]=cnt; } int dep[N]; int st[N],ed[N]; int id,rmq[N<<1][LOG]; void dfs(int x,int f) { dep[x]=dep[f]+1; rmq[++id][0]=x; st[x]=id; for(int i=v[x];i;i=e[i].nxt) { int nxt=e[i].to; if(nxt==f) continue; dfs(nxt,x); rmq[++id][0]=x; } ed[x]=id; } inline int cmp(int x,int y) { return (dep[x]<dep[y]?x:y); } int log[N<<1]; void ST() { log[0]=-1; for(int i=1;i<=id;i++) log[i]=log[i>>1]+1; for(int i=0,t=1;i<LOG-1;i++,t<<=1) for(int j=1;j<=id;j++) { int l=rmq[j][i],r=(j+t>id?rmq[j][i]:rmq[j+t][i]); rmq[j][i+1]=cmp(l,r); } } inline int LCA(int x,int y) { int lb=min(st[x],st[y]),rb=max(ed[x],ed[y]); int k=log[rb-lb+1]; return cmp(rmq[lb][k],rmq[rb-(1<<k)+1][k]); } inline int Dist(int x,int y) { return dep[x]+dep[y]-(dep[LCA(x,y)]<<1); } int root; bool vis[N]; int sz[N],mx[N]; void Find(int x,int f,int tot) { sz[x]=1,mx[x]=0; for(int i=v[x];i;i=e[i].nxt) { int nxt=e[i].to; if(nxt==f || vis[nxt]) continue; Find(nxt,x,tot); sz[x]+=sz[nxt]; mx[x]=max(mx[x],sz[nxt]); } mx[x]=max(mx[x],tot-sz[x]); if(!root || mx[root]>mx[x]) root=x; } int fa[N]; void Build(int x,int f,int tot) { root=0; Find(x,0,tot); x=root; Find(x,0,tot); fa[x]=f; vis[x]=true; for(int i=v[x];i;i=e[i].nxt) { int nxt=e[i].to; if(vis[nxt]) continue; Build(nxt,x,sz[nxt]); } } vector<int> p1[N],p2[N]; int *sum[N],*sub[N]; int sz1[N],sz2[N]; inline void Insert(int x) { int i=x; while(i) { p1[i].push_back(Dist(x,i)); if(fa[i]) p2[i].push_back(Dist(x,fa[i])); i=fa[i]; } } inline int Place1(int x,int k) { if(k<p1[x][0]) return 0; if(k>p1[x].back()) return sz1[x]-1; return k-p1[x][0]+1; } inline int Place2(int x,int k) { if(k<p2[x][0]) return 0; if(k>p2[x].back()) return sz2[x]-1; return k-p2[x][0]+1; } inline int lowbit(int x) { return x&(-x); } void Modify(int x,int dlt) { for(int j=1;j<sz1[x];j+=lowbit(j)) sum[x][j]+=dlt; int i=x; while(fa[i]) { int D=Dist(x,fa[i]); int pos=Place1(fa[i],D); for(int j=pos;j<sz1[fa[i]];j+=lowbit(j)) sum[fa[i]][j]+=dlt; pos=Place2(i,D); for(int j=pos;j<sz2[i];j+=lowbit(j)) sub[i][j]+=dlt; i=fa[i]; } } inline int Query(int x,int k) { int res=0; int pos=Place1(x,k); for(int j=pos;j;j-=lowbit(j)) res+=sum[x][j]; int i=x; while(fa[i]) { int D=k-Dist(x,fa[i]); if(D<0) { i=fa[i]; continue; } pos=Place1(fa[i],D); for(int j=pos;j;j-=lowbit(j)) res+=sum[fa[i]][j]; pos=Place2(i,D); for(int j=pos;j;j-=lowbit(j)) res-=sub[i][j]; i=fa[i]; } return res; } int main() { // freopen("input.txt","r",stdin); // freopen("my.txt","w",stdout); read(n),read(m); for(int i=1;i<=n;i++) read(a[i]); for(int i=1;i<n;i++) { int x,y; read(x),read(y); AddEdge(x,y); AddEdge(y,x); } dfs(1,0); ST(); Build(1,0,n); for(int i=1;i<=n;i++) Insert(i); for(int i=1;i<=n;i++) { sort(p1[i].begin(),p1[i].end()); sort(p2[i].begin(),p2[i].end()); p1[i].resize(unique(p1[i].begin(),p1[i].end())-p1[i].begin()); p2[i].resize(unique(p2[i].begin(),p2[i].end())-p2[i].begin()); sz1[i]=p1[i].size()+1; sum[i]=new int[sz1[i]]; memset(sum[i],0,sizeof(int)*sz1[i]); sz2[i]=p2[i].size()+1; sub[i]=new int[sz2[i]]; memset(sub[i],0,sizeof(int)*sz2[i]); } for(int i=1;i<=n;i++) Modify(i,a[i]); int lastans=0; while(m--) { int op,x,y; read(op),read(x),read(y); x^=lastans; y^=lastans; if(op==1) { Modify(x,y-a[x]); a[x]=y; } else { lastans=Query(x,y); out(lastans); } } return 0; }
另附上数据生成器(不异或上次答案的那种),只要没有拍挂就问题不大
#include <ctime> #include <cmath> #include <cstdio> #include <vector> #include <cstring> #include <cstdlib> #include <algorithm> using namespace std; typedef long long ll; const int N=100005; inline int rnd(int lim) { return (((ll)rand()*rand()+rand())%lim*rand()+rand())%lim+1; } typedef pair<int,int> pii; vector<pii> edge; void Generate_Tree(int n) { int lim=sqrt(n); vector<int> cur,nxt; int tot=1; cur.push_back(1); while(tot<n) { nxt.clear(); for(int i=0;i<cur.size() && tot<n;i++) { int x=cur[i]; int sz=rnd(lim); if(tot+sz>=n) sz=n-tot; for(int j=1;j<=sz;j++) { edge.push_back(pii(x,++tot)); nxt.push_back(tot); } } cur=nxt; } } int cor[N]; void Shuffle(int n) { for(int i=1;i<=n;i++) cor[i]=i; random_shuffle(cor+1,cor+n+1); } int main() { srand(time(NULL)); freopen("input.txt","w",stdout); int SZ=100000; int n=SZ,m=SZ; printf("%d %d\n",n,m); for(int i=1;i<=n;i++) { int x=rnd(10000); printf("%d ",x); } printf("\n"); Generate_Tree(n); Shuffle(n); for(int i=0;i<n-1;i++) printf("%d %d\n",cor[edge[i].first],cor[edge[i].second]); for(int i=1;i<=m;i++) { int op=rnd(2)-1,x,y; if(op==1) x=rnd(n),y=rnd(10000); else x=rnd(n),y=rnd(n); printf("%d %d %d\n",op,x,y); } return 0; }
最终大BOSS:Luogu P3920 (紫荆花之恋,$WC2014$)
待续...暂时还没搞懂用替罪羊树重构的实现办法
慢慢补题
HDU 6268 ($Master\ of\ Subgraph$,$2017\ CCPC$杭州)
题目pdf:http://acm.hdu.edu.cn/downloads/CCPC2018-Hangzhou-ProblemSet.pdf
读完这题,一个比较显然的性质是,一个连通子图中的所有节点的共同LCA是唯一的;于是考虑枚举这个LCA
这就是说,我们可以对原树中每一个点的子树进行一次计算(必选子树的根),而总体的答案就是每个点答案的并
一开始想的是对每个点做背包,不过很明显一次背包的复杂度是$O(m^2)$、且一共需要做$n$次,根本无法接受
于是学到了一个trick,就是把上面过程中的对某子树的背包,改为对子树中一个点的背包
什么叫对一个点的背包呢?
我们可以对每个点用一个bitset来表示所有可能被选到的值;现对$x$的子树进行计算(注意,这趟计算中,只有$x$点的bitset是对最终答案有贡献的,其余点的bitset仅用于辅助计算$x$点的bitset)
我们希望做到一件事情:我们依次dfs $x$的儿子$son_i$,并将$x$的bitset并上$son_i$的bitset以获得贡献
那么$son_i$的bitset所表示的就是,在已统计过$son_1\text{~} son_{i-1}$的基础上,加入$son_i$的子树后所能选出的可能权值和情况
要能体现$son_1\text{~}son_{i-1}$的贡献,我们就需要把当前$x$的bitset通过某种方式传给$son_i$的bitset(因为$x$的bitset已获得了$son_1\text{~} son_{i-1}$的贡献)
既然当前统计的是$son_i$的子树,所以$son_i$是必须在连通子图中的,否则无法让子树中的其他点在连通子图中($son_i$不在连通子图中的情况就是当前$x$的bitset,不需要担心)
那么连通子图的权值和必然要加上$w[son_i]$,即之前所有可能的取值都要加上$w[son_i]$;这在bitset的表示上恰好为 将$x$的bitset左移$w[son_i]$位,可以比较快地做到
然后可以单选一个$son_i$点,即将第$w[son_i]$位变成$1$
这样一来,选$son_i$点的贡献已经全部统计出来了,这就是对于子树中一点的背包;若$son_i$是叶子节点,直接返回$x$、异或上$son_i$的bitset就可以获得贡献
若$son_i$不是叶子节点,之后就是一样的步骤,继续向下递归;不过对于某个$y$的来说,初始是将其父亲的bitset左移$w[y]$位,但选$y$是将第$\sum_{j\text{在}son_i\text{到}y\text{的路径上} }\ \ \ \ w[j]$位变成$1$
如果简单的采用这种方法,复杂度是$O(\frac{n^2m}{x})$($x$为bitset压位削减的常数),仍然不是很稳
但是能够注意到,总共$n$次的 对于每个子树的dfs 恰好是点分治的经典应用,于是可以把外层的dfs过程扔到点分治上进行,总复杂度就是$O(\frac{nm\cdot logn}{x})$了
这样一来,外层是在分治树上dfs,内层是在原树上dfs,有点奇妙
(虽然点分治改变了内层dfs所遍历的子树,但是再这道题目中,子树的划分是任意的;举个例子说,对于选定根$root$、原树中的儿子$x$、分治树中的儿子$y$来说,在暴力统计中选$x$又选$y$的情况在$x$点被统计,在点分治统计中该情况在$y$点被统计,其余的情况互不干扰,所以并不会产生任何重复或遗漏)
#include <cstdio> #include <bitset> #include <vector> #include <cstring> #include <algorithm> using namespace std; const int N=3005; const int M=100005; int n,m; int w[N]; vector<int> v[N]; int root; bool vis[N]; int sz[N],mx[N]; void Find(int x,int fa,int tot) { sz[x]=1,mx[x]=0; for(int i=0;i<v[x].size();i++) { int nxt=v[x][i]; if(vis[nxt] || nxt==fa) continue; Find(nxt,x,tot); sz[x]+=sz[nxt]; mx[x]=max(mx[x],sz[nxt]); } mx[x]=max(mx[x],tot-sz[x]); if(!root || mx[root]>mx[x]) root=x; } bitset<M> val[N]; bitset<M> ans; int sum[N]; void dfs(int x,int fa) { sum[x]=sum[fa]+w[x]; val[x]=val[fa]<<w[x]; if(sum[x]<=m) val[x][sum[x]]=1; for(int i=0;i<v[x].size();i++) { int nxt=v[x][i]; if(!vis[nxt] && nxt!=fa) { dfs(nxt,x); val[x]|=val[nxt]; } } } void Solve(int x,int tot) { root=0; Find(x,0,tot); x=root; Find(x,0,tot); dfs(x,0); ans|=val[x]; vis[x]=true; for(int i=0;i<v[x].size();i++) { int nxt=v[x][i]; if(vis[nxt]) continue; Solve(nxt,sz[nxt]); } vis[x]=false; } int main() { int T; scanf("%d",&T); while(T--) { ans.reset(); for(int i=1;i<=n;i++) { v[i].clear(); val[i].reset(); } scanf("%d%d",&n,&m); for(int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); v[x].push_back(y); v[y].push_back(x); } for(int i=1;i<=n;i++) scanf("%d",&w[i]); Solve(1,n); for(int i=1;i<=m;i++) printf("%d",(int)ans[i]); printf("\n"); } return 0; }
(待续)
原文地址:https://www.cnblogs.com/LiuRunky/p/Vertex_Partition.html