今天介绍一个神仙算法:Dsu On Tree[ 树上启发式合并 ]
这个算法用于离线处理询问子树信息,而且很好写。
但是在你没有理解它之前,这是个很鬼畜的算法。
理解后你才能真心感到它的美妙之处。
关键是它是有着媲美线段树合并的时间复杂度的“暴力”算法。
这里说一件事,我学这个东西时找了很多篇博客,它们无一例外地给出了这样一个流程:
1. 先统计一个节点所有的轻儿子 然后删除它的答案
2. 再统计这个节点的重儿子 保留他的答案
3. 再算一遍所有轻儿子 加到答案中上传
我当时就看的很懵逼,算一遍所有轻儿子,删掉,再算一遍,这不闲的?
直接统计它的重儿子再算轻儿子不就好了?很疑惑,问了身边很多人也都觉得迷惑。
人类迷惑行为大赏.jpg
后面我搞懂了,为了不让其他学习dsu on tree的人也觉得迷惑,我就写了这一篇博客。
在这里非常感谢洛谷两个dalao的帮助,现在理解了这个东西。
我们每次进入一棵子树计算答案时,都要把计算上一棵子树的数据清除。
为什么,如果我们带着上次计算后的结果去计算新子树,答案肯定是不对的。
但是,我们最后一棵子树不需要清除,因为我们不用再进入新子树了(没了)。
那我们再回到上面说的,为什么一开始要算一遍轻儿子?
从最纯粹的暴力开始,我们有两个函数dfs1和dfs2,dfs1函数作为主体函数,dfs2作为辅助函数。
先dfs1到每一个点,dfs1它的后代,计算后代的信息,再dfs2它的后代,计算自己的答案。
也就是说,开始算轻儿子是要把它后代的信息计算出来,而不是理解为之前提到那些博客里面的算出来答案后“删除答案”。删除答案是为了不让计算的数据冲突。
按照dalao说的,保留重儿子的信息可以优化复杂度。从原先暴力的O(n^2)优化到O(nlogn)。
所以为了不用清除重儿子的信息,先dfs1轻儿子,再dfs1重儿子,最后dfs2轻儿子更新自己的答案。
如果一开始没有dfs1轻儿子,我们就没有得到后代的信息,所谓“删除”答案,其实是从统计信息的数组把dfs1轻儿子时存进去的用于计算的“缓存”清理了。
下面给出代码:
#include<bits/stdc++.h> #define N 100010 using namespace std; inline int read(){ int data=0,w=1;char ch=0; while(ch!=‘-‘ && (ch<‘0‘||ch>‘9‘))ch=getchar(); if(ch==‘-‘)w=-1,ch=getchar(); while(ch>=‘0‘ && ch<=‘9‘)data=data*10+ch-‘0‘,ch=getchar(); return data*w; } struct Edge{ int nxt,to; #define nxt(x) e[x].nxt #define to(x) e[x].to }e[N<<1]; int head[N],tot; inline void addedge(int f,int t){ nxt(++tot)=head[f];to(tot)=t;head[f]=tot; } int cnt[N],siz[N],son[N],c[N],max_val,n,child; long long sum,ans[N]; void add(int x,int f,int val){ cnt[c[x]]+=val; if(cnt[c[x]]>max_val)max_val=cnt[c[x]],sum=c[x]; else if(cnt[c[x]]==max_val)sum+=1LL*c[x]; for(int i=head[x];i;i=nxt(i)){ int y=to(i); if(y==f||y==child)continue; add(y,x,val); } } void dfs1(int x,int f){//重链剖分 siz[x]=1;int maxson=-1; for(int i=head[x];i;i=nxt(i)){ int y=to(i); if(y==f)continue; dfs1(y,x); siz[x]+=siz[y]; if(siz[y]>maxson){ maxson=siz[y];son[x]=y; } } } void dfs2(int x,int f,int opt){//opt为0表示统计后的答案要删掉,opt为1则不用删 for(int i=head[x];i;i=nxt(i)){ int y=to(i); if(y==f)continue; if(y!=son[x])dfs2(y,x,0); }if(son[x])dfs2(son[x],x,1),child=son[x]; add(x,f,1);child=0; ans[x]=sum; if(!opt)add(x,f,-1),sum=0,max_val=0; } int main(){ n=read(); for(int i=1;i<=n;i++)c[i]=read(); for(int i=1;i<n;i++){ int x=read(),y=read(); addedge(x,y);addedge(y,x); } dfs1(1,0);dfs2(1,0,0); for(int i=1;i<=n;i++) printf("%lld ",ans[i]); return 0; }
原文地址:https://www.cnblogs.com/light-house/p/11779076.html