T1:给出n个正整数a1,a2…an和一个质数mod.一个变量x初始为1.
进行m次操作.每次在n个数中随机选一个ai,然后x=x∗ai.问m次操作之后x的取值的期望.
(1<=ai<mod mod为质数 1<=mod<=1000 1<=n<=105 , 1<=m<=109)
第一眼康上去感觉是个数论加期望,完全不想做,想想后发现是个假期望,但好像是真数论,还是不想做
于是直接跳过,最后也没什么思路,打了个n×m×mod的暴力,然后喜暴 0
考完后,听大佬们分享,才想到这是个矩阵优化dp
那说说思路:
首先会发现mod极小,而m极大,那么最基本的思路就是用mod复杂度增加的代价将m的复杂度降低
猜测m复杂度为log级别,那么就想log级的算法
想到如果设计一个状态 f [ i ][ j ] 表示操作 i 此后变成 j 的期望
那么每次转移的系数矩阵都是相同的!
我们就可以用矩阵快速幂来优化dp的转移,于是复杂度为:O ( mod3 log(m) )
可是还是过不了,思路有问题吗?显然没有,那怎么优化呢?
关于矩乘的优化,想到循环矩阵的优化,于是看能不能转变系数矩阵的定义,让其成为循环矩阵
发现题目一个很妙的性质,1<=ai<mod,于是想到用原根优化(???)
设原根为 rt ,若 i = xp[i] 那么,我们最开始的式子是 i × ak → j (mod mod)
那么用原根就可以将式子转化为 xp[i] × xp[ak] = xp[j] (mod mod)
因为底数都是 x ,所以我们可以将其转化为指数间的运算
即:p[ i ] + p[ ak ] = p[ j ] (mod φ(mod))
哇,加法!
好了,它循环了。
为什么? 加法的转移相当与一种等距离的定向的转移,所以必循环
时间复杂度变为 O ( mod2 log(m) ) ,同时还优化了空间的复杂度
so,code
1 #include<cstdio> 2 #include<iostream> 3 #include<cmath> 4 #include<algorithm> 5 #include<cstring> 6 #include<vector> 7 #include<queue> 8 #define ll long long 9 using namespace std; 10 const int MAXMOD=1005,MAXN=100005,D=1e9+7; 11 int n,m,mod; 12 ll val[MAXN],ans,prt,pos[MAXMOD],cnt[MAXMOD],stk[MAXMOD],tp,invn; 13 struct Matrix { 14 ll s[MAXMOD]; 15 Matrix() {memset(s,0,sizeof(s));} 16 }P,R; 17 Matrix operator * (const Matrix &AA,const Matrix &BB) { 18 Matrix CC; 19 for(int i=0;i<mod;i++) 20 for(int j=0;j<mod;j++) 21 CC.s[(i+j)%(mod-1)]=(CC.s[(i+j)%(mod-1)]+AA.s[i]*BB.s[j]%D)%D; 22 return CC; 23 } 24 ll qpow(ll x,ll k,ll dd) { 25 ll ret=1; 26 while(k) { 27 if(k&1) ret=(ret*x)%dd; 28 x=(x*x)%dd,k>>=1; 29 } 30 return ret%dd; 31 } 32 void get_prt() { 33 int tmp=mod-1; 34 for(int i=2;i*i<mod;i++) 35 if(tmp%i==0) { 36 stk[++tp]=i; 37 while(tmp%i==0) tmp/=i; 38 } 39 if(tmp>1) stk[++tp]=tmp; 40 for(int i=2;i<=mod;i++) { 41 bool flag=0; 42 for(int j=1;j<=tp;j++) 43 if(qpow(i,(mod-1)/stk[j],mod)==1) {flag=1;break;} 44 if(!flag) {prt=i;break;} 45 } 46 for(int i=0;i<mod-1;i++) pos[qpow(prt,i,mod)]=i; 47 } 48 int main() { 49 scanf("%d%d%d",&n,&m,&mod); 50 for(int i=1;i<=n;i++) scanf("%lld",&val[i]),++cnt[val[i]]; 51 get_prt(); 52 invn=qpow(n,D-2,D); 53 for(int i=1;i<mod;i++) 54 P.s[pos[i]]=cnt[i]*invn%D; 55 R.s[0]=1; 56 while(m) { 57 if(m&1) R=R*P; 58 P=P*P,m>>=1; 59 } 60 for(int i=1;i<mod;i++) ans=(ans+R.s[pos[i]]*i)%D; 61 printf("%lld\n",ans); 62 return 0; 63 }
t1 Code
T2:给一颗树,每个节点都有两个权值a和b,a和b有如下关系:
b[x]=a[1]dis(1,x)+a[2]dis(2,x)+....+a[n]*dis(n,x) (dis ( i , j ) 表示i 到j 的最短路径经过的边数)
现在给你a与b数组中的一个,求另一个数组。
( 2<=n<=100000 )
首先很容易就可以在 O ( n ) 的复杂度下利用换根dp用a求出b,dp方程显然,就不赘述了
主要看如何用b来求a:
不妨设dp的根为1
我们看a求b时换根的方程 b[ v ] = b[ u ] + (siz[ 1 ] - siz [ v ])- siz [ v ]
(siz [ i ] 表示以i为根的子树内a的和)(v是u的儿子)
移项后得 b[ v ] - b[ u ] = siz[ 1 ] - 2×siz[ v ] 记作 g[ v ]
于是我们可以对除1以外的所以点进行上述计算
然后会发现其实并算不出什么东西???
可是我们还没有用 b[ 1 ] 啊?
于是将 b[ 1 ]的表达式写出,然后再合并
发现,b[ 1 ] = ∑siz [ u ] (u!=1)
那就简单了,用 ∑g[ v ] + 2 × b[ 1 ] 就可以计算出 siz[ 1 ] 的值
然后再用 g 和 siz[ 1 ] 算出所有点的 siz
最后dfs一遍求出a即可
so,code
1 #include<cstdio> 2 #include<iostream> 3 #include<cmath> 4 #include<algorithm> 5 #include<cstring> 6 #include<vector> 7 #include<queue> 8 #define ll long long 9 using namespace std; 10 const int MAXN=200005; 11 int T,n,o; 12 ll val[MAXN],siz[MAXN],f[MAXN],g[MAXN],sum; 13 struct node { 14 int to,nxt; 15 }mp[MAXN*2]; 16 int h[MAXN],tot; 17 void add(int x,int y) { 18 mp[++tot].nxt=h[x]; 19 mp[tot].to=y; 20 h[x]=tot; 21 } 22 void dfs1(int u,int fa) { 23 siz[u]=val[u]; 24 for(int i=h[u];i;i=mp[i].nxt) { 25 int v=mp[i].to; 26 if(v==fa) continue; 27 dfs1(v,u); 28 siz[u]+=siz[v]; 29 f[u]+=f[v]+siz[v]; 30 } 31 } 32 void dfs2(int u,int fa) { 33 for(int i=h[u];i;i=mp[i].nxt) { 34 int v=mp[i].to; 35 if(v==fa) continue; 36 g[v]=g[u]-siz[v]+(siz[1]-siz[v]); 37 dfs2(v,u); 38 } 39 } 40 void work0() { 41 dfs1(1,0); 42 g[1]=f[1]; 43 dfs2(1,0); 44 for(int i=1;i<=n;i++) printf("%lld ",g[i]); 45 printf("\n"); 46 } 47 void dfs3(int u,int fa) { 48 for(int i=h[u];i;i=mp[i].nxt) { 49 int v=mp[i].to; 50 if(v==fa) continue; 51 g[v]=val[v]-val[u]; 52 dfs3(v,u); 53 } 54 } 55 void dfs4(int u,int fa) { 56 f[u]=siz[u]; 57 for(int i=h[u];i;i=mp[i].nxt) { 58 int v=mp[i].to; 59 if(v==fa) continue; 60 dfs4(v,u); 61 f[u]-=siz[v]; 62 } 63 } 64 void work1() { 65 dfs3(1,0); 66 for(int i=2;i<=n;i++) sum+=g[i]; 67 siz[1]=(2*val[1]+sum)/(n-1); 68 for(int i=2;i<=n;i++) siz[i]=(siz[1]-g[i])/2; 69 dfs4(1,0); 70 for(int i=1;i<=n;i++) printf("%lld ",f[i]); 71 printf("\n"); 72 } 73 int main() { 74 scanf("%d",&T); 75 while(T--) { 76 scanf("%d",&n); 77 for(int i=1,aa,bb;i<n;i++) scanf("%d%d",&aa,&bb),add(aa,bb),add(bb,aa); 78 scanf("%d",&o); 79 for(int i=1;i<=n;i++) scanf("%lld",&val[i]); 80 if(!o) work0(); 81 else work1(); 82 memset(h,0,sizeof(h)); 83 memset(f,0,sizeof(f)); 84 memset(g,0,sizeof(g)); 85 memset(siz,0,sizeof(siz)); 86 memset(val,0,sizeof(val)); 87 tot=0;sum=0; 88 } 89 return 0; 90 }
t2 Code
T3:先写会儿插头dp,咕着。。。
原文地址:https://www.cnblogs.com/Gkeng/p/11259008.html