AC自动机的基础应用和进阶应用



更新(2018.8.24)

  1. 更新了模板,以防出现今天MLE + TLE的情况 QAQ
  2. 以前有道题目打错了还AC,现更正
  3. 删除了入门打法

前言

点进来以为是可以自动AC题目的东西的同学可以出去了(╯#-_-)╯╧═╧
AC自动机是多模匹配算法,相比之下KMP是单模匹配算法,其用了Trie树作为数据结构,并用了KMP中next数组的思想,不会Trie树的同学可以看上一篇文章(虽然看了也不会,本人写得太辣鸡_(:з」∠)_)
个人感觉有些难……
花了整整一个周啃AC自动机,最后写了前两道基础应用的题目和最后一道进阶应用的题目

  1. AC自动机详解
    http://www.cppblog.com/mythit/archive/2009/04/21/80633.html

病毒侵袭 - HDU 2896

题目大意
(略!)

思路
模板题

  #include <cstring>
  #include <cstdio>
  #include <algorithm>
  #include <set>
  using namespace std;
  const int N = 60000;
  const int ascii_size = 96;

  set<int> st;
  char str[10000 + 1];
  int  total_num;

  int  pos[N];
  int  nxt[N][ascii_size];
  int  fail[N];
  int  tot;

  int  que[N];
  int  head, tail;

  inline void newNode(int& x){
      x = tot++;
      memset(nxt[x], -1, sizeof(nxt[x]));
      fail[x] = 0;        //fail直接全部指向root节点,root节点为0
      pos[x] = 0;
  }

  inline void init(){
      total_num = 0;
      head = tail = 0;
      tot = 1;
      memset(nxt[0], -1, sizeof(nxt[0]));
      fail[0] = pos[0] = 0;
  }

  void push(char s[], int i){
      int p = 0;
      for(int i = 0; s[i]; i++){
          int idx = s[i] - 32;
          if(nxt[p][idx] == -1){
              newNode(nxt[p][idx]);
          }
          p = nxt[p][idx];
      }
      pos[p] = i;
  }

  void setFail(){
      for(int idx = 0; idx < ascii_size; idx++){
          if(nxt[0][idx] != -1){
              que[head++] = nxt[0][idx];
          }else{
              nxt[0][idx] = 0;    //不存在则直接指向root结点,符合匹配的规则
          }
      }
      while(tail != head){
          int p = que[tail++];
          for(int idx = 0; idx < ascii_size; idx++){
              if(~nxt[p][idx]){
                  //存在时设置Fail指针,因为设置过与root直接相连的结点的nxt值,所以不会为-1,不存在时会为0,
                  //直接指向root节点,符合匹配规则
                  fail[nxt[p][idx]] = nxt[fail[p]][idx];
                  que[head++] = nxt[p][idx];
              }else{
                  nxt[p][idx] = nxt[fail[p]][idx];  //不存在时直接指向nxt[fail[p]][idx]
              }
          }
      }
  }

  void query(char s[]){
      int p = 0;
      st.clear();
      for(int i = 0; s[i]; i++){
          int idx = s[i] - 32;
          p = nxt[p][idx];
          for(int q = p; q; q = fail[q]){
              if(pos[q])  st.insert(pos[q]);
          }
      }
  }

  int main(){
      int n;
      while(~scanf("%d", &n)){
          init();
          for(int i = 1; i <= n; i++){
              scanf("%s", str);
              push(str, i);
          }
          setFail();

          int m;
          scanf("%d", &m);
          for(int i = 1; i <= m; i++){
              scanf("%s", str);
              query(str);
              if(!st.empty()){
                  total_num++;
                  printf("web %d:", i);
                  for(set<int>::iterator it = st.begin(); it != st.end(); it++){
                      printf(" %d", * it);
                  }
                  puts("");
              }
          }
          printf("total: %d\n", total_num);
      }
      return 0;
  }


病毒侵袭持续中 - HDU 3065

题目大意
(继续略!)

思路
仍然算是模板题
记得搜到时就++该点记录的pos的cnt值

  #include <cstring>
  #include <cstdio>
  #include <algorithm>
  using namespace std;
  const int N = 5e4 + 15;
  const int ascii_size = 96;

  char in[1005][55];
  int  cnt[N];
  int  nxt[N][ascii_size];
  int  fail[N];
  int  pos[N];
  int  tot;

  int  que[N];
  int  head, tail;

  inline void newNode(int& x){
      x = tot++;
      memset(nxt[x], -1, sizeof(nxt[x]));
      fail[x] = 0;
      pos[x] = 0;
  }

  inline void init(){
      memset(cnt, 0, sizeof(cnt));
      head = tail = 0;
      tot = 1;
      memset(nxt[0], -1, sizeof(nxt[0]));
      fail[0] = pos[0] = 0;
  }

  void push(char s[], int j){
      int p = 0;
      for(int i = 0; s[i]; i++){
          int idx = s[i] - 32;
          if(nxt[p][idx] == -1){
              newNode(nxt[p][idx]);
          }
          p = nxt[p][idx];
      }
      pos[p] = j;
  }

  void setFail(){
      for(int idx = 0; idx < ascii_size; idx++){
          if(~nxt[0][idx]){
              que[head++] = nxt[0][idx];
          }else{
              nxt[0][idx] = 0;
          }
      }
      while(tail != head){
          int p = que[tail++];
          for(int idx = 0; idx < ascii_size; idx++){
              if(~nxt[p][idx]){
                  fail[nxt[p][idx]] = nxt[fail[p]][idx];
                  que[head++] = nxt[p][idx];
              }else{
                  nxt[p][idx] = nxt[fail[p]][idx];
              }
          }
      }
  }

  void query(){
      int p = 0;
      char ch;
      while((ch = getchar()) != '\n'){
          int idx = ch - 32;
          p = nxt[p][idx];
          for(int q = p; q; q = fail[q]){
              if(pos[q])  cnt[pos[q]]++;
          }
      }
  }

  int main(){
      int n;
      while(~scanf("%d", &n)){
          init();
          for(int i = 1; i <= n; i++){
              scanf("%s", in[i]);
              push(in[i], i);
          }
          getchar();
          setFail();

          query();
          for(int i = 1; i <= n; i++){
              if(cnt[i])  printf("%s: %d\n", in[i], cnt[i]);
          }

      }
      return 0;
  }


DNA Sequence - POJ 2778

题目大意
给定一些病毒串,问长度为n且不包含病毒串的DNA序列有多少种,结果对100000取模

思路
这题用到了离散数学的一个定理,

如果邻接矩阵中c(i, j)代表从节点i走到节点j走一步的不同走法,那么矩阵的k次方中的c’(i, j)代表从节点i走到节点j走k步的不同走法

所以这题可用AC自动机构建一个邻接矩阵,这个矩阵是排除了走向病毒串的矩阵,然后再进行矩阵快速幂,最后答案就是 sigma[i](c[0][i])

  #include <cstdio>
  #include <cstring>
  #include <iostream>
  #include <algorithm>
  using namespace std;
  typedef long long ll;
  const int N = 4;
  const int M = 100 + 5;
  const int MOD = 100000;

  int  nxt[M][N];
  bool mark[M];
  int  fail[M];
  int  tot;

  int  que[M];
  int  head, tail;

  struct Matrix{
      ll met[M][M];
  };

  Matrix multiply(const Matrix& a, const Matrix& b){
      Matrix ret;
      for(int i = 0; i < tot; i++){
          for(int j = 0; j < tot; j++){
              ret.met[i][j] = 0;
              for(int k = 0; k < tot; k++){
                   ret.met[i][j] += (a.met[i][k] * b.met[k][j]);
              }
              ret.met[i][j] %= MOD;    //先计算再求模,不然会TLE,这也说明了求模的确很耗时
          }
      }
      return ret;
  }

  Matrix quickPow(const Matrix& a, int b){
      Matrix base = a, ans;
      memset(&ans, 0, sizeof(ans));
      for(int i = 0; i < M; i++)  ans.met[i][i] = 1;
      while(b){
          if(b&1)     ans = multiply(ans, base);
          base = multiply(base, base);
          b >>= 1;
      }
      return ans;
  }

  void newNode(int& x){
      x = tot++;
      memset(nxt[x], -1, sizeof(nxt[x]));
      fail[x] = 0;
      mark[x] = true;
  }

  inline void init(){
      tot = 1;
      head = tail = 0;
      memset(nxt[0], -1, sizeof(nxt[0]));
      fail[0] = 0;
      mark[0] = true;
  }

  int cvt(char ch){
      if(ch == 'A')   return 0;
      if(ch == 'T')   return 1;
      if(ch == 'C')   return 2;
      else            return 3;
  }

  void build(char s[]){
      int p = 0;
      for(int i = 0; s[i]; i++){
          int idx = cvt(s[i]);
          if(nxt[p][idx] == -1){
              newNode(nxt[p][idx]);
          }
          mark[nxt[p][idx]] &= mark[p];   //从病毒串走出来的,就需要更新把mark值设为0
          p = nxt[p][idx];
      }
      mark[p] = false;
  }

  void autoMaton(){
      for(int idx = 0; idx < N; idx++){
          if(nxt[0][idx] != -1){
              que[head++] = nxt[0][idx];
          }else{
              nxt[0][idx] = 0;
          }
      }
      while(tail != head){
          int p = que[tail++];
          //Fail指针指向的点是病毒串,那么该点也是病毒串的点
          //因为Fail指针指向的点一定是该串的一个后缀
          mark[p] &= mark[fail[p]];
          for(int idx = 0; idx < N; idx++){
              if(nxt[p][idx] == -1){
                  nxt[p][idx] = nxt[fail[p]][idx];
              }else{
                  fail[nxt[p][idx]] = nxt[fail[p]][idx];
                  que[head++] = nxt[p][idx];
              }
          }
      }
  }

  ll res(ll b){
      ll ans = 0;
      Matrix a;
      memset(&a, 0, sizeof(a));
      for(int i = 0; i < tot; i++){
          for(int j = 0; j < N; j++){
              a.met[i][nxt[i][j]] += (mark[i] & mark[nxt[i][j]]);   //排除掉病毒串的点
          }
      }

      Matrix ret = quickPow(a, b);
      for(int i = 0; i < tot; i++){
          ans += ret.met[0][i];
      }
      return ans%MOD;
  }

  int main(){
      int n;
      ll b;
      char str[15];
      while(~scanf("%d%lld", &n, &b)){
          init();
          while(n--){
              scanf("%s", str);
              build(str);
          }
          autoMaton();
          printf("%lld\n", res(b));
      }
      return 0;
  }