设$f(x)$为树的生成函数,即$x^i$的系数为根节点权值为$i$的树的个数。
不难得出$f(x)=\sum_{k\in D}f(x)^k+x$
我们要求这个多项式的第$n$项,由拉格朗日反演可得
$[x^n]f(x)=\frac1n[x^{n-1}](\frac x{g(x)})^n$
其中$[x^n]f(x)$表示$f(x)$的$n$次项系数。
$f(x)$是$g(x)$的复合逆,即$g(f(x))=x$
在本题中,$g(x)=x-\sum_{k\in D}x^k$
我们需要多项式求逆和多项式快速幂。
多项式求逆就不介绍了,多项式快速幂一种朴素的做法是倍增+NTT,复杂度是$O(n\log n\log k)$
有没有更快的做法呢?
观察到$f(x)^n=e^{n\ln(f(x))}$,所以我们只需要快速算$\ln(f(x))$及$e^{f(x)}$即可。
注意$f(x)$的常数项要为1,还好出题人良心保证了这一点。
Part 1:如何算$\ln(f(x))$?
设$g(x)=\ln(f(x))$,那么$g‘(x)=\frac{f‘(x)}{f(x)}$,所以$g(x)=\int\frac{f‘(x)}{f(x)}$,时间复杂度$O(n\log n)$
Part 2:如何算$e^{f(x)}$?
还是考虑倍增,假设我已经求出$g_0(x)=e^{f(x)}(mod\;x^n)$,要求$g(x)=e^{f(x)}(mod\;x^{2n})$
根据泰勒展开,有$$0=h(g(x))=\sum_{i=0}^{\infty}\frac{h^{(i)}(g_0(x))}{i!}(g(x)-g_0(x))^i$$当$i>1$时,上式$mod\;x^{2n}$为$0$
所以$0=h(g_0(x))+h‘(g_0(x))(g(x)-g_0(x))\;(mod\;x^{2n})$
即$g(x)=g_0(x)-\frac{h(g_0(x))}{h‘(g_0(x))}(mod\;x^{2n})$
其中$h(g(x))=\ln(g(x))-f(x)$
所以$g(x)=g_0(x)-\frac{\ln(g_0(x))-f(x)}{\frac 1{g_0(x)}}=g_0(x)(1-\ln(g_0(x))+f(x))\;(mod\;x^{2n})$
时间复杂度$O(n\log n)$
#include <cstdio> #include <cstring> #include <algorithm> #define pre m=n<<1; for(int i=0;i<m;i++) r[i]=(r[i>>1]>>1)|((i&1)<<l) typedef long long ll; const int p=950009857,N=300000; int n,m,l,x,nn,f[N],g[N],t1[N],t2[N],t3[N],r[N],ni[N]; ll pw(ll a,int b) {ll r=1; for(;b;b>>=1,a=a*a%p) if(b&1) r=r*a%p; return r;} void ntt(int *a,int n,int f) { for(int i=0;i<n;i++) if(r[i]>i) std::swap(a[i],a[r[i]]); for(int i=1;i<n;i<<=1) for(int j=0,wn=pw(7,((p-1)/(i*2)*f+p-1)%(p-1));j<n;j+=i<<1) for(int k=0,w=1;k<i;k++,w=(ll)w*wn%p) { int x=a[j+k],y=(ll)a[j+k+i]*w%p; a[j+k]=(x+y)%p,a[j+k+i]=(x-y+p)%p; } if(!~f) for(int i=0;i<n;i++) a[i]=(ll)a[i]*ni[n]%p; } void inv(int *f,int *g,int *t,int n,int l) { if(n==1) {g[0]=pw(f[0],p-2); return;} inv(f,g,t,n>>1,l-1),memcpy(t,f,sizeof(int)*n),memset(t+n,0,sizeof(int)*n),pre;ntt(t,m,1),ntt(g,m,1); for(int i=0;i<m;i++) g[i]=(ll)g[i]*(2-(ll)t[i]*g[i]%p+p)%p; ntt(g,m,-1),memset(g+n,0,sizeof(int)*n); } void ln(int *f,int *g,int *t,int n,int l) { memset(t,0,sizeof(int)*n*2),inv(f,t,t1,n,l); for(int i=0;i+1<n;i++) g[i]=(ll)f[i+1]*(i+1)%p; pre;ntt(g,m,1),ntt(t,m,1); for(int i=0;i<m;i++) g[i]=(ll)g[i]*t[i]%p; ntt(g,m,-1); for(int i=m-1;i;i--) g[i]=(ll)g[i-1]*ni[i]%p; g[0]=0,memset(g+n,0,sizeof(int)*n); } void ex(int *f,int *g,int *t,int n,int l) { if(n==1) {g[0]=1; return;} ex(f,g,t,n>>1,l-1),memset(t,0,sizeof(int)*n*2),ln(g,t,t2,n,l); for(int i=0;i<n;i++) t[i]=(f[i]-t[i]+p)%p; t[0]=(t[0]+1)%p,pre;ntt(t,m,1),ntt(g,m,1); for(int i=0;i<m;i++) g[i]=(ll)g[i]*t[i]%p; ntt(g,m,-1),memset(g+n,0,sizeof(int)*n); } int main() { scanf("%d%d",&n,&m),f[0]++,ni[1]=1,nn=n; for(int i=1;i<=m;i++) scanf("%d",&x),f[x-1]=p-1; for(m=n,n=1,l=0;n<=m;n<<=1) l++; for(int i=2;i<=n*2;i++) ni[i]=(ll)(p-p/i)*ni[p%i]%p; inv(f,g,t1,n,l),memset(f,0,sizeof f),ln(g,f,t2,n,l); for(int i=0;i<n;i++) f[i]=(ll)f[i]*nn%p; memset(g,0,sizeof g),ex(f,g,t3,n,l),printf("%lld",(ll)g[nn-1]*ni[nn]%p); return 0; }