简介
对于一颗静态树,O(nlogn)时间内处理子树的统计问题。是一种优雅的暴力。
算法思想
很显然,朴素做法下,对于每颗子树对其进行统计的时间复杂度是平方级别的。考虑对树进行一个重链剖分。虽然都基于重链剖分,但不同于树剖,我们维护的不是树链。
对于每个节点,我们先处理其轻儿子所在子树,轻子树在处理完后消除其影响。然后处理重儿子所在子树,保留其贡献。然后再暴力跑该点的轻子树,统计该点子树的最终答案。如果该点子树是轻子树,则消除该子树的影响,否则保留。用代码描述的话,大概是这个流程:
void dfs(int u,int fa,int hvy)
{
for(v :G[u])//处理轻子树
{
if(v==f||v==son[u])
continue;
dfs(v,u,0);
}
if(son[u])//处理重子树
dfs(son[u],u,1);
calc(u,fa,1);//暴力统计轻子树对该点答案的贡献
ans[u]=res;
if(!hvy)
calc(u,fa,-1);//若点u所在子树是轻子树,则逆着原来统计的操作来消除其影响。
}
以上体现大概思想,但遇到具体题目可能有很多细节需要思考。
复杂度分析
这个可能不能很容易的明白其为何高效,如何达到O(nlogn)。因此我们考虑每个节点对时间复杂度的贡献。如果真的明白上述的算法流程,可以知道我们执行暴力统计的都是对轻边所连的子树,因此每个点被遍历到的次数与它往上到根的轻边数量有关。而任一点到根的路径上,轻边的数量不会超过logn。因此每个点最多被遍历logn次。这样想应该好理解很多。
举例
Lomsat gelral
这是一道比较经典的入门题,有兴趣的可以练手,感受一下算法的思想,再做下一题。在此不给出代码。
下面稍微讲一下D. Arpa’s letter-marked tree and Mehrdad’s Dokhtar-kosh paths
感觉这道题还是挺难的,要考虑不少细节。
题意大概就是每条边有一个字符(a-v),求每颗子树下最长的一条简单路径,其上的字符可重组成回文串。显然就是要至多只有一个字符出现奇数次。
我们把每种字符看作二进制上的一个位,即2的幂。则满足条件的简单路径,其边权异或结果必须为0或2的幂。
因此用到dp和dsu on tree的思想。a[i]表示点i到根的路径异或值,dp[i]表示a[x]=i的点中,深度最大的x的深度。
对于一颗以u为根的子树,它的答案路径(该路径默认包含u,因此可能不是最终答案)可能是1.u到其子树中某点的简单路径;2.u的两颗不同子树中的两点间的路径。前者直接判断来更新答案;对于后者两颗子树间的情况,需要不断更新每个异或值下的最大深度,方便对于跑到的点可以知道此时与它满足条件的另一点的最大深度,从而得知路径长来更新答案。然后若该子树为重子树,则保留dp信息,否则重置。
附上代码
#include<bits/stdc++.h>
#define dd(x) cout<<#x<<" = "<<x<<" "
#define de(x) cout<<#x<<" = "<<x<<"\n"
#define sz(x) int(x.size())
#define All(x) x.begin(),x.end()
#define pb push_back
#define mp make_pair
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int,int> P;
typedef priority_queue<int> BQ;
typedef priority_queue<int,vector<int>,greater<int> > SQ;
const int maxn=5e5+10,mod=1e9+7,INF=0x3f3f3f3f;
int a[maxn],dp[maxn*10],sz[maxn],d[maxn],son[maxn],ans[maxn];
vector<int> G[maxn];
void dfs1(int u,int fa)
{
sz[u]=1;
d[u]=d[fa]+1;
a[u]^=a[fa];
for (auto& v:G[u])
{
dfs1(v,u);
sz[u]+=sz[v];
if (sz[v]>sz[son[u]])
son[u]=v;
}
}
int mx;
bool check(int x,int y)
{
int t=x^y,cnt=0;
for (int i=0;i<='v'-'a';++i)
cnt+=(t>>i)&1;
return cnt<=1;
}
void cal(int rt,int u)
{
if (check(a[u],a[rt]))
mx=max(mx,d[u]-d[rt]);
mx=max(mx,dp[a[u]]+d[u]-2*d[rt]);
for (int i=0;i<='v'-'a';++i)
mx=max(mx,dp[a[u]^(1<<i)]+d[u]-2*d[rt]);
for (auto& v:G[u])
cal(rt,v);
}
void upd(int u,int ty)
{
if (ty)
dp[a[u]]=max(dp[a[u]],d[u]);
else
dp[a[u]]=-INF;
for (auto& v:G[u])
upd(v,ty);
}
void dfs2(int u,int hvy)
{
for (auto&v :G[u])
{
if (v==son[u])
continue;
dfs2(v,0);
}
if (son[u])
dfs2(son[u],1);
mx=0;
mx=max(mx,dp[a[u]]-d[u]);
for (int i=0;i<='v'-'a';++i)
mx=max(mx,dp[a[u]^(1<<i)]-d[u]);
for (auto& v:G[u])
{
if (v==son[u])
continue;
cal(u,v);
upd(v,1);
}
ans[u]=mx;
if (hvy)
dp[a[u]]=max(dp[a[u]],d[u]);
else
{
for (auto& v:G[u])
upd(v,0);
dp[a[u]]=-INF;
}
}
void solve(int u)
{
for (auto& v:G[u])
{
solve(v);
ans[u]=max(ans[u],ans[v]);
}
}
int main()
{
int n;
cin>>n;
char c[2];
for (int i=2;i<=n;++i)
{
int f;
scanf("%d%s",&f,c);
G[f].pb(i);
a[i]=1<<(c[0]-'a');
}
for (int i=1;i<maxn*10;++i)
dp[i]=-INF;
dfs1(1,0);
dfs2(1,1);
solve(1);
for (int i=1;i<=n;++i)
printf("%d ",ans[i]);
return 0;
}
原文地址:https://www.cnblogs.com/orangee/p/10463899.html