题目链接:Click here
Solution:
容易得到这样一个\(dp\),设\(f[i][j]\)表示已经选了\(i\)个数,乘积\(mod \,\,m\)后为\(j\)的方案
\[
f[2\times i][j]=\sum_{a\times b\equiv j\,\,(mod\,\, m)} f[i][a]\times f[i][b]
\]
考虑如何优化这个转移方程,注意到这道题的模数为\(1004535809\),这是一个很大的提示,考虑\(NTT\)
但是我们知道\(NTT\)处理的是卷积,而这里是指数相乘,该怎么转化呢?
乘法变加法,不由得让我们想到对数,模意义下的对数,又不由得让我们想到原根
考虑原根的定义,\(g\)是\(p\)的原根,则\(g^0 ,g^1\dots g^{p-2}\)在\(mod \,\,p\)下恰好取到\([1,p-1]\)的所有整数
原根也很好求,把\(\varphi(p)=\prod pr_i^{a_i}\),然后对于所有的\(pr_i\)都有\(g^{\varphi(p) \over pr_i} \not \equiv 1 \,\, ( mod \,\, p)\)
那么我们令\(j=\log_g j,a=\log_g a,b=\log_g b\),同时根据扩展欧拉定理,转移方程就变成了
\[
f[2\times i][j]=\sum_{a+b \equiv j \,\, (mod \,\, \varphi(m))}f[i][a]\times f[i][b]
\]
注意每次转移后\(f[i][j]+=f[i][j+\varphi(m)]\),我们就可以愉快的\(NTT\)啦!
等等,好像还不行,\(n=1e9\),仔细观察这个转移,发现可以直接多项式快速幂,那么再快速幂即可,时间复杂度\(O(m\log m \log n)\)
Code:
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=3e4+11;
const int mod=1004535809;
int n,m,S,X,len=1,tim;
int f[N],v[N],p[N],pr[N],lg[N];
int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-f;ch=getchar();}
while(isdigit(ch)){x=x*10+ch-48;ch=getchar();}
return x*f;
}
int qpow(int x,int y,int p){
int re=1;
while(y>0){
if(y&1) re=re*x%p;
y>>=1;x=x*x%p;
}return re;
}
int GetRt(int x){
int tot=0,phi=x-1;
for(int i=2;i*i<=phi;i++)
if(phi%i==0){
pr[++tot]=i;
while(phi%i==0) phi/=i;
}
if(phi>1) pr[++tot]=phi;
phi=x-1;
for(int i=2;i<=phi;i++){
int flag=1;
for(int j=1;j<=tot&&flag;j++)
if(qpow(i,phi/pr[j],x)==1) flag=0;
if(flag) return i;
}
return -1;
}
void NTT(int *a,int flag){
for(int i=0;i<len;i++)
if(i<p[i]) swap(a[i],a[p[i]]);
for(int l=2;l<=len;l<<=1){
int wn=qpow(3,(mod-1)/l,mod);
if(flag==-1) wn=qpow(wn,mod-2,mod);
for(int st=0;st<len;st+=l){
int w=1;
for(int u=st;u<st+(l>>1);u++,w=w*wn%mod){
int x=a[u],y=w*a[u+(l>>1)]%mod;
a[u]=(x+y)%mod;a[u+(l>>1)]=(x+mod-y)%mod;
}
}
}
}
void Mul(int *A,int *B,int *C){
static int a[N],b[N];
for(int i=0;i<len;i++) a[i]=A[i];
for(int i=0;i<len;i++) b[i]=B[i];
NTT(a,1);NTT(b,1);
for(int i=0;i<len;i++) a[i]=a[i]*b[i]%mod;
NTT(a,-1);int inv=qpow(len,mod-2,mod);
for(int i=0;i<len;i++) a[i]=a[i]*inv%mod;
for(int i=0;i<m-1;i++) C[i]=(a[i]+a[i+m-1])%mod;
}
signed main(){
n=read(),m=read(),X=read(),S=read();
while(len<(m<<1)) len<<=1,++tim;
for(int i=0;i<len;i++)
p[i]=(p[i>>1]>>1)|((i&1)<<(tim-1));
int g=GetRt(m);
for(int i=0;i<m-1;i++) lg[qpow(g,i,m)]=i;
for(int i=1;i<=S;i++){
int x=read()%m;
if(x) ++v[lg[x]];
}
f[lg[1]]=1;
while(n){
if(n&1) Mul(f,v,f);
n>>=1;Mul(v,v,v);
}
printf("%lld\n",f[lg[X]]);
return 0;
}
原文地址:https://www.cnblogs.com/NLDQY/p/12242664.html