题意:给定一棵树,每个节点有一个颜色,问树上有多少种子串(定义子串为某两个点上的路径),保证叶子节点数<=20。n<=10^5
题解:
叶子节点小于等于20,考虑将每个叶子节点作为根把树给提起来形成一棵trie,然后定义这棵树的子串为从上到下的一个串(深度从浅到深)。
这样做我们可以发现每个子串必定是某棵trie上的一段直线。统计20棵树的不同子串只需要把它们建到一个自动机上就行了,相当于把20棵trie合并成一棵大的。
对于每个节点x,它贡献的子串数量是max[x]-min[x],又因为min[x]=max[fa]+1,则=max[x]-max[fa],就是step[x]-step[fa];
学会了怎样在sam上插入一颗trie,就直接记录一下父亲在sam上的节点作为p。注意每次都要新开一个点,不然会导致无意义的子串出现。
例如一棵树 (括号内为i颜色)
1(0)
2(1)
3(2) 4(3)
2是1的孩子,3和4都是2的孩子。在以1为根节点的时候插入了这棵trie,在以3为根节点的时候son[root][2]已经存在,如果用它来当现在的点的话就会让一棵trie接在另一棵的末位,导致无意义的子串出现,答案偏大。
1 #include<cstdio> 2 #include<cstdlib> 3 #include<cstring> 4 #include<iostream> 5 using namespace std; 6 7 typedef long long LL; 8 const int N=20*100010; 9 int n,c,tot,len,last; 10 int w[N],son[N][15],step[N],pre[N],first[N],cnt[N],id[N]; 11 struct node{ 12 int x,y,next; 13 }a[2*N]; 14 15 void ins(int x,int y) 16 { 17 a[++len].x=x;a[len].y=y; 18 a[len].next=first[x];first[x]=len; 19 } 20 21 int add_node(int x) 22 { 23 step[++tot]=x; 24 return tot; 25 } 26 27 int extend(int p,int ch) 28 { 29 // int np; 30 // if(son[p][ch]) return son[p][ch]; 31 // else np=add_node(step[p]+1); 32 int np=add_node(step[p]+1);//debug 每次都要新开一个点 33 34 while(p && !son[p][ch]) son[p][ch]=np,p=pre[p]; 35 if(p==0) pre[np]=1; 36 else 37 { 38 int q=son[p][ch]; 39 if(step[q]==step[p]+1) pre[np]=q; 40 else 41 { 42 int nq=add_node(step[p]+1); 43 memcpy(son[nq],son[q],sizeof(son[q])); 44 pre[nq]=pre[q]; 45 pre[np]=pre[q]=nq; 46 while(son[p][ch]==q) son[p][ch]=nq,p=pre[p]; 47 } 48 } 49 last=np; 50 return np; 51 } 52 53 void dfs(int x,int fa,int now) 54 { 55 int nt=extend(now,w[x]); 56 // printf("%d\n",nt); 57 for(int i=first[x];i;i=a[i].next) 58 { 59 int y=a[i].y; 60 if(y!=fa) dfs(y,x,nt); 61 } 62 } 63 64 int main() 65 { 66 freopen("a.in","r",stdin); 67 scanf("%d%d",&n,&c); 68 for(int i=1;i<=n;i++) scanf("%d",&w[i]); 69 tot=0;len=0; 70 memset(son,0,sizeof(son)); 71 memset(pre,0,sizeof(pre)); 72 memset(cnt,0,sizeof(cnt)); 73 memset(first,0,sizeof(first)); 74 step[++tot]=0;last=1; 75 for(int i=1;i<n;i++) 76 { 77 int x,y; 78 scanf("%d%d",&x,&y); 79 ins(x,y);ins(y,x); 80 cnt[x]++;cnt[y]++; 81 } 82 // for(int i=1;i<=len;i++) printf("%d -- > %d\n",a[i].x,a[i].y); 83 for(int i=1;i<=n;i++) 84 { 85 if(cnt[i]==1) dfs(i,0,1); 86 } 87 // for(int i=1;i<=tot;i++) printf("%d ",id[i]);printf("\n"); 88 LL ans=0; 89 for(int i=1;i<=tot;i++) ans+=(LL)(step[i]-step[pre[i]]); 90 printf("%lld\n",ans); 91 return 0; 92 }
时间: 2024-10-07 00:27:45