FFT解决大数乘法



前言

FFT是至今让我感觉最绝望的东西,因为真的除了模板题勉强能做,剩下的完全就做不了了QAQ
FFT用于加速多项式乘法计算,从O(n^2)优化到O(nlogn),在其他领域用的很广,在ACM领域可以用来算大数乘法,更重要的是用来计算卷积,加速DP,这类题目神难……
还有就是FFT本身难,下面的代码没有理论基础的话,十有八九是看不懂的,除了数学大佬和高智商大佬orz,而且有了理论基础,模板也不一定看得懂……会了模板,题目还是不会做…………
个人感觉FFT需要的理论基础有:线性代数,欧拉方程,复数的运算

  1. FFT详解
    https://blog.csdn.net/ggn_2015/article/details/68922404
    https://www.zybuluo.com/397915842/note/37965
  2. 二进制反转
    https://blog.csdn.net/GGN_2015/article/details/69518685

大数乘法 V2 问题 - 51Nod-1028

题目大意
模板题

思路
使用std的complex类做,旋转因子用 wn = exp(PI*i/(n/2))
这是第一道我能把注释写得如此详细的题目……

  #include <iostream>
  #include <cstdio>
  #include <cstring>
  #include <algorithm>
  #include <complex>
  #include <cmath>
  using namespace std;
  const int N = 3e5 + 15;
  const double PI = acos(-1);
  typedef complex<double> cd;

  cd   a[N], b[N];     //complex数组
  int  rev[N];         //记录反转二进制数字
  char s1[N], s2[N];
  int  ans[N];

  //getRev:
  //作用:生成1 到 2^bit(也就是s)的颠倒二进制数的数组
  //原理:将当前数字二进制反转,即是
  //当前数字前bit - 1位的颠倒结果右移一位 + 当前数字末位移到最高位
  //其中前者的工作已经在rev[i >> 1]生成的时候完成了,所以把这个值直接右移一位,
  //再加上当前数字末位最高位
  void getRev(int bit){   //bit = log2(s)
      for(int i = 0; i < (1 << bit); i++){
          rev[i] = ((rev[i >> 1] >> 1) | ((i&1) << (bit - 1)));
      }
  }

  //fft:
  //作用:计算dft(系数式转点值式)和idft(点值式转系数式)
  //参数:
  // cd a[]: 系数表达式中的系数向量
  // int n:  结果的点值表达式的次数界
  // int dft:传1进行dft,传-1进行idft
  void fft(cd a[], int n, int dft){
      //根据得到的rev数组调整a,为下一步的循环做准备
      //加入 if(i < rev[i]),防止双次交换变回原样
      //建议debug看一下原理
      for(int i = 0; i < n; i++){
          if(i < rev[i])     swap(a[i], a[rev[i]]);
      }

      //step为每次的步长,从底层递推回去,因此
      //不需要走到n,只需要走到step/2
      for(int step = 1; step < n; step <<= 1){
          //计算单位复根,根据消去引理,计算的是w[n/2](k),而不是w[n](2k)
          cd wn = exp(cd(0, PI*dft/step));

          //计算每一个块,具体可看《算法导论》的那颗树
          //加上step << 1是因为下面是同时跑k和k+step的,总共就跑了step << 1
          for(int j = 0; j < n; j += step << 1){
              //递归时同深度的子树相互独立,因此递推回去也是一样
              cd wnk(1, 0);
              for(int k = j; k < j + step; k++){
                  cd x = a[k];
                  cd y = wnk*a[k + step];
                  a[k]        = x + y;
                  a[k + step] = x - y;
                  wnk*=wn;
              }
          }
      }

      if(dft == -1){
          for(int i = 0; i < n; i++)      a[i] /= n;
      }
  }

  int main(){
      while(~scanf("%s%s", s1, s2)){
          memset(ans, 0, sizeof(ans));
          int len_s1 = strlen(s1), len_s2 = strlen(s2);
          for(int i = 0; i < len_s1; i++)     a[i] = cd(s1[len_s1 - i - 1] - '0', 0);
          for(int i = 0; i < len_s2; i++)     b[i] = cd(s2[len_s2 - i - 1] - '0', 0);

          //找到第一个超过(len_s1 + len_s2 - 1)的2的整数幂s作为次数界,并记录下指数bit
          //bit给getRev函数用, s给fft函数用
          int bit = 1, s = 2;
          for(bit = 1; s < len_s1 + len_s2 - 1; bit++, s <<= 1);

          //调用getRev函数生成反转二进制,从1到2^bit - 1
          getRev(bit);

          //将a和b转为点值式,并相乘
          //将相乘的结果
          fft(a, s, 1);
          fft(b, s, 1);
          for(int i = 0; i < s; i++)      a[i] * = b[i];
          fft(a, s, -1);

          //还原成十进制,目前表达式是
          //a[0] + a[1] * x + a[2] * x^2 + ... + a[s - 1] * x^(s - 1)
          //其中a[i]并不一定是10以内的数,所以该进位的要进位
          for(int i = 0; i < s; i++){
              ans[i] += (int)(a[i].real() + 0.5);
              ans[i + 1] += ans[i]/10;
              ans[i] %= 10;
          }

          //输出答案,去掉前导0
          int* p = ans + len_s1 + len_s2;
          while(* p == 0 && p != ans - 1)  p--;
          if(p == ans - 1){
              printf("0\n");
          }else{
              while(p != ans - 1){
                  printf("%d", * p);
                  p--;
              }
              puts("");
          }
      }
      return 0;
  }


Bull Math - POJ 2389

题目大意
同样是大数乘法

思路
本题可模拟做,当然这里用FFT做
复数类是自己写的,wn取 i*sin(PI/(n/2)) + cos(PI/(n/2))

  #include <cstdio>
  #include <iostream>
  #include <cmath>
  #include <algorithm>
  #include <cstring>
  using namespace std;
  typedef long long ll;
  const int N = 200;
  const double PI = acos(-1.0);

  struct cd{
      double real, imag;

      cd(double _real = 0, double _imag = 0): real(_real), imag(_imag) {}
      cd operator + (cd _obj) { return cd(real + _obj.real, imag + _obj.imag);}
      cd operator - (cd _obj) { return cd(real - _obj.real, imag - _obj.imag);}
      cd operator * (cd _obj) { return cd(real * _obj.real - imag * _obj.imag, real * _obj.imag + imag * _obj.real);}
      cd operator * (double _real) { return ( * this) * cd(_real, 0);}
      cd operator = (double x) { real = x, imag = 0; return *this;}
  };

  cd   a[N], b[N];
  int  rev[N];
  char s1[N], s2[N];
  int  ans[N];

  inline void init(){
      memset(a, 0, sizeof(a));
      memset(b, 0, sizeof(b));
      memset(ans, 0, sizeof(ans));
  }

  inline void getRev(int bit){
      for(int i = 0; i < (1 << bit); i++){
          rev[i] = ((rev[i >> 1] >> 1) | (i&1) << (bit - 1));
      }
  }
  void mswap(cd& x, cd& y){
      cd temp = x;
      x = y;
      y = temp;
  }

  void fft(cd a[], int n, int dft){
      for(int i = 0; i < n; i++){
          if(i < rev[i])  mswap(a[i], a[rev[i]]);
      }
      for(int step = 1; step < n; step <<= 1){
          cd wn(cos(dft*PI/step), sin(dft*PI/step));
          for(int j = 0; j < n; j += step << 1){
              cd wnk(1, 0);
              for(int k = j; k < j + step; k++){
                  cd x = a[k];
                  cd y = wnk*a[k + step];
                  a[k]        = x + y;
                  a[k + step] = x - y;
                  wnk = wnk*wn;
              }
          }
      }
      if(dft == -1){
          for(int i = 0; i < n; i++){
              a[i].real /= n;
          }
      }
  }

  int main(){
      while(~scanf("%s%s", s1, s2)){
          init();
          int len_s1 = strlen(s1), len_s2 = strlen(s2);
          for(int i = 0; i < len_s1; i++)
              a[i] = (s1[len_s1 - i - 1] - '0');
          for(int i = 0; i < len_s2; i++)
              b[i] = (s2[len_s2 - i - 1] - '0');
          int bit = 1, s = 2;
          for(bit = 1; s < len_s1 + len_s2 - 1; bit++, s <<= 1);
          getRev(bit);
          fft(a, s, 1);
          fft(b, s, 1);
          for(int i = 0; i < s; i++)      a[i] = a[i]*b[i];
          fft(a, s, -1);

          for(int i = 0; i < s; i++){
              ans[i] += (int)(a[i].real + 0.5);
              ans[i + 1] += ans[i]/10;
              ans[i] %= 10;
          }

          int* p = ans + s;
          while(p != ans - 1 && *p == 0)  p--;
          if(p == ans - 1){
              printf("0\n");
          }else{
              while(p != ans - 1){
                  printf("%d", *p);
                  p--;
              }
              puts("");
          }
      }
      return 0;
  }