计数套路题?但是我连套路都不会,,,
拿到这道题我一脸蒙彼,,,感谢@poorpool 大佬的博客的指点
先将第\(i\)位上的数字\(p_i\)向\(i\)连无向边,然后构成了一个有若干环组成的无向图,可以知道某个点包含它的有且仅有一个环,因为所有点度数都为2(自环的点度数也是2吧qwq)
那么我们的最终目标就是把这个图转换成有\(n\)个自环的图,相当于把排列排好顺序
考虑对于一个\(m\)个点的环,把它变成\(m\)个自环至少需要\(m-1\)步(相当于排列\((2,3...m,1)\)变成\((1,2,3...m)\),这至少需要\(m-1\)次交换)
设把一个\(m\)个点的环变成\(m\)个自环的方案数为\(f_m\).因为一次操作可以把一个环拆成两个环,设拆出来两个环大小为\(x,y\),那么拆环的方案数\(g(x,y)\)为\[g(x,y)=\begin{cases}x,\quad (x=y)\\x+y, \quad others\end{cases} \]
就是对于每一个点,有两个点使得这一个点和那两个点中的一个进行交换操作可以分出大小为\(x,y\)的环,因为点对会被计算2次,所以方案为\(\frac{(x+y)*2}{2}=x+y\),如果\(x=y\),那么那两个点是同一点,所以方案要除以2
每个环之间的操作是不会相互影响的,所以可以交错进行操作,两个环贡献答案给一个大环时,两个环的操作交错所构成的排列数就是可重集的排列数(总元素个数阶乘除以每一中元素个数阶乘).综上,我们可以知道\[f_m=\sum_{i=1}^{\lfloor\frac{m}{2}\rfloor} g(i,m-i)f_if_{m-i}\frac{(m-2)!}{(i-1)!(m-i-1)!}\]
我们可以找出\(k\)个大小为\(a_1,a_2...a_k\)的环
把所有环的答案合并,同样每个环之间的操作是不会相互影响的,所以方案数要乘上一个可重集的排列数,可以得到\[ans=\prod_{i=1}^{k} f_{a_i}*\frac{(n-k)!}{\prod_{i=1}^{k}(a_i-1)!}\]
(注意边界,\(f_1=f_2=1\))
然而求\(f\)要\(O(n)\)的时间,总复杂度为\(O(nk)\)
考虑优化此算法,通过打表找规律可以发现\(f_i=i^{i-2}\)
然后求\(f\)只要要\(O(logn)\)了,爽!
然后就可以偷税的写代码了
// luogu-judger-enable-o2
#include<bits/stdc++.h>
#define LL long long
#define il inline
#define re register
using namespace std;
const LL mod=1000000009;
il LL rd()
{
re LL x=0,w=1;re char ch;
while(ch<'0'||ch>'9') {if(ch=='-') w=-1;ch=getchar();}
while(ch>='0'&&ch<='9') {x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
return x*w;
}
int to[200010],nt[200010],hd[100010],tot=1;
int vis[200010],a[100010],tt;
LL jc[100010],cj[100010]; //不要问我为什么把阶乘数组的逆元写成cj(逃
bool v[100010];
il void add(int x,int y)
{
++tot;to[tot]=y;nt[tot]=hd[x];hd[x]=tot;
++tot;to[tot]=x;nt[tot]=hd[y];hd[y]=tot;
}
il void dfs(int x,int o,int h)
{
if(!v[x]) ++a[o];v[x]=true;
for(int i=hd[x];i;i=nt[i])
{
if(vis[i]==h) continue;
vis[i]=vis[i^1]=h;
dfs(to[i],o,h);
}
}
il LL ksm(LL a,LL b)
{
if(b<=0) return 1; //i=1时,i-2会为-1 qwq
LL an=1;
while(b)
{
if(b&1) an=(an*a)%mod;
a=(a*a)%mod;
b>>=1;
}
return an;
}
int main()
{
jc[0]=cj[0]=1;
for(int i=1;i<=100000;i++) jc[i]=(jc[i-1]*i)%mod,cj[i]=ksm(jc[i],mod-2);
int T=rd();
for(int h=1;h<=T;h++)
{
tot=1;tt=0;
memset(a,0,sizeof(a));
memset(v,0,sizeof(v));
memset(hd,0,sizeof(hd));
int n=rd();
for(int i=1;i<=n;i++) add(i,rd());
for(int i=1;i<=n;i++)
if(!v[i]) dfs(i,++tt,h);
LL ans=jc[n-tt];
for(int i=1;i<=tt;i++) ans=((ans*ksm(a[i],a[i]-2))%mod*cj[a[i]-1])%mod;
printf("%lld\n",ans);
}
return 0;
}
原文地址:https://www.cnblogs.com/smyjr/p/9467183.html