FFT/NTT/FWT re learning or new learning?

Unconsciously, I began to learn polynomials. It's so fast

Pick up what I learned before, no \ (FFT \), new \ (NTT/FWT \)

Let's go Learning blog

FWT blog

FFT: Fast Fourier transform

In general, it is to quickly find the product of two polynomials

Stick the board here Luogu3803 polynomial multiplication

code
#include<bits/stdc++.h>
using namespace std;
#define fo(i,x,y) for(int i=(x);i<=(y);i++)
#define fu(i,x,y) for(int i=(x);i>=(y);i--)
const double pi=acos(-1.0);
const int N=1<<21;
struct coex{
    double r,i;
    coex(){}
    coex(double x,double y){r=x;i=y;}
    coex operator + (coex a){return coex(r+a.r,i+a.i);}
    coex operator - (coex a){return coex(r-a.r,i-a.i);}
    coex operator * (coex a){return coex(r*a.r-i*a.i,r*a.i+i*a.r);}
}a[N],b[N],w[N];
int af[N],la,lb,lim,len;
void fft(coex *a,int lim){
    fo(i,0,lim-1)if(af[i]>i)swap(a[i],a[af[i]]);
    for(int t=lim>>1,d=1;d<lim;d<<=1,t>>=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                coex tmp=w[t*j]*a[i+j+d];
                a[i+j+d]=a[i+j]-tmp;
                a[i+j]=a[i+j]+tmp;
            }
}
signed main(){
    scanf("%d%d",&la,&lb);
    fo(i,0,la)scanf("%lf",&a[i].r);
    fo(i,0,lb)scanf("%lf",&b[i].r);
    for(lim=1,len=0;lim<=la+lb;lim<<=1,len++);
    fo(i,0,lim-1){
        af[i]=(af[i>>1]>>1)|((i&1)<<(len-1));
        w[i]=coex(cos(2.0*i*pi/lim),sin(2.0*i*pi/lim));
    }
    fft(a,lim);fft(b,lim);
    fo(i,0,lim-1)a[i]=a[i]*b[i],w[i].i=-w[i].i;
    fft(a,lim);
    fo(i,0,la+lb)printf("%d ",(int)(a[i].r/lim+0.5));
}

The following points should be noted:

1. \ (lim \) must be greater than polynomial degree, strictly greater than

2. Find out the relationship between \ (d \) and \ (t \), and various 2x

3. Back board, back is over.

4. Don't forget to divide the final answer by \ (lim \), the preprocessing unit root by \ (2*i*pi/lim \), and don't forget to divide by \ (lim \)

NTT: fast number theory transformation

code
void ntt(int *a,int lim){
    fo(i,0,lim-1)if(af[i]>i)swap(a[i],a[af[i]]);
    for(int t=lim>>1,d=1;d<lim;d<<=1,t>>=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                int tmp=g[t*j]*a[i+j+d]%mod;
                a[i+j+d]=(a[i+j]-tmp+mod)%mod;
                a[i+j]=(a[i+j]+tmp)%mod;
            }
}
int bas[M],ans[M];
int a[M],b[M];
void mul(int *x,int *y,int *z){
    fo(i,0,lim-1)a[i]=y[i];
    fo(i,0,lim-1)b[i]=z[i];
    g[0]=1;g[1]=ksm(3,(mod-1)/lim,mod);
    fo(i,2,lim-1)g[i]=g[i-1]*g[1]%mod;
    ntt(a,lim);ntt(b,lim);
    g[0]=1;g[1]=ksm(g[1],mod-2,mod);
    fo(i,2,lim-1)g[i]=g[i-1]*g[1]%mod;
    fo(i,0,lim-1)a[i]=a[i]*b[i]%mod;
    ntt(a,lim);int inv=ksm(lim,mod-2,mod);
    fo(i,0,lim-1)a[i]=a[i]*inv%mod;
    fo(i,0,m-2)a[i]=(a[i]+a[i+m-1])%mod;
    fo(i,0,m-2)x[i]=a[i];
}

be careful:

1. Pay attention to the module and find the original root yourself

2. Notice the last time you come out divided by \ (lim \)

3. Don't forget to change the original root

Polynomial inversion

Is to find the inverse of a polynomial (mod\ x^n \)

Mainly, this board is very difficult to recite

Recursive Proof process

So the conclusion is \ (B=2B'-AB'^2\)(\(B '\) represents the inverse element in the sense of \ (mod\ x^{\lceil\frac{n}{2}\rceil} \)

code
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define fo(i,x,y) for(int i=(x);i<=(y);i++)
#define fu(i,x,y) for(int i=(x);i>=(y);i--)
int read(){
    int s=0,t=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')t=-1;ch=getchar();}
    while(isdigit(ch)){s=s*10+ch-'0';ch=getchar();}
    return s*t;
}
const int mod=998244353;
const int N=1<<18;
int ksm(int x,int y){
    int ret=1;
    while(y){
        if(y&1)ret=ret*x%mod;
        x=x*x%mod;y>>=1;
    }return ret;
}
int af[N],g[N],lim,len;
void ntt(int *a,int lim){
    fo(i,0,lim-1)if(af[i]>i)swap(a[i],a[af[i]]);
    for(int t=lim>>1,d=1;d<lim;d<<=1,t>>=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                int tmp=g[t*j]*a[i+j+d]%mod;
                a[i+j+d]=(a[i+j]-tmp+mod)%mod;
                a[i+j]=(a[i+j]+tmp)%mod;
            }
}
int a[N],b[N],c[N];
void calc(int deg){//DEG indicates that the current mod is x^deg
    if(deg==1){return b[0]=ksm(a[0],mod-2),void();}
    calc((deg+1)>>1);
    for(lim=1,len=0;lim<deg*2;lim<<=1,len++);
    fo(i,0,lim-1)af[i]=(af[i>>1]>>1)|((i&1)<<(len-1));
    g[0]=1;g[1]=ksm(3,(mod-1)/lim);
    fo(i,2,lim-1)g[i]=g[i-1]*g[1]%mod;
    fo(i,0,deg-1)c[i]=a[i];//Note that only the first deg-1 bits are assigned here, because only these bits can be used in this multiplication
    //Here a has deg-1 bit, while b has only DEG + 1 > > 1 bit, because b is square, and here is the inverse of the deg-1 bit of A
    ntt(c,lim);ntt(b,lim);
    g[0]=1;g[1]=ksm(g[1],mod-2);
    fo(i,2,lim-1)g[i]=g[i-1]*g[1]%mod;
    fo(i,0,lim-1)b[i]=(2-c[i]*b[i]%mod+mod)*b[i]%mod;
    ntt(b,lim);int iv=ksm(lim,mod-2);
    fo(i,0,lim-1)b[i]=b[i]*iv%mod;
    fo(i,deg,lim-1)b[i]=0;//Empty the back to prevent affecting the next operation
    fo(i,0,lim-1)c[i]=0;//empty
}
int n;
signed main(){
    n=read();
    fo(i,0,n-1)a[i]=read();
    calc(n);
    fo(i,0,n-1)printf("%lld ",b[i]);
}

FWT: fast Walsh transform

It is found that the convolutions we do are \ (C = \ sum {I = 1} ^ {n} \ sum {J = 1} ^ {I} a_j * B {I-J} \)

That is, the additive convolution that can be handled by \ (FFT \) and \ (NTT \), if it is with? Or? What about XOR?

We need to use \ (FWT \), that is, fast bit convolution

|: or convolution

This seems a little difficult if it is a direct volume, but the high-dimensional prefix sum is or convolution, so we can use the high-dimensional prefix sum to solve the problem of or convolution

However, the high-dimensional prefix sum is not directly supported by \ (FWT \)

\(FWT \) has the same transformation and inverse transformation as \ (FFT \). Let's see the blog I gave at the top

This transformation can't be called a point value expression, it should be called a subset and an expression

It's equivalent to enumerating each binary bit to find the subset sum, and the inverse transformation is the same when it is disassembled

code
void ort(int *a,int lim,int tp){//tp=1 represents the forward transform, tp=2 represents the inverse transform
    for(int d=1;d<lim;d<<=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                a[i+j+d]=(a[i+j+d]+a[i+j]*tp)%mod;
            }
}

&: and convolution

This should be called high-dimensional suffix sum. We can use high-dimensional suffix sum to solve the problem of convolution

Transformation is suffix, same as inverse transformation

code
void andt(int *a,int lim,int tp){//tp=1, tp=2
    for(int d=1;d<lim;d<<=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                a[i+j]=(a[i+j]+a[i+j+d]*tp)%mod;
            }
}

^: XOR convolution

This should be called high dimensional suffix difference

It seems a little difficult to summarize. Well, if you don't summarize, just look at the code

code
void xort(int *a,int lim,int tp){//tp=1 means forward transform, tp=0.5 means inverse transform
    for(int d=1;d<lim;d<<=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                int tmp=a[i+j+d];
                a[i+j+d]=(a[i+j]-tmp+mod)*tp%mod;
                a[i+j]=(a[i+j]+tmp)*tp%mod;
            }
}

The solution of high-dimensional prefix sum, enumerate each dimension and do prefix sum again

In bit operation, that is to enumerate each binary bit, and add the bit with \ (0 \) to the bit with \ (1 \)

In this way, although the high-dimensional prefix is simple, it has the guarantee of correctness and ensures no repetition and no leakage

Luogu4717

code
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define fo(i,x,y) for(int i=(x);i<=(y);i++)
#define fu(i,x,y) for(int i=(x);i>=(y);i--)
int read(){
    int s=0,t=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')t=-1;ch=getchar();}
    while(isdigit(ch)){s=s*10+ch-'0';ch=getchar();}
    return s*t;
}
const int mod=998244353;
const int N=1<<18;
int ksm(int x,int y){
    int ret=1;
    while(y){
        if(y&1)ret=ret*x%mod;
        x=x*x%mod;y>>=1;
    }return ret;
}
int n,lim,a[N],b[N],c[N],aa[N],bb[N];
void init(int lim){fo(i,0,lim-1)a[i]=aa[i],b[i]=bb[i];}
void ort(int *a,int lim,int tp){
    for(int d=1;d<lim;d<<=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                a[i+j+d]=(a[i+j+d]+a[i+j]*tp)%mod;
            }
}
void andt(int *a,int lim,int tp){
    for(int d=1;d<lim;d<<=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                a[i+j]=(a[i+j]+a[i+j+d]*tp)%mod;
            }
}
void xort(int *a,int lim,int tp){
    for(int d=1;d<lim;d<<=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                int tmp=a[i+j+d];
                a[i+j+d]=(a[i+j]-tmp+mod)*tp%mod;
                a[i+j]=(a[i+j]+tmp)*tp%mod;
            }
}
void mix(int lim){fo(i,0,lim-1)c[i]=a[i]*b[i]%mod;}
void pt(int *a,int lim){fo(i,0,lim-1)printf("%lld ",a[i]);printf("\n");}
signed main(){
    n=read();lim=1<<n;
    fo(i,0,lim-1)aa[i]=read();
    fo(i,0,lim-1)bb[i]=read();
    init(lim);ort(a,lim,1);ort(b,lim,1);mix(lim);ort(c,lim,mod-1);pt(c,lim);
    init(lim);andt(a,lim,1);andt(b,lim,1);mix(lim);andt(c,lim,mod-1);pt(c,lim);
    init(lim);xort(a,lim,1);xort(b,lim,1);mix(lim);xort(c,lim,ksm(2,mod-2));pt(c,lim);
}

skill

1. When generating the function, there is a problem about symmetry. If it is symmetrical, the sum of the positions of the two points is twice the center of symmetry, which means that the axis of symmetry is the same after polynomial convolution

2. For the use of generating functions, we only need coefficients, not really bring \ (x \) to find values. I only need coefficients

For example: \ (\ sum {I = 0} ^ {n} \ sum {J = 0} ^ {I} J! * 2 ^ {I-J} \), the factorial and power of this thing can be regarded as the coefficients of two polynomials respectively

In this way, we can quickly find the following sum through convolution, that is, the coefficient of the \ (i \) term of the convoluted polynomial, which conforms to the rules of polynomial multiplication, and then bring it into the sum

Added by SiriusB on Mon, 03 Jan 2022 08:32:13 +0200