【题意】
有若干个询问,询问路径u,v上的颜色总数,另外有要求a,b,意为将a颜色看作b颜色。
【思路】
vfk真是神系列233。
Quote:
用S(v, u)代表 v到u的路径上的结点的集合。
用root来代表根结点,用lca(v, u)来代表v、u的最近公共祖先。
那么
S(v, u) = S(root, v) xor S(root, u) xor lca(v, u)
其中xor是集合的对称差。
简单来说就是节点出现两次消掉。
lca很讨厌,于是再定义
T(v, u) = S(root, v) xor S(root, u)
观察将curV移动到targetV前后T(curV, curU)变化:
T(curV, curU) = S(root, curV) xor S(root, curU)
T(targetV, curU) = S(root, targetV) xor S(root, curU)
取对称差:
T(curV, curU) xor T(targetV, curU)= (S(root, curV) xor S(root, curU)) xor (S(root, targetV) xor S(root, curU))
由于对称差的交换律、结合律:
T(curV, curU) xor T(targetV, curU)= S(root, curV) xor S(root, targetV)
两边同时xor T(curV, curU):
T(targetV, curU)= T(curV, curU) xor S(root, curV) xor S(root, targetV)
发现最后两项很爽……哇哈哈
T(targetV, curU)= T(curV, curU) xor T(curV, targetV)
(有公式恐惧症的不要走啊 T_T)
也就是说,更新的时候,xor T(curV, targetV)就行了。
即,对curV到targetV路径(除开lca(curV, targetV))上的结点,将它们的存在性取反即可。
from vfleaking
设目前指针处于a,b,且now为已经得到的T(curV,curU)的统计答案,此次询问为u,v,则我们使a->u,b->v,设计一个vis标记,路上取反标记并更新now,即xor T(curV,targetV),最后还要取反lca(u,v)才是一个完整的u->v路径。
【代码】
1 #include<set> 2 #include<cmath> 3 #include<queue> 4 #include<vector> 5 #include<cstdio> 6 #include<cstring> 7 #include<iostream> 8 #include<algorithm> 9 #define trav(u,i) for(int i=front[u];i;i=e[i].nxt) 10 #define FOR(a,b,c) for(int a=(b);a<=(c);a++) 11 using namespace std; 12 13 typedef long long ll; 14 const int N = 4e5+10; 15 const int D = 21; 16 17 ll read() { 18 char c=getchar(); 19 ll f=1,x=0; 20 while(!isdigit(c)) { 21 if(c==‘-‘) f=-1; c=getchar(); 22 } 23 while(isdigit(c)) 24 x=x*10+c-‘0‘,c=getchar(); 25 return x*f; 26 } 27 28 struct Edge { 29 int v,nxt; 30 }e[N]; 31 int en=1,front[N]; 32 void adde(int u,int v) 33 { 34 e[++en]=(Edge){v,front[u]}; front[u]=en; 35 } 36 37 int n,m,B,B_cnt,now,dfsc,dfn[N],ans[N]; 38 int pos[N],a[N],cnt[N],vis[N],fa[N][D],dep[N],st[N],top; 39 40 struct Node 41 { 42 int id,l,r,a,b; 43 bool operator < (const Node& rhs) const 44 { 45 return pos[l]<pos[rhs.l] || (pos[l]==pos[rhs.l]&&dfn[r]<dfn[rhs.r]); 46 } 47 } q[N]; 48 49 int dfs(int u) 50 { 51 FOR(i,1,D-1) 52 fa[u][i]=fa[fa[u][i-1]][i-1]; 53 int size=0; 54 dfn[u]=++dfsc; 55 trav(u,i) 56 { 57 int v=e[i].v; 58 if(v!=fa[u][0]) { 59 fa[v][0]=u; 60 dep[v]=dep[u]+1; 61 size+=dfs(v); 62 if(size>=B) { 63 B_cnt++; 64 while(size--) 65 pos[st[top--]]=B_cnt; 66 } 67 } 68 } 69 st[++top]=u; 70 return size+1; 71 } 72 int lca(int u,int v) 73 { 74 if(dep[u]<dep[v]) swap(u,v); 75 int t=dep[u]-dep[v]; 76 FOR(i,0,D-1) 77 if((1<<i)&t) u=fa[u][i]; 78 if(u==v) return u; 79 for(int i=D-1;i>=0;i--) 80 if(fa[u][i]!=fa[v][i]) 81 u=fa[u][i],v=fa[v][i]; 82 return fa[u][0]; 83 } 84 void upd(int u) 85 { 86 if(!vis[u]) { 87 vis[u]=1; 88 now+=(++cnt[a[u]])==1; 89 } else { 90 vis[u]=0; 91 now-=(--cnt[a[u]])==0; 92 } 93 } 94 void work(int u,int v) 95 { 96 while(u!=v) 97 { 98 if(dep[u]<dep[v]) swap(u,v); 99 upd(u); u=fa[u][0]; 100 } 101 } 102 int main() 103 { 104 // freopen("in.in","r",stdin); 105 // freopen("out.out","w",stdout); 106 n=read(),m=read(); 107 B=sqrt(n); 108 FOR(i,1,n) a[i]=read(); 109 int rt,u,v; 110 FOR(i,1,n) { 111 u=read(),v=read(); 112 if(!u) rt=v; 113 else if(!v) rt=u; 114 else adde(u,v),adde(v,u); 115 } 116 dfs(rt); 117 B_cnt++; 118 while(top) pos[st[top--]]=B_cnt; 119 FOR(i,1,m) { 120 q[i].l=read(),q[i].r=read(); 121 q[i].a=read(),q[i].b=read(); 122 q[i].id=i; 123 if(dfn[q[i].l]>dfn[q[i].r]) swap(q[i].l,q[i].r); 124 } 125 sort(q+1,q+m+1); 126 127 work(q[1].l,q[1].r); 128 int lc=lca(q[1].l,q[1].r); 129 upd(lc); 130 ans[q[1].id]=now; 131 ans[q[1].id]-=(q[1].a!=q[1].b)&&(cnt[q[1].a]&&cnt[q[1].b]); 132 upd(lc); 133 134 FOR(i,2,m) { 135 work(q[i-1].l,q[i].l); 136 work(q[i-1].r,q[i].r); 137 lc=lca(q[i].l,q[i].r); 138 upd(lc); 139 ans[q[i].id]=now; 140 ans[q[i].id]-=(q[i].a!=q[i].b)&&(cnt[q[i].a]&&cnt[q[i].b]); 141 upd(lc); 142 } 143 144 FOR(i,1,m) 145 printf("%d\n",ans[i]); 146 return 0; 147 }