Unconsciously, I began to learn polynomials. It's so fast
Pick up what I learned before, no \ (FFT \), new \ (NTT/FWT \)
Let's go Learning blog
FFT: Fast Fourier transform
In general, it is to quickly find the product of two polynomials
Stick the board here Luogu3803 polynomial multiplication
code#include<bits/stdc++.h> using namespace std; #define fo(i,x,y) for(int i=(x);i<=(y);i++) #define fu(i,x,y) for(int i=(x);i>=(y);i--) const double pi=acos(-1.0); const int N=1<<21; struct coex{ double r,i; coex(){} coex(double x,double y){r=x;i=y;} coex operator + (coex a){return coex(r+a.r,i+a.i);} coex operator - (coex a){return coex(r-a.r,i-a.i);} coex operator * (coex a){return coex(r*a.r-i*a.i,r*a.i+i*a.r);} }a[N],b[N],w[N]; int af[N],la,lb,lim,len; void fft(coex *a,int lim){ fo(i,0,lim-1)if(af[i]>i)swap(a[i],a[af[i]]); for(int t=lim>>1,d=1;d<lim;d<<=1,t>>=1) for(int i=0;i<lim;i+=(d<<1)) fo(j,0,d-1){ coex tmp=w[t*j]*a[i+j+d]; a[i+j+d]=a[i+j]-tmp; a[i+j]=a[i+j]+tmp; } } signed main(){ scanf("%d%d",&la,&lb); fo(i,0,la)scanf("%lf",&a[i].r); fo(i,0,lb)scanf("%lf",&b[i].r); for(lim=1,len=0;lim<=la+lb;lim<<=1,len++); fo(i,0,lim-1){ af[i]=(af[i>>1]>>1)|((i&1)<<(len-1)); w[i]=coex(cos(2.0*i*pi/lim),sin(2.0*i*pi/lim)); } fft(a,lim);fft(b,lim); fo(i,0,lim-1)a[i]=a[i]*b[i],w[i].i=-w[i].i; fft(a,lim); fo(i,0,la+lb)printf("%d ",(int)(a[i].r/lim+0.5)); }
The following points should be noted:
1. \ (lim \) must be greater than polynomial degree, strictly greater than
2. Find out the relationship between \ (d \) and \ (t \), and various 2x
3. Back board, back is over.
4. Don't forget to divide the final answer by \ (lim \), the preprocessing unit root by \ (2*i*pi/lim \), and don't forget to divide by \ (lim \)
NTT: fast number theory transformation
codevoid ntt(int *a,int lim){ fo(i,0,lim-1)if(af[i]>i)swap(a[i],a[af[i]]); for(int t=lim>>1,d=1;d<lim;d<<=1,t>>=1) for(int i=0;i<lim;i+=(d<<1)) fo(j,0,d-1){ int tmp=g[t*j]*a[i+j+d]%mod; a[i+j+d]=(a[i+j]-tmp+mod)%mod; a[i+j]=(a[i+j]+tmp)%mod; } } int bas[M],ans[M]; int a[M],b[M]; void mul(int *x,int *y,int *z){ fo(i,0,lim-1)a[i]=y[i]; fo(i,0,lim-1)b[i]=z[i]; g[0]=1;g[1]=ksm(3,(mod-1)/lim,mod); fo(i,2,lim-1)g[i]=g[i-1]*g[1]%mod; ntt(a,lim);ntt(b,lim); g[0]=1;g[1]=ksm(g[1],mod-2,mod); fo(i,2,lim-1)g[i]=g[i-1]*g[1]%mod; fo(i,0,lim-1)a[i]=a[i]*b[i]%mod; ntt(a,lim);int inv=ksm(lim,mod-2,mod); fo(i,0,lim-1)a[i]=a[i]*inv%mod; fo(i,0,m-2)a[i]=(a[i]+a[i+m-1])%mod; fo(i,0,m-2)x[i]=a[i]; }
be careful:
1. Pay attention to the module and find the original root yourself
2. Notice the last time you come out divided by \ (lim \)
3. Don't forget to change the original root
Polynomial inversion
Is to find the inverse of a polynomial (mod\ x^n \)
Mainly, this board is very difficult to recite
Recursive Proof process
So the conclusion is \ (B=2B'-AB'^2\)(\(B '\) represents the inverse element in the sense of \ (mod\ x^{\lceil\frac{n}{2}\rceil} \)
code#include<bits/stdc++.h> using namespace std; #define int long long #define fo(i,x,y) for(int i=(x);i<=(y);i++) #define fu(i,x,y) for(int i=(x);i>=(y);i--) int read(){ int s=0,t=1;char ch=getchar(); while(!isdigit(ch)){if(ch=='-')t=-1;ch=getchar();} while(isdigit(ch)){s=s*10+ch-'0';ch=getchar();} return s*t; } const int mod=998244353; const int N=1<<18; int ksm(int x,int y){ int ret=1; while(y){ if(y&1)ret=ret*x%mod; x=x*x%mod;y>>=1; }return ret; } int af[N],g[N],lim,len; void ntt(int *a,int lim){ fo(i,0,lim-1)if(af[i]>i)swap(a[i],a[af[i]]); for(int t=lim>>1,d=1;d<lim;d<<=1,t>>=1) for(int i=0;i<lim;i+=(d<<1)) fo(j,0,d-1){ int tmp=g[t*j]*a[i+j+d]%mod; a[i+j+d]=(a[i+j]-tmp+mod)%mod; a[i+j]=(a[i+j]+tmp)%mod; } } int a[N],b[N],c[N]; void calc(int deg){//DEG indicates that the current mod is x^deg if(deg==1){return b[0]=ksm(a[0],mod-2),void();} calc((deg+1)>>1); for(lim=1,len=0;lim<deg*2;lim<<=1,len++); fo(i,0,lim-1)af[i]=(af[i>>1]>>1)|((i&1)<<(len-1)); g[0]=1;g[1]=ksm(3,(mod-1)/lim); fo(i,2,lim-1)g[i]=g[i-1]*g[1]%mod; fo(i,0,deg-1)c[i]=a[i];//Note that only the first deg-1 bits are assigned here, because only these bits can be used in this multiplication //Here a has deg-1 bit, while b has only DEG + 1 > > 1 bit, because b is square, and here is the inverse of the deg-1 bit of A ntt(c,lim);ntt(b,lim); g[0]=1;g[1]=ksm(g[1],mod-2); fo(i,2,lim-1)g[i]=g[i-1]*g[1]%mod; fo(i,0,lim-1)b[i]=(2-c[i]*b[i]%mod+mod)*b[i]%mod; ntt(b,lim);int iv=ksm(lim,mod-2); fo(i,0,lim-1)b[i]=b[i]*iv%mod; fo(i,deg,lim-1)b[i]=0;//Empty the back to prevent affecting the next operation fo(i,0,lim-1)c[i]=0;//empty } int n; signed main(){ n=read(); fo(i,0,n-1)a[i]=read(); calc(n); fo(i,0,n-1)printf("%lld ",b[i]); }
FWT: fast Walsh transform
It is found that the convolutions we do are \ (C = \ sum {I = 1} ^ {n} \ sum {J = 1} ^ {I} a_j * B {I-J} \)
That is, the additive convolution that can be handled by \ (FFT \) and \ (NTT \), if it is with? Or? What about XOR?
We need to use \ (FWT \), that is, fast bit convolution
|: or convolution
This seems a little difficult if it is a direct volume, but the high-dimensional prefix sum is or convolution, so we can use the high-dimensional prefix sum to solve the problem of or convolution
However, the high-dimensional prefix sum is not directly supported by \ (FWT \)
\(FWT \) has the same transformation and inverse transformation as \ (FFT \). Let's see the blog I gave at the top
This transformation can't be called a point value expression, it should be called a subset and an expression
It's equivalent to enumerating each binary bit to find the subset sum, and the inverse transformation is the same when it is disassembled
codevoid ort(int *a,int lim,int tp){//tp=1 represents the forward transform, tp=2 represents the inverse transform for(int d=1;d<lim;d<<=1) for(int i=0;i<lim;i+=(d<<1)) fo(j,0,d-1){ a[i+j+d]=(a[i+j+d]+a[i+j]*tp)%mod; } }
&: and convolution
This should be called high-dimensional suffix sum. We can use high-dimensional suffix sum to solve the problem of convolution
Transformation is suffix, same as inverse transformation
codevoid andt(int *a,int lim,int tp){//tp=1, tp=2 for(int d=1;d<lim;d<<=1) for(int i=0;i<lim;i+=(d<<1)) fo(j,0,d-1){ a[i+j]=(a[i+j]+a[i+j+d]*tp)%mod; } }
^: XOR convolution
This should be called high dimensional suffix difference
It seems a little difficult to summarize. Well, if you don't summarize, just look at the code
codevoid xort(int *a,int lim,int tp){//tp=1 means forward transform, tp=0.5 means inverse transform for(int d=1;d<lim;d<<=1) for(int i=0;i<lim;i+=(d<<1)) fo(j,0,d-1){ int tmp=a[i+j+d]; a[i+j+d]=(a[i+j]-tmp+mod)*tp%mod; a[i+j]=(a[i+j]+tmp)*tp%mod; } }
The solution of high-dimensional prefix sum, enumerate each dimension and do prefix sum again
In bit operation, that is to enumerate each binary bit, and add the bit with \ (0 \) to the bit with \ (1 \)
In this way, although the high-dimensional prefix is simple, it has the guarantee of correctness and ensures no repetition and no leakage
code#include<bits/stdc++.h> using namespace std; #define int long long #define fo(i,x,y) for(int i=(x);i<=(y);i++) #define fu(i,x,y) for(int i=(x);i>=(y);i--) int read(){ int s=0,t=1;char ch=getchar(); while(!isdigit(ch)){if(ch=='-')t=-1;ch=getchar();} while(isdigit(ch)){s=s*10+ch-'0';ch=getchar();} return s*t; } const int mod=998244353; const int N=1<<18; int ksm(int x,int y){ int ret=1; while(y){ if(y&1)ret=ret*x%mod; x=x*x%mod;y>>=1; }return ret; } int n,lim,a[N],b[N],c[N],aa[N],bb[N]; void init(int lim){fo(i,0,lim-1)a[i]=aa[i],b[i]=bb[i];} void ort(int *a,int lim,int tp){ for(int d=1;d<lim;d<<=1) for(int i=0;i<lim;i+=(d<<1)) fo(j,0,d-1){ a[i+j+d]=(a[i+j+d]+a[i+j]*tp)%mod; } } void andt(int *a,int lim,int tp){ for(int d=1;d<lim;d<<=1) for(int i=0;i<lim;i+=(d<<1)) fo(j,0,d-1){ a[i+j]=(a[i+j]+a[i+j+d]*tp)%mod; } } void xort(int *a,int lim,int tp){ for(int d=1;d<lim;d<<=1) for(int i=0;i<lim;i+=(d<<1)) fo(j,0,d-1){ int tmp=a[i+j+d]; a[i+j+d]=(a[i+j]-tmp+mod)*tp%mod; a[i+j]=(a[i+j]+tmp)*tp%mod; } } void mix(int lim){fo(i,0,lim-1)c[i]=a[i]*b[i]%mod;} void pt(int *a,int lim){fo(i,0,lim-1)printf("%lld ",a[i]);printf("\n");} signed main(){ n=read();lim=1<<n; fo(i,0,lim-1)aa[i]=read(); fo(i,0,lim-1)bb[i]=read(); init(lim);ort(a,lim,1);ort(b,lim,1);mix(lim);ort(c,lim,mod-1);pt(c,lim); init(lim);andt(a,lim,1);andt(b,lim,1);mix(lim);andt(c,lim,mod-1);pt(c,lim); init(lim);xort(a,lim,1);xort(b,lim,1);mix(lim);xort(c,lim,ksm(2,mod-2));pt(c,lim); }
skill
1. When generating the function, there is a problem about symmetry. If it is symmetrical, the sum of the positions of the two points is twice the center of symmetry, which means that the axis of symmetry is the same after polynomial convolution
2. For the use of generating functions, we only need coefficients, not really bring \ (x \) to find values. I only need coefficients
For example: \ (\ sum {I = 0} ^ {n} \ sum {J = 0} ^ {I} J! * 2 ^ {I-J} \), the factorial and power of this thing can be regarded as the coefficients of two polynomials respectively
In this way, we can quickly find the following sum through convolution, that is, the coefficient of the \ (i \) term of the convoluted polynomial, which conforms to the rules of polynomial multiplication, and then bring it into the sum