我这种数学一窍不通的菜鸡终于开始学多项式全家桶了……
必须要会的前置技能:FFT(不会?戳我:【知识总结】快速傅里叶变换(FFT))
一、NTT
跟FFT功能差不多,只是把复数域变成了模域(计算复数系数多项式相乘变成计算在模意义下整数系数多项式相乘)。你看FFT里的单位圆是循环的,模一个质数也是循环的嘛qwq。\(n\)次单位根\(w_n\)怎么搞?看这里:【BZOJ3328】PYXFIB(数学)(内含相关证明。只看与原根和单位根相关的内容即可。)
注意裸的NTT要求模数\(p\)存在原根并且\(p-1\)是\(2\)的若干次幂的倍数(这个次数要大于多项式次数\(n\))。于是通常就会用著名的NTT模数:\(998244353=2^{23}\times 7\times 17+1\)。
节约篇幅,代码先不放了。后面所有代码里都有NTT模板……
二、多项式求逆
对于\(n\)次多项式\(A\),如果有多项式\(B\)满足\(AB\equiv 1 \mod x^{n+1}\),则称\(B\)是\(A\)在模\(x^{n+1}\)意义下的逆元(和整数逆元差不多)。通常采用倍增的方法求逆元。通常都会规定多项式系数在模\(p\)的意义下。
首先,\(A\)在模\(x\)的意义下就只有一个常数项,所以此时的逆元\(B\)也只有一个常数项,就是\(A\)的常数项模\(p\)的逆元。
如果我们知道\(B_0\)是\(A\)在模\(x^{\lceil\frac{n}{2}\rceil}\)意义下的逆元,现在要求\(B\)是\(A\)在模\(x^n\)意义下的逆元。根据题设,显然有:
\[AB=1\mod x^n\]
很明显,\(AB\)的\(1\)到\(n-1\)次项系数全是\(0\),所以模一个\(x\)的低于\(n\)次幂也一定是\(1\)。所以
\[AB_0=AB=1\mod x^{\lceil\frac{n}{2}\rceil}\]
那么
\[B-B_0=0\mod x^{\lceil\frac{n}{2}\rceil}\]
两边和模数同时平方:
\[B^2+B_0^2-2BB_0=0\mod x^n\]
两边同时乘\(A\),得到(别忘了\(AB=1\mod x^n\)):
\[B+AB_0^2-2B_0=0\mod x^n\]
然后移项,得到:
\[B=2B_0-AB_0^2\mod x^n\]
照着这个式子递归算就行了。
代码:
注意代码里面的\(n\)是项数不是次数。一定要把没用的数组清空,以及进行NTT时把多项式项数写对。
代码最开始是防机惨护身符。
#include <cstdio>
#include <algorithm>
#include <cctype>
#include <cstring>
#undef i
#undef j
#undef k
#undef max
#undef min
#undef swap
#undef sort
#undef true
#undef false
#undef if
#undef for
#undef while
#define _ 0
using namespace std;
namespace zyt
{
template<typename T>
inline bool read(T &x)
{
char c;
bool f = false;
x = 0;
do
c = getchar();
while (c != EOF && c != '-' && !isdigit(c));
if (c == EOF)
return false;
if (c == '-')
f = true, c = getchar();
do
x = x * 10 + c - '0', c = getchar();
while (isdigit(c));
if (f)
x = -x;
return true;
}
template<typename T>
inline void write(T x)
{
static char buf[20];
char *pos = buf;
if (x < 0)
putchar('-'), x = -x;
do
*pos++ = x % 10 + '0';
while (x /= 10);
while (pos > buf)
putchar(*--pos);
}
typedef long long ll;
const int N = 1e5 + 10, B = 17, LEN = 1 << (B + 2) | 11, p = 998244353;
inline int power(int a, int b)
{
a %= p, b %= p - 1;
int ans = 1;
while (b)
{
if (b & 1)
ans = (ll)ans * a % p;
a = (ll)a * a % p;
b >>= 1;
}
return ans;
}
inline int get_inv(const int a)
{
return power(a, p - 2);
}
namespace Polynomial
{
int omega[LEN], winv[LEN], rev[LEN];
namespace Primitive_Root
{
int cnt;
pair<int, int> prime[20];
inline void get_prime(int n)
{
cnt = 0;
for (int i = 2; i * i <= n; i++)
{
if (n % i == 0)
prime[cnt++] = make_pair(i, 0);
while (n % i == 0)
++prime[cnt - 1].second, n /= i;
}
if (n > 1)
prime[cnt++] = make_pair(n, 1);
}
inline int get_g(const int n)
{
get_prime(n - 1);
for (int i = 2; i < n; i++)
{
bool flag = true;
for (int j = 0; j < cnt && flag; j++)
flag &= (power(i, (n - 1) / prime[j].first) != 1);
if (flag)
return i;
}
return -1;
}
}
void ntt(int *a, const int *w, const int n)
{
for (int i = 0; i < n; i++)
if (i < rev[i])
swap(a[i], a[rev[i]]);
for (int l = 1; l < n; l <<= 1)
for (int i = 0; i < n; i += (l << 1))
for (int k = 0; k < l; k++)
{
int tmp = (a[i + k] - (ll)w[n / (l << 1) * k] * a[i + l + k] % p + p) % p;
a[i + k] = (a[i + k] + (ll)w[n / (l << 1) * k] * a[i + l + k] % p) % p;
a[i + l + k] = tmp;
}
}
void init(const int n, const int lg2)
{
static int g = 0;
if (!g)
g = Primitive_Root::get_g(p);
int w = power(g, (p - 1) / n), wi = get_inv(w);
omega[0] = winv[0] = 1;
for (int i = 1; i < n; i++)
{
omega[i] = (ll)omega[i - 1] * w % p;
winv[i] = (ll)winv[i - 1] * wi % p;
}
for (int i = 0; i < n; i++)
rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (lg2 - 1)));
}
void inv(const int *a, int *ans, const int n)
{
if (n == 1)
ans[0] = get_inv(a[0]);
else
{
static int tmp[LEN];
inv(a, ans, (n + 1) >> 1);
int m = 1, lg2 = 0;
while (m < (n << 1) - 1)
m <<= 1, ++lg2;
memcpy(tmp, a, sizeof(int[n]));
init(m, lg2);
ntt(tmp, omega, m);
ntt(ans, omega, m);
for (int i = 0; i < m; i++)
ans[i] = (ans[i] * 2LL % p - (ll)tmp[i] * ans[i] % p * ans[i] % p + p) % p;
ntt(ans, winv, m);
int invm = get_inv(m);
for (int i = 0; i < m; i++)
ans[i] = (ll)ans[i] * invm % p;
memset(ans + n, 0, sizeof(int[m - n]));
memset(tmp, 0, sizeof(int[m]));
}
}
}
int a[LEN], b[LEN], n;
int work()
{
read(n);
for (int i = 0; i < n; i++)
read(a[i]);
Polynomial::inv(a, b, n);
for (int i = 0; i < n; i++)
write(b[i]), putchar(' ');
return (0^_^0);
}
}
int main()
{
return zyt::work();
}
三、加减乘除
加减法:直接每项对应相加减。
乘法:这就是NTT的目的啊喂!
除法:如果不是带余除法直接乘逆元。下面着重介绍带余除法。
已知\(n\)次多项式\(F\)和\(m\)次多项式\(G\),求\(n-m\)次多项式\(Q\)和多项式\(R\)(\(R\)的次数\(deg_R\)小于\(m\)),满足:
\[F=QG+R\]
(未完待续咕咕咕……
原文地址:https://www.cnblogs.com/zyt1253679098/p/10226915.html