快速傅里叶变换(FFT)、数论变换(NTT)及其在字符串匹配中的应用
多项式乘法
一元多项式是指形如$a_0+a_1x+a_2x^2+……+a_{n-1}x^{n-1}$的式子,两个次多项式相乘,运算法则是“逐项相乘,合并同类项”,复杂度是\(O(n^2)\)。多项式还可以用其对应函数图像上的点值来表示,显然,n-1次多项式需要用n个不同的点来表示,因为至少需要n个等于组成的方程组才可求出所有系数的值。多项式点值表示法的好处是计算乘法比较方便,只需函数值相乘即可。所以,要求两个多项式的乘积,可以把两个多项式由系数形式转换成点值形式,相乘后再转换回来。把多项式由系数形式转换成点值形式称为离散傅里叶变换(Discrete Fourier Transform,DFT),把多项式由点值形式转换成系数形式称为离散傅里叶逆变换(IDFT)。
朴素的DFT方法需要计算n个多项式的值,每个值的计算需要n-1次乘法,所以朴素的DFT算法的复杂度还是\(O(n^2)\),并不能降低计算量。但是在计算n个多项式值的时候,可以选取一些特殊的自变量的值,使得可以重复利用前面的计算结果,就可以降低计算量。 快速傅里叶变换(Fast Fourier Transform,FFT)就是利用复数域内“n次单位根”的特性来加快DFT过程的。
快速傅里叶变换基本方法
对于多项式 \[f(x)=\sum_{i=0}^{n-1}a_ix^i=a_0+a_1x+a_2x^2+……+a_{n-1}x^{n-1}\] 按系数下标的奇偶性分开,得\[f(x)=a_0+a_2x^2+……+a_{n-2}x^{n-2}+a_1x+a_3x^3+……a_{n-1}x^{n-1}\]\[=a_0+a_2x^2+……+a_{n-2}x^{n-2}+x(a_1+a_3x^2+……+a_{n-1}x^{n-2})\]令,\[f_1(x)=a_0+a_2x+a_4x^2+……+a_{n-2}x^{\frac{n}{2}-1}\]\[f_2(x)=a_1+a_3x+a_5x^2+……+a_{n-1}x^{\frac{n}{2}-1}\]则有,\[f(x)=f_1(x^2)+xf_2(x^2)\]设\(\omega_n^0\),\(\omega_n^1\),……,\(\omega_n^{n-1}\)为n次单位根,则求f(x)在\(\omega_n^0\),\(\omega_n^1\),……,\(\omega_n^{n-1}\)处的值就变成了在次数只有一半的两个多项式\(f_1(x)\)、\(f_2(x)\)在\((\omega_n^0)^2\),\((\omega_n^1)^2\),……,\((\omega_n^{n-1})^2\)处的值。根据n次单位根的特性,\((\omega_n^0)^2\),\((\omega_n^1)^2\),……,\((\omega_n^{n-1})^2\)并不是互不相同的,而是由\(\frac{n}{2}\)个\(\frac{n}{2}\)次单位根组成。这样,就把DFT的规模降低了一半,递推下去就得到递归版的FFT算法。找到最里层递归及逐层返回合并的系数的规律,就可以把递归FFT改进为迭代FFT。
DFT其实是一个矩阵的运算,其逆运算IDFT也是矩阵运算,算法原理是相同的。所以往往把FFT、IFFT写在一个函数中。
下面是浴谷P3803AC了的代码,可以作为模块使用。
#include <iostream>
#include <complex>
#include <cmath>
using namespace std;
complex<double> p[2097152],g[2097152];//两多项式最高可达2*10^6次,系数则有2*10^6+1项,但是傅里叶变换需要把系数扩展到2的整数次方,所有这里预留的2^21,大多数是直接扩展4倍。
void fft(complex<double> a[],int len,int inv){
for(int i=0,j=0;i<len;i++){//采用雷德算法进行位逆序变换,很多写法i直接从1开始,因为0显然是不需要变换位置,所以不影响结果。
int k=len;//len是2的整数次幂,决定了二进制的位数;
if(i<j)swap(a[i],a[j]);
//以下计算j按二进制逆序的下一个数,从最高位加起,向右进位
while(j&(k>>=1))j&=~k;//这一步在很多代码中用算术运算代替逻辑运算(右移就是除以2,位与可以用大于来判断)结果是一样的,但是就不太好理解了。
j|=k;//前一步只是处理了需要进位的,不需要进位之后,置1(或运算)
}
for(int s=2;s<=len;s<<=1){//s是准备合并序列的长度
complex<double> wm(cos(inv*2*M_PI/s),sin(inv*2*M_PI/s));
for(int k=0;k<len;k+=s){//步长是序列的长度,循环一次处理一个序列
complex<double> w(1,0);
for(int j=0;j<s/2;j++){//前一半和后一半合并,所以循环终止条件是到k+s/2
complex<double> t=w*a[k+j+s/2];
complex<double> u=a[k+j];
a[k+j]=u+t;
a[k+j+s/2]=u-t;
w=w*wm;
}
}
}
if(inv==-1)
for(int i=0;i<len;i++)
a[i].real(a[i].real()/len);
}
int main(){
int n,m,limit=1,temp;
cin>>n>>m;
while(limit<n+m+1)limit<<=1;
for(int i=0;i<=n;i++){
cin>>temp;
p[i].real(temp);
p[i].imag(0);
}
for(int i=0;i<=m;i++){
cin>>temp;
g[i].real(temp);
g[i].imag(0);
}
for(int i=n+1;i<limit;i++){
p[i].real(0);
p[i].imag(0);
}
for(int i=m+1;i<limit;i++){
g[i].real(0);
g[i].imag(0);
}
fft(p,limit,1);
fft(g,limit,1);
for(int i=0;i<limit;i++){
p[i]=p[i]*g[i];
}
fft(p,limit,-1);
for(int i=0;i<n+m+1;i++){
cout<<(int)(p[i].real()+0.5)<<" ";
}
cout<<endl;
}
运用FFT计算多项式乘法需要注意的是n次多项式和m次多项式相乘,结果是n+m次多项式,那么在分别对n次多项式和m次多项式进行FFT转换时就需要计算n+m+1个点的值而不仅仅是n个和m个,这样才能逐点相乘然后再IFFT。也就是说在进行FFT之前需要扩展,扩展很简单,高次系数为0就行。另外,由于FFT需要不断二分,所以扩展之后的系数个数还必须是2的整数次幂。
FFT在字符串匹配中的应用。
现在有2个字符串S和T,令m = |S|,n = |T|,设n <= m。假设S的第i个位置的字符与T的第j个位置的字符匹配,则可以表示为S[i] = T[j],假设S[i … i + n - 1]与T[0 … n - 1]匹配,则\[\sum_{j=0}^{n-1}(S[i+j]-T[j])^2=0\]取平方是为防止正负抵消恰好等于0的情况。这一个式子的下标之后i+j+j不是定值,无法利用多项式的乘积,因为多项式的乘积是把指数相同即系数的下标之和相同的合并同类项。但是如果把T翻转,那么只要T[n-1]=S[i]、T[n-2]=S[i+1]、……、T[0]=S[i]则字符串同样是在i处匹配,表达式就可改写为\[\sum_{j=0}^{n-1}(S[i+j]-T[n-1-j])^2=0\]这个式子可化为\[\sum_{j=0}^{n-1}(S[i+j])^2+\sum_{j=0}^{n-1}(T[n-1-j])^2-\sum_{j=0}^{n-1}S[i+j]T[n-1-j]=0\]这个式子第一项是数组中连续n项的平方和,可以循环逐个计算,第二项是常数,第三项可以看作两个多项式乘积的第i项,这样就可以利用FFT来查找字符串匹配。这种算法由于用到复数域上的计算存在精度问题,并且也不比KMP算法快,但好处是可以用于有通配符的情况。把通配符的值设为0,则\[\sum_{j=0}^{n-1}(S[i+j]-T[j])^2S[i+j]T[n-1-j]=0\]就可用来判断是否匹配。展开后复杂一些,道理是一样的。
灵活地设计匹配判断表达式,还可以解决部分匹配、不匹配等相似的问题。比如OJ4TH 1415这一题,并不是要求查找完全匹配,只要求最多的部分匹配,并且匹配的部分在模式串中还不一定连续。这一题由于字符集很小,可以分别计算每一个字符在各个位置的匹配数量,设待考查的字符的值为1,其他字符为0,那么\[\sum_{j=0}^{n-1}S[i+j]T[n-1-j]\]的值就是i处的匹配的个数,累加起来找最大值即可。AC的代码如下:
#include <iostream>
#include <complex>
#include <cmath>
#include <algorithm>
using namespace std;
const int N=2<<20; //一般开出4倍空间,实际最大可能用到2倍再扩展到2的整数次方,所以这里取2^21
char x[N],y[N];
complex<double> f[N], g[N];
int sum[N];
void fft(complex<double> a[],int len,int inv){
for(int i=0,j=0;i<len;i++){//采用雷德算法进行位逆序变换,很多写法i直接从1开始,因为0显然是不需要变换位置,所以不影响结果。
int k=len;//len是2的整数次幂,决定了二进制的位数;
if(i<j)swap(a[i],a[j]);
//以下计算j按二进制逆序的下一个数,从最高位加起,向右进位
while(j&(k>>=1))j&=~k;//这一步在很多代码中用算术运算代替逻辑运算(右移就是除以2,位与可以用大于来判断)结果是一样的,但是就不太好理解了。
j|=k;//前一步只是处理了需要进位的,不需要进位之后,置1(或运算)
}
for(int s=2;s<=len;s<<=1){//s是准备合并序列的长度
complex<double> wm(cos(inv*2*M_PI/s),sin(inv*2*M_PI/s));
for(int k=0;k<len;k+=s){//步长是序列的长度,循环一次处理一个序列
complex<double> w(1,0);
for(int j=0;j<s/2;j++){//前一半和后一半合并,所以循环终止条件是到k+s/2
complex<double> t=w*a[k+j+s/2];
complex<double> u=a[k+j];
a[k+j]=u+t;
a[k+j+s/2]=u-t;
w=w*wm;
}
}
}
if(inv==-1)
for(int i=0;i<len;i++)
a[i].real(a[i].real()/len);
}
void solve(char c,int n,int m,int len){
for(int i=0;i<n;i++){
if(x[i]==c)
f[i].real(1);
else
f[i].real(0);
f[i].imag(0);
}
for(int i=n;i<len;i++){
f[i].real(0);
f[i].imag(0);
}
for(int i=0;i<m;i++){
if(y[i]==c)
g[i].real(1);
else
g[i].real(0);
g[i].imag(0);
}
for(int i=m;i<len;i++){
g[i].real(0);
g[i].imag(0);
}
fft(f,len,1);
fft(g,len,1);
for(int i=0;i<len;i++)
f[i]=f[i]*g[i];
fft(f,len,-1);
for(int i=0;i<len;i++)
sum[i]+=(int)(f[i].real()+0.5);
}
int main(){
int n,m,len=1,ans=0;
scanf("%d%d",&n,&m);
scanf("%s%s",x,y);
for (int i=0;i<m;i++){
switch(y[i]){
case 'R':
y[i]='S';
break;
case 'S':
y[i]='P';
break;
case 'P':
y[i]='R';
}
}
reverse(y,y+m);
while(len<n+m-1)len<<=1;
solve('R',n,m,len);
solve('S',n,m,len);
solve('P',n,m,len);
for(int i=m-1;i<n+m-1;i++)
ans=max(ans,sum[i]);
printf("%d\n",ans);
return 0;
}
2018 ACM-ICPC 中国大学生程序设计竞赛线上赛H题Rock Paper Scissors Lizard Spock也是类似的,只不过由于规则复杂了一些,枚举模式串中的一个字母需要替换文本串中的两个字母,道理都一样。
数论变换
在这几个字符串匹配的例子中,实际上多项目式的系数都是比较小的整数,但FFT却用到了复数,进行了大量的浮点运算,FFT是利用了“n次单位根”的特性来减小计算量的。在数论中,根据费尔马小定理,\(a^{p-1}\equiv1\)(mod p)是单位元,且\(a^{\frac{p-1}{n}}\)与\(\omega^{\frac{1}{n}}\)具有同样的性质。模p必须是素数且p-1必须能被n整除,同时因为n是2的幂,所以可以查找形如\(p=c⋅2^k+1\)的素数来选择模数,此时满足\(a^1\)、\(a^2\)、……、\(a^{p-1}\)的a即为原根。常见的形如\(P=c⋅2^k+1\)的素数有\(998244353=119⋅2^{23}+1\),\(1004535809=479⋅2^{21}+1\),它们的原根都为3。用\(a^{\frac{p-1}{n}}\)代替FFT中的\(\omega^{\frac{1}{n}}\)就得到了数论变换(Number Theoretic Transforms,NTT)。模p域中的乘方可以使用快速幂算法,求逆元也是。同样根据费尔马小定理不难得出\(a\cdot a^{p-2}\equiv1\)(mod p)所以\(a^{p-2}\)(mod p)就是a的逆元。
下面是OJ4TH 1415这道题AC了的NTT解法,运行时间减少了约35%,运行内存差不多少了40%。
#include <iostream>
#include <complex>
#include <cmath>
#include <algorithm>
using namespace std;
const int mod=998244353;
const int N=2<<20; //一般开出4倍空间,实际最大可能用到2倍再扩展到2的整数次方,所以这里取2^21
char x[N],y[N];
int f[N],g[N],sum[N];
unsigned long long fast_pow(unsigned long long x,int y) {//快速幂算法
unsigned long long p=1;
while(y){
if(y&1)p=x*p%mod;
x=x*x%mod;
y>>=1;
}
return p;
}
void ntt(int a[],int len,int inv){
for(int i=0,j=0;i<len;i++){//采用雷德算法进行位逆序变换,很多写法i直接从1开始,因为0显然是不需要变换位置,所以不影响结果。
int k=len;//len是2的整数次幂,决定了二进制的位数;
if(i<j)swap(a[i],a[j]);
//以下计算j按二进制逆序的下一个数,从最高位加起,向右进位
while(j&(k>>=1))j&=~k;//这一步在很多代码中用算术运算代替逻辑运算(右移就是除以2,位与可以用大于来判断)结果是一样的,但是就不太好理解了。
j|=k;//前一步只是处理了需要进位的,不需要进位之后,置1(或运算)
}
for(int s=2;s<=len;s<<=1){//s是准备合并序列的长度
int gm=fast_pow(3,(mod-1)/s);//3是原根
if(inv==-1)gm=fast_pow(gm,mod-2);
for(int k=0;k<len;k+=s){//步长是序列的长度,循环一次处理一个序列
unsigned long long g=1;
for(int j=0;j<s/2;j++){//前一半和后一半合并,所以循环终止条件是到k+s/2
unsigned long long t=g*a[k+j+s/2]%mod;
unsigned long long u=a[k+j];
a[k+j]=(u+t)%mod;
a[k+j+s/2]=(u-t+mod)%mod;
g=g*gm%mod;
}
}
}
if(inv==-1){
int t=fast_pow(len,mod-2);
for (int i=0;i<len;i++)
a[i]=(unsigned long long)a[i]*t%mod;
}
}
void solve(char c,int n,int m,int len){
for(int i=0;i<n;i++){
if(x[i]==c)
f[i]=1;
else
f[i]=0;
}
for(int i=n;i<len;i++){
f[i]=0;
}
for(int i=0;i<m;i++){
if(y[i]==c)
g[i]=1;
else
g[i]=0;
}
for(int i=m;i<len;i++){
g[i]=0;
}
ntt(f,len,1);
ntt(g,len,1);
for(int i=0;i<len;i++){
f[i]=(unsigned long long)f[i]*g[i]%mod;
}
ntt(f,len,-1);
for(int i=0;i<len;i++){
sum[i]+=f[i];
}
}
int main(){
int n,m,len=1,ans=0;
scanf("%d%d",&n,&m);
scanf("%s%s",x,y);
for (int i=0;i<m;i++){
switch(y[i]){
case 'R':
y[i]='S';
break;
case 'S':
y[i]='P';
break;
case 'P':
y[i]='R';
}
}
reverse(y,y+m);
while(len<n+m-1)len<<=1;
solve('R',n,m,len);
solve('S',n,m,len);
solve('P',n,m,len);
for(int i=m-1;i<n+m-1;i++)
ans=max(ans,sum[i]);
printf("%d\n",ans);
return 0;
}