【知识总结】多项式全家桶(四)(快速幂和开根)

上一篇:【知识总结】多项式全家桶(三)(任意模数NTT)

推荐小恐龙的博客(参考资料):多项式开根

(本文中一切多项式运算默认在模 \(x_n\) 意义下进行)

一、快速幂

多项式快速幂?首先有一种很显然的方式是把整数快速幂里面的整数乘法替换成多项式乘法 NTT ,复杂度 \(O(n\log^2n)\) 。

然而还有一种 \(O(n\log n)\) 的做法:要求 \(B=A^k\) ,相当于求 \(\log_A B=k\) ,用换底公式得 \(\log_A B=\frac{\ln B}{\ln A}=k\) ,所以 \(B=e^{k\ln A}\) 。

然后写个多项式对数函数和指数函数(参见【知识总结】多项式全家桶(二)(ln和exp) )就完了。

注意,多项式运算是对 \(x^n\) 取模而不是对 \(998244353\) 取模!对质数取模仅仅是为了避免出现高精度或小数,对多项式运算没有任何影响。所以不要像我一样傻以为指数上要对 \(998244352\) 取模 QAQ 。

代码:

#include <cstdio>
#include <cctype>
#include <cstring>
#include <algorithm>
#include <climits>
using namespace std;

namespace zyt
{
    template<typename T>
    inline bool read(T &x, const int p = INT_MAX)
    {
        char c;
        bool f = false;
        x = 0;
        do
            c = getchar();
        while (c != EOF && c != '-' && !isdigit(c));
        if (c == EOF)
            return false;
        if (c == '-')
            f = true, c = getchar();
        do
            x = (x * 10LL + c - '0') % p, c = getchar();
        while (isdigit(c));
        if (f)
            x = -x;
        return true;
    }
    template<typename T>
    inline void write(T x)
    {
        static char buf[20];
        char *pos = buf;
        if (x < 0)
            putchar('-'), x = -x;
        do
            *pos++ = x % 10 + '0';
        while (x /= 10);
        while (pos > buf)
            putchar(*--pos);
    }
    typedef long long ll;
    const int N = 1e5 + 10, p = 998244353, g = 3;
    inline int power(int a, int b)
    {
        int ans = 1;
        while (b)
        {
            if (b & 1)
                ans = (ll)ans * a % p;
            a = (ll)a * a % p;
            b >>= 1;
        }
        return ans;
    }
    inline int get_inv(const int a)
    {
        return power(a, p - 2);
    }
    namespace Polynomial
    {
        const int LEN = N << 2;
        int omega[LEN], winv[LEN], rev[LEN];
        void init(const int n, const int lg2)
        {
            int w = power(g, (p - 1) / n), wi = get_inv(w);
            omega[0] = winv[0] = 1;
            for (int i = 1; i < n; i++)
            {
                omega[i] = (ll)omega[i - 1] * w % p;
                winv[i] = (ll)winv[i - 1] * wi % p;
            }
            for (int i = 0; i < n; i++)
                rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (lg2 - 1)));
        }
        void ntt(int *a, const int *w, const int n)
        {
            for (int i = 0; i < n; i++)
                if (i < rev[i])
                    swap(a[i], a[rev[i]]);
            for (int l = 1; l < n; l <<= 1)
                for (int i = 0; i < n; i += (l << 1))
                    for (int k = 0; k < l; k++)
                    {
                        int x = a[i + k], y = (ll)a[i + l + k] * w[n / (l << 1) * k] % p;
                        a[i + k] = (x + y) % p;
                        a[i + l + k] = (x - y + p) % p;
                    }
        }
        void mul(const int *a, const int *b, int *c, const int n)
        {
            static int x[LEN], y[LEN];
            int m = 1, lg2 = 0;
            while (m < (n << 1) - 1)
                m <<= 1, ++lg2;
            init(m, lg2);
            memcpy(x, a, sizeof(int[n]));
            memset(x + n, 0, sizeof(int[m - n]));
            memcpy(y, b, sizeof(int[n]));
            memset(y + n, 0, sizeof(int[m - n]));
            ntt(x, omega, m), ntt(y, omega, m);
            for (int i = 0; i < m; i++)
                x[i] = (ll)x[i] * y[i] % p;
            ntt(x, winv, m);
            int invm = get_inv(m);
            for (int i = 0; i < n; i++)
                c[i] = (ll)x[i] * invm % p;
        }
        void _inv(const int *A, int *B, const int n)
        {
            if (n == 1)
                B[0] = 1;
            else
            {
                static int tmp[LEN];
                _inv(A, B, (n + 1) >> 1);
                int m = 1, lg2 = 0;
                while (m < (n << 1) - 1)
                    m <<= 1, ++lg2;
                init(m, lg2);
                memcpy(tmp, A, sizeof(int[n]));
                memset(tmp + n, 0, sizeof(int[m - n]));
                memset(B + ((n + 1) >> 1), 0, sizeof(int[m - ((n + 1) >> 1)]));
                ntt(tmp, omega, m), ntt(B, omega, m);
                for (int i = 0; i < m; i++)
                    B[i] = (ll)(B[i] * 2LL % p - (ll)tmp[i] * B[i] % p * B[i] % p + p) % p;
                ntt(B, winv, m);
                int invm = get_inv(m);
                for (int i = 0; i < n; i++)
                    B[i] = (ll)B[i] * invm % p;
                memset(B + n, 0, sizeof(int[m - n]));
            }
        }
        void inv(const int *A, int *B, const int n)
        {
            static int x[LEN];
            memcpy(x, A, sizeof(int[n]));
            _inv(x, B, n);
        }
        void derivative(const int *A, int *B, const int n)
        {
            for (int i = 1; i < n; i++)
                B[i - 1] = (ll)A[i] * i % p;
            B[n - 1] = 0;
        }
        void integrate(const int *A, int *B, const int n)
        {
            for (int i = n - 1; i >= 0; i--)
                B[i + 1] = (ll)A[i] * get_inv(i + 1) % p;
            B[0] = 0;
        }
        void ln(const int *A, int *B, const int n)
        {
            static int tmp1[LEN], tmp2[LEN];
            derivative(A, tmp1, n);
            inv(A, tmp2, n - 1);
            mul(tmp1, tmp2, B, n - 1);
            integrate(B, B, n - 1);
        }
        void _exp(const int *A, int *B, const int n)
        {
            if (n == 1)
                B[0] = 1;
            else
            {
                static int tmp[LEN];
                _exp(A, B, (n + 1) >> 1);
                ln(B, tmp, n);
                for (int i = 0; i < n; i++)
                    tmp[i] = (-tmp[i] + A[i] + p) % p;
                tmp[0] = (tmp[0] + 1) % p;
                mul(B, tmp, B, n);
            }
        }
        void exp(const int *a, int *b, const int n)
        {
            static int tmp[LEN];
            memcpy(tmp, a, sizeof(int[n]));
            _exp(tmp, b, n);
        }
    }
    int work()
    {
        using namespace Polynomial;
        static int a[LEN];
        int n, k;
        read(n), read(k, p);
        for (int i = 0; i < n; i++)
            read(a[i], p);
        ln(a, a, n);
        for (int i = 0; i < n; i++)
            a[i] = (ll)a[i] * k % p;
        exp(a, a, n);
        for (int i = 0; i < n; i++)
            write(a[i]), putchar(' ');
        return 0;
    }
}
int main()
{
    freopen("5245.in", "r", stdin);
    return zyt::work();
}

二、多项式开根

首先当然可以直接用上面的快速幂,相当于计算 \(k\) 是 \(2\) 的逆元时的情况。但有一种常数更小也更好写的方法:

求 \(B=\sqrt A\) ,即 \(B^2=A\) 。和求逆、求指数函数类似,采用分治的思想。假设已经求出 \(B_0\) 满足 \(B_0^2=A \mod x^{\lceil\frac{n}{2}\rceil}\) ,求 \(B^2=A\mod x^n\) 。

显然有 \(B^2=A \mod x^{\lceil\frac{n}{2}\rceil}\) ,所以 \(B_0-B=0 \mod x^{\lceil\frac{n}{2}\rceil}\) 。

类似求逆,两边同时平方,得到:

\[(B_0-B)^2=0 \mod x^n\]

展开,得到:

\[B_0^2-2B_0B+B^2=0 \mod x^n\]

即:

\[B_0^2-2B_0B+A=0 \mod x^n\]

移项:

\[B=\frac{A+B_0^2}{2B_0} \mod x^n\]

多项式求逆即可。

代码:

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cctype>
using namespace std;

namespace zyt
{
    template<typename T>
    inline bool read(T &x)
    {
        char c;
        bool f = false;
        x = 0;
        do
            c = getchar();
        while (c != EOF && c != '-' && !isdigit(c));
        if (c == EOF)
            return false;
        if (c == '-')
            f = true, c = getchar();
        do
            x = x * 10 + c - '0', c = getchar();
        while (isdigit(c));
        if (f)
            x = -x;
        return true;
    }
    template<typename T>
    inline void write(T x)
    {
        static char buf[20];
        char *pos = buf;
        if (x < 0)
            putchar('-'), x = -x;
        do
            *pos++ = x % 10 + '0';
        while (x /= 10);
        while (pos > buf)
            putchar(*--pos);
    }
    typedef long long ll;
    const int N = 1e5 + 10, p = 998244353, g = 3;
    inline int power(int a, int b)
    {
        int ans = 1;
        while (b)
        {
            if (b & 1)
                ans = (ll)ans * a % p;
            a = (ll)a * a % p;
            b >>= 1;
        }
        return ans;
    }
    inline int get_inv(const int a)
    {
        return power(a, p - 2);
    }
    namespace Polynomial
    {
        const int LEN = N << 2;
        int rev[LEN], omega[LEN], winv[LEN];
        void init(const int n, const int lg2)
        {
            int w = power(g, (p - 1) / n), wi = get_inv(w);
            omega[0] = winv[0] = 1;
            for (int i = 1; i < n; i++)
            {
                omega[i] = (ll)omega[i - 1] * w % p;
                winv[i] = (ll)winv[i - 1] * wi % p;
            }
            for (int i = 0; i < n; i++)
                rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (lg2 - 1)));
        }
        void ntt(int *a, const int *w, const int n)
        {
            for (int i = 0; i < n; i++)
                if (i < rev[i])
                    swap(a[i], a[rev[i]]);
            for (int l = 1; l < n; l <<= 1)
                for (int i = 0; i < n; i += (l << 1))
                    for (int k = 0; k < l; k++)
                    {
                        int x = a[i + k], y = (ll)a[i + l + k] * w[n / (l << 1) * k] % p;
                        a[i + k] = (x + y) % p;
                        a[i + l + k] = (x - y + p) % p;
                    }
        }
        void mul(const int *a, const int *b, int *c, const int n)
        {
            static int x[LEN], y[LEN];
            int m = 1, lg2 = 0;
            while (m < (n << 1) - 1)
                m <<= 1, ++lg2;
            memcpy(x, a, sizeof(int[n]));
            memset(x + n, 0, sizeof(int[m - n]));
            memcpy(y, b, sizeof(int[n]));
            memset(y + n, 0, sizeof(int[m - n]));
            init(m, lg2);
            ntt(x, omega, m), ntt(y, omega, m);
            for (int i = 0; i < m; i++)
                x[i] = (ll)x[i] * y[i] % p;
            ntt(x, winv, m);
            int invm = get_inv(m);
            for (int i = 0; i < n; i++)
                c[i] = (ll)x[i] * invm % p;
        }
        void _inv(const int *a, int *b, const int n)
        {
            if (n == 1)
                b[0] = get_inv(a[0]);
            else
            {
                static int tmp[LEN];
                int m = 1, lg2 = 0;
                _inv(a, b, (n + 1) >> 1);
                while (m < (n << 1) - 1)
                    m <<= 1, ++lg2;
                init(m, lg2);
                memcpy(tmp, a, sizeof(int[n]));
                memset(tmp + n, 0, sizeof(int[m - n]));
                memset(b + ((n + 1) >> 1), 0, sizeof(int[m - ((n + 1) >> 1)]));
                ntt(tmp, omega, m), ntt(b, omega, m);
                for (int i = 0; i < m; i++)
                    b[i] = (b[i] * 2LL % p - (ll)tmp[i] * b[i] % p * b[i] % p + p) % p;
                ntt(b, winv, m);
                int invm = get_inv(m);
                for (int i = 0; i < n; i++)
                    b[i] = (ll)b[i] * invm % p;
                memset(b + n, 0, sizeof(int[m - n]));
            }
        }
        void inv(const int *a, int *b, const int n)
        {
            static int tmp[LEN];
            memcpy(tmp, a, sizeof(int[n]));
            _inv(tmp, b, n);
        }
        void _sqrt(const int *a, int *b, const int n)
        {
            if (n == 1)
                b[0] = 1;
            else
            {
                static int tmp1[LEN], tmp2[LEN];
                _sqrt(a, b, (n + 1) >> 1);
                memset(b + ((n + 1) >> 1), 0, sizeof(int[n - ((n + 1) >> 1)]));
                mul(b, b, tmp1, n);
                for (int i = 0; i < n; i++)
                    b[i] = b[i] * 2LL % p;
                inv(b, tmp2, n);
                for (int i = 0; i < n; i++)
                    tmp1[i] = (tmp1[i] + a[i]) % p;
                mul(tmp1, tmp2, b, n);
            }
        }
        void sqrt(const int *a, int *b, const int n)
        {
            static int tmp[LEN];
            memcpy(tmp, a, sizeof(int[n]));
            _sqrt(tmp, b, n);
        }
    }
    int n, a[N << 2];
    int work()
    {
        read(n);
        for (int i = 0; i < n; i++)
            read(a[i]);
        Polynomial::sqrt(a, a, n);
        for (int i = 0; i < n; i++)
            write(a[i]), putchar(' ');
        return 0;
    }
}
int main()
{
    return zyt::work();
}

原文地址:https://www.cnblogs.com/zyt1253679098/p/10657548.html

时间: 2025-01-08 20:57:32

【知识总结】多项式全家桶(四)(快速幂和开根)的相关文章

【知识总结】多项式全家桶(一)(NTT、加减乘除和求逆)

我这种数学一窍不通的菜鸡终于开始学多项式全家桶了-- 必须要会的前置技能:FFT(不会?戳我:[知识总结]快速傅里叶变换(FFT)) 一.NTT 跟FFT功能差不多,只是把复数域变成了模域(计算复数系数多项式相乘变成计算在模意义下整数系数多项式相乘).你看FFT里的单位圆是循环的,模一个质数也是循环的嘛qwq.\(n\)次单位根\(w_n\)怎么搞?看这里:[BZOJ3328]PYXFIB(数学)(内含相关证明.只看与原根和单位根相关的内容即可.) 注意裸的NTT要求模数\(p\)存在原根并且\

[模板] 多项式全家桶

注意:以下所有说明均以帮助理解模板为目的,不保证正确性. 多项式求逆 已知$A(x)$,求满足$A(x)B(x)=1\ (mod\ x^n)$的B(以下为了方便假设n是2的幂) 考虑倍增,假设已经求出$A(x)B_0(x)=1\ (mod\ x^{n/2})$ $$A(x)(B(x)-B_0(x))=0\ (mod\ x^{n/2})$$ $$(B(x)-B_0(x))=0\ (mod\ x^{n/2})$$ $$(B(x)-B_0(x))^2=0\ (mod\ x^n)$$ $$B^2(x)-

[算法学习] 多项式全家桶

多项式 一个\(n\)次多项式可以表示为\(A(x)=\sum_{i=0}^{n}a_i x^i\),另一个\(n\)次多项式可以表示为\(B(x)=\sum_{i=0}^{n}b_i x^i\). 多项式加法 将\(A(x)\)和\(B(x)\)相加,得到多项式\(C(x)=\sum_{i=0}^{n} (a_i+b_i) x^i\). 复杂度是\(O(n)\)的. 多项式乘法 将\(A(x)\)和\(B(x)\)相乘,得到多项式\(C(x)=\sum_{i=0}^{n}\sum_{j=0}^

多项式全家桶

Include 多项式乘法 多项式求逆 多项式除法 多项式取模 多项式对数函数 多项式指数函数 多项式正弦函数 多项式余弦函数 #include<bits/stdc++.h> #define reg register int #define il inline #define fi first #define se second #define mk(a,b) make_pair(a,b) #define numb (ch^'0') using namespace std; typedef l

Uva10870 Recurrences(矩阵快速幂)

题目 考虑递推关系式\(f(n)=a1*f(n-1)+a2*f(n-2)+....+ad*f(n-d)\),计算f(n)%m [输入格式] 输入包含多组测试数据.每组数据第一行为三个整数d,n,m(1<=d<=15,1<=n<=2^31-1,1<=m<=46340).第二行包含d个非负整数a1,a2.....ad.第三行为d个非负整数f(1),f(2).....f(d).这些数字均不超过2^31-1.输入结束的标志是d=n=m=0. [输出格式] 对于每组数据,输出f(

【BZOJ3992】[SDOI2015]序列统计 NTT+多项式快速幂

[BZOJ3992][SDOI2015]序列统计 Description 小C有一个集合S,里面的元素都是小于M的非负整数.他用程序编写了一个数列生成器,可以生成一个长度为N的数列,数列中的每个数都属于集合S. 小C用这个生成器生成了许多这样的数列.但是小C有一个问题需要你的帮助:给定整数x,求所有可以生成出的,且满足数列中所有数的乘积mod M的值等于x的不同的数列的有多少个.小C认为,两个数列{Ai}和{Bi}不同,当且仅当至少存在一个整数i,满足Ai≠Bi.另外,小C认为这个问题的答案可能

矩阵快速幂基础知识

一. 先介绍以下矩阵的基础知识 矩阵:有 n 行 m 列组成一个 n*m 的矩阵 1. 矩阵的加减运算满足的条件:两个矩阵的行.列 必须相同 2. 矩阵的乘运算 满足的条件:  A矩阵的列数为 B矩阵的行数 A(ms)*B(sn)=C(mn) 得到的矩阵 C 是 m 行 n 列的 其中 c[i][j] 为A 的第 i 行与B的第j 列对应乘积的和 即:  代码: 1 const int N=100; 2 int c[N][N]; 3 void multi(int a[m][s],int b[s]

多项式求ln,求exp,开方,快速幂 学习总结

按理说Po姐姐三月份来讲课的时候我就应该学了 但是当时觉得比较难加上自己比较懒,所以就QAQ了 现在不得不重新弄一遍了 首先说多项式求ln 设G(x)=lnF(x) 我们两边求导可以得到G'(x)=F‘(x)/F(x) 则G(x)就是F’(x)/F(x)的积分 我们知道多项式求导和积分是O(n)的,多项式求逆是O(nlogn)的 所以总时间复杂度O(nlogn) 多项式求ln一般解决的问题是这样的 设多项式f表示一些奇怪的东西,由一些奇怪的东西有序组成的方案为 f^1+f^2+f^3…… 化简之

luoguP5219 无聊的水题 I 多项式快速幂

有一个幼儿园容斥:最大次数恰好为 $m=$  最大次数最多为 $m$ - 最大次数最多为 $m-1$. 然后来一个多项式快速幂就好了. code: #include <cmath> #include <cstring> #include <algorithm> #include <cstdio> #include <string> #define ll long long #define ull unsigned long long using