大意: 两人轮流操作一个长$n$, 只含前$k$种小写字母的串, 每次操作删除一个字符或者将整个串重排, 每次操作后得到的串不能和之前出现过的串相同, 求多少种串能使先手必胜.
找下规律发现$n$为奇数必胜, 否则假设$a_i$为字符$i$出现次数, 如果$\frac{n!}{a_1!a_2!...a_k!}$为奇数则必败
$n!$中$2$的幂次为n-__builtin_popcount(n)
所以必败就等价于$a_1+...+a_n=a_1|...|a_n$
设$f_{i,j}$表示前$i$个字符, 状态为$j$的方案数除以总字符数的阶乘
可以得到转移为$f_{i,S}=\sum \frac{1}{x!} f_{i-1,S\oplus x}$
做$O(\log k)$次子集卷积即可, 复杂度是$O(n\log ^2n\log k)$
我写的好像常数太大的没卡过去, 先这样吧
#include <iostream> #include <sstream> #include <algorithm> #include <cstdio> #include <cmath> #include <set> #include <map> #include <queue> #include <string> #include <cstring> #include <bitset> #include <functional> #include <random> #define REP(i,a,n) for(int i=a;i<=n;++i) #define PER(i,a,n) for(int i=n;i>=a;--i) #define hr putchar(10) #define pb push_back #define lc (o<<1) #define rc (lc|1) #define mid ((l+r)>>1) #define ls lc,l,mid #define rs rc,mid+1,r #define x first #define y second #define io std::ios::sync_with_stdio(false) #define endl ‘\n‘ #define DB(a) ({REP(__i,1,n) cout<<a[__i]<<‘,‘;hr;}) using namespace std; typedef long long ll; const int N = 1e6+10; int n,k,P,fac[N],ifac[N],cnt[N]; int dp[N],f[20][N],g[20][N],h[20][N]; ll inv(ll x){return x<=1?1:inv(P%x)*(P-P/x)%P;} ll qpow(ll a,ll n) {ll r=1%P;for (a%=P;n;a=a*a%P,n>>=1)if(n&1)r=r*a%P;return r;} void FMT(int *a, int n, int tp) { int mx = (1<<n)-1; REP(i,0,n-1) REP(j,0,mx) { if (j>>i&1) a[j]=(a[j]+tp*a[j^1<<i])%P; } } void mul(int *a, int *b, int *c, int n) { int mx = (1<<n)-1; REP(i,0,n) REP(j,0,mx) f[i][j]=g[i][j]=h[i][j]=0; REP(i,0,mx) { f[cnt[i]][i] = a[i]; g[cnt[i]][i] = b[i]; } REP(i,0,n) FMT(f[i],n,1),FMT(g[i],n,1); REP(i,0,n) { REP(j,0,i) REP(k,0,mx) { h[i][k] = (h[i][k]+(ll)f[j][k]*g[i-j][k])%P; } FMT(h[i],n,-1); REP(k,0,mx) if (cnt[k]==i) c[k] = h[i][k]; } } int main() { REP(i,0,N-1) cnt[i] = __builtin_popcount(i); scanf("%d%d%d",&n,&k,&P); fac[0] = 1; REP(i,1,N-1) fac[i]=(ll)fac[i-1]*i%P; ifac[N-1] = inv(fac[N-1]); PER(i,0,N-2) ifac[i]=(ll)ifac[i+1]*(i+1)%P; int tot = qpow(k,n); if (n&1) return printf("%d\n",tot),0; int len = 1; while ((1<<len)<=n) ++len; dp[0] = 1; for (; k; mul(ifac,ifac,ifac,len),k>>=1) { if (k&1) mul(dp,ifac,dp,len); } int ans = (tot-(ll)dp[n]*fac[n])%P; if (ans<0) ans += P; printf("%d\n", ans); }
原文地址:https://www.cnblogs.com/uid001/p/11625021.html
时间: 2024-10-02 00:57:04