知识点-树链剖分
“在一棵树上进行路径的修改、求极值、求和”:乍一看只要线段树就能轻松解决,实际上,仅凭线段树是不能搞定它的。我们需要用到一种貌似高级的复杂算法——树链剖分。
树链,就是树上的路径。剖分,就是把路径分类为重链和轻链。
记siz[v]表示以v为根的子树的节点数,dep[v]表示v的深度(根深度为1),top[v]表示v所在的链的顶端节点,f[v]表示v的父亲,son[v]表示v的子节点中siz[]最大的节点编号(即重儿子),id[v]表示v的父边在线段树中的位置。只要把这些东西求出来,就能用log(n)的时间完成原问题中的操作。
重儿子:如果siz[u]为v的子节点中siz值最大的,那么u就是v的重儿子。
轻儿子:v的其它子节点。
重边:点v与其重儿子的连边。
轻边:点v与其轻儿子的连边。
重链:由重边连成的路径。
轻链:轻边。
算法实现
首先用一个dfs函数求出siz,dep,f,son的值。
紧接着用一个build函数建树建链并求出id和top的值。对于节点v,如果有son[v]的存在,显然可得top[son[v]]=top[v]。在线段树中,节点的重边应当在节点的父边之后,所以id[son[v]]=++tot;在对重儿子进行深搜之后,枚举所有轻儿子。对于每个轻儿子u,显然有top[u]=u;同时也要id[top[u]]=++tot;
当建树建链的步骤完成后,可以根据题意进行相应的操作。
例题
题目描述
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成一些操作:
I. CHANGE u t : 把结点u的权值改为t
II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
III. QSUM u v: 询问从点u到点v的路径上的节点的权值和
注意:从点u到点v的路径上的节点包括u和v本身
输入
输入的第一行为一个整数n,表示节点的个数。
接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。
接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。
接下来1行,为一个整数 q,表示操作的总数。
接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
输出
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
样例输入
4 1 2 2 3 4 1 4 2 1 3 12 QMAX 3 4 QMAX 3 3 QMAX 3 2 QMAX 2 3 QSUM 3 4 QSUM 2 1 CHANGE 1 5 QMAX 3 4 CHANGE 3 6 QMAX 3 4 QMAX 2 4 QSUM 3 4
样例输出
4 1 2 2 10 6 5 6 5 16
提示
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
算法操作
1.修改操作update:本题只是单点更新,所以只需用普通的线段树修改方法即可。
2.求极值以及和问题fmax和fsum:属于同一种类型的题目。
首先记fu=top[u],fv=top[v]。
若fu!=fv,即它们不处于同一条重链上,则不妨设dep[u]>=dep[v],此时u的深度比v的要大。更新u到fu父边的答案后,使u=f[fu],即将u类似LCA一样向上“跳”。
若fu=fv,则它们处于同一条重链上。若u!=v,则继续更新答案,否则得到答案。
代码
1 #include<cstring> 2 #include<cmath> 3 #include<algorithm> 4 #include<cstdio> 5 #define mid ((L+R)>>1) 6 #define ls (node<<1) 7 #define rs ls+1 8 using namespace std; 9 struct edge{ 10 int to,nxt; 11 }edge[60001]; 12 int cnt,n,tot; 13 int tsum[300001],tmax[300001],dep[30001],id[30001],head[30001],f[30001],son[30001],siz[30001],top[30001]; 14 char ord[20]; 15 void add(int u,int v)//链式前向星 16 { 17 edge[++cnt].to=v; 18 edge[cnt].nxt=head[u]; 19 head[u]=cnt; 20 } 21 void dfs(int node,int father,int deep) 22 { 23 f[node]=father;dep[node]=deep;siz[node]=1; 24 for(int i=head[node];i;i=edge[i].nxt) 25 { 26 int to=edge[i].to; 27 if(to==father)continue; 28 dfs(to,node,deep+1); 29 siz[node]+=siz[to]; 30 if(siz[to]>siz[son[node]])son[node]=to; 31 } 32 } 33 void build(int node,int num)//建树建链 34 { 35 id[node]=++tot; 36 top[node]=num; 37 if(son[node])build(son[node],num); 38 for(int i=head[node];i;i=edge[i].nxt) 39 { 40 int to=edge[i].to; 41 if(to!=son[node]&&to!=f[node])build(to,to); 42 } 43 } 44 void update(int node,int L,int R,int u,int w)//更新节点 45 { 46 if(u>R||u<L)return; 47 if(L==R) 48 { 49 tsum[node]=tmax[node]=w; 50 return; 51 } 52 update(ls,L,mid,u,w); 53 update(rs,mid+1,R,u,w); 54 tmax[node]=max(tmax[ls],tmax[rs]); 55 tsum[node]=tsum[ls]+tsum[rs]; 56 } 57 int qmax(int node,int L,int R,int l,int r) 58 { 59 if(l>R||r<L)return -10000000; 60 if(l<=L&&R<=r)return tmax[node]; 61 int ans=max(qmax(ls,L,mid,l,r),qmax(rs,mid+1,R,l,r)); 62 return ans; 63 } 64 int qsum(int node,int L,int R,int l,int r) 65 { 66 if(l>R||r<L)return 0; 67 if(l<=L&&R<=r)return tsum[node]; 68 int ans=qsum(ls,L,mid,l,r)+qsum(rs,mid+1,R,l,r); 69 return ans; 70 } 71 int fmax(int u,int v) 72 { 73 int ans=-10000000; 74 while(top[u]!=top[v]) 75 { 76 if(dep[top[u]]<dep[top[v]])swap(u,v); 77 ans=max(ans,qmax(1,1,tot,id[top[u]],id[u])); 78 u=f[top[u]]; 79 } 80 if(dep[u]>dep[v])swap(u,v); 81 ans=max(ans,qmax(1,1,tot,id[u],id[v])); 82 return ans; 83 } 84 int fsum(int u,int v) 85 { 86 int ans=0; 87 while(top[u]!=top[v]) 88 { 89 if(dep[top[u]]<dep[top[v]])swap(u,v); 90 ans+=qsum(1,1,tot,id[top[u]],id[u]); 91 u=f[top[u]]; 92 } 93 if(dep[u]>dep[v])swap(u,v); 94 ans+=qsum(1,1,tot,id[u],id[v]); 95 return ans; 96 } 97 int main() 98 { 99 int a,b,w,u,v,q; 100 scanf("%d",&n); 101 for(int i=1;i<n;i++) 102 { 103 scanf("%d%d",&a,&b); 104 add(b,a);add(a,b); 105 } 106 dfs(1,0,1); 107 build(1,1); 108 for(int i=1;i<=n;i++) 109 { 110 scanf("%d",&w); 111 update(1,1,tot,id[i],w); 112 } 113 scanf("%d",&q); 114 while(q--) 115 { 116 scanf("%s%d%d",ord,&u,&v); 117 if(ord[0]==‘C‘)update(1,1,tot,id[u],v); 118 else 119 { 120 if(ord[1]==‘M‘)printf("%d\n",fmax(u,v)); 121 else printf("%d\n",fsum(u,v)); 122 } 123 } 124 return 0; 125 }