目录

后缀数组:快速查找模式串的所有出现位置

从定义上来说,后缀数组(用 sa 表示)包含了字符串 s 所有后缀的起始索引,它是一个 int 数组,sa[i] 表示对应后缀的起始索引。注意后缀已经按照字典序排好。

以字符串 fizzbuzz 为例,它的后缀数组是 4 0 1 5 7 3 6 2,对应关系如下

后缀数组 sa 对应后缀
4 buzz
0 fizzbuzz
1 izzbuzz
5 uzz
7 z
3 zbuzz
6 zz
2 zzbuzz

用 $N$ 表示字符串的长度,朴素算法只需要生成所有后缀,然后对后缀进行排序即可。排序的时间复杂度是 $O(Nlog\ N)$,每一次字符串比较需要 $O(N)$,所以朴素算法的时间复杂度是

$$ 1\times O(Nlog\ N)\times O(N)=\boxed{O(N^2log\ N)} $$

实现很简单,这里不展开

Note

我看过很多倍增法的解释 12,但我发现都不大好理解,下面是我自己整理的理解角度

倍增法简单来说,就是不断执行如下步骤直到 $2^k\ge n$

  • 第 1 轮,只看每个后缀的前 $2^1$ 字符,对后缀数组排序
  • 第 2 轮,只看每个后缀的前 $2^2$ 字符,对后缀数组排序
  • 第 $k$ 轮,只看每个后缀的前 $2^k$ 字符,对后缀数组排序

可以看到,每一次都从 $2^{k-1}\rightarrow 2^k$,这也是名字中“倍增”的含义

倍增法的本质是:用已经排序好的短字符串,去比较更长的字符串。举例来说,如何比较 2 个长度为 $2^k$ 的字符串?我们可以比较他们的前 $2^{k-1}$ 个字符,然后再比较后 $2^{k-1}$ 字符

为了复用之前的比较结果,倍增法需要引入一个额外数组,一般称之为 ranking 数组(用 ra 表示),它也是一个 int 数组。在第 $k$ 轮执行开始前ra[sa[i]] 表示以 sa[i] 开始的后缀的前 $2^{k-1}$ 个字符在所有后缀中的字典序排序

这里可以推导一下前面提到的“第 $k$ 轮,只看每个后缀的前 $2^k$ 字符,对后缀数组排序”是怎么工作的:根据前面的讨论,利用 ra 数组,可以这么干

  • 比较他们的前 $2^{k-1}$ 个字符,那么就是比较 $ra\big[sa[i]\big]$
  • 然后再比较后 $2^{k-1}$ 字符,那么就是比较 $ra\big[sa[i]+2^{k-1}\big]$

换言之,只看前 $2^k$ 字符比较字符串被转化为了下面这个 tuple 的排序:

$$ \Big( ra\big[sa[i]],ra\big[sa[i]+2^{k-1}\big] \Big) $$

而 $ra$ 数组的这两个值在上一轮都已经求好了

最后分析一下倍增法的时间复杂度:因为排序每次都在倍增,所以只需要比较 $O(log\ N)$ 轮,每一轮的比较时间复杂度是 $O(Nlog\ N)\times O(1)=O(Nlog\ N)$,因为现在不是字符串比较,而是用 ranking pair 比较

$$ O(log\ N)\times O(Nlog\ N)\times O(1)=\boxed{O(Nlog^2\ N)} $$

上面的倍增法还可以利用基数排序进行优化,这样排序的时间复杂度就从 $O(Nlog\ N)$ 变成了 $O(N)$

比如比较 (1,3), (2,2), (3,2) 这些 ranking pair,基数排序会这么做:

  • 先基于第二个位置做基数排序得到:(2,2), (3,2), (1,3)
  • 再基于第一个位置做基数排序得到:(1,3), (2,2), (3,2)

利用基数排序优化后,后缀数组生成的时间复杂度就是

$$ O(log\ N)\times O(N)\times O(1)=\boxed{O(Nlog\ N)} $$

代码实现上需要注意:ra 数组的访问可能会越界,越界时应返回一个比所有可能的 rank 更小的数字,这样才符合字典序比较的。下面的实现中所有合法的 ranking $\in[1,)$,所以返回的是 $0$

#include <algorithm>
#include <iostream>
#include <string>
#include <vector>

using std::cin;
using std::cout;
using std::fill;
using std::string;
using std::vector;

int main() {
  string s;
  cin >> s;

  // Init
  vector<int> sa(s.size());
  vector<int> ra(s.size());
  for (int i = 0; i < s.size(); i++) {
    sa[i] = i;
    ra[i] = s[i] - 'a' + 1; // valid rank starts from 1
  }

  vector<int> new_sa(s.size());
  for (int w = 1; w < s.size(); w <<= 1) {
    // Sort suffix array
    sort(sa.begin(), sa.end(), [&](int x, int y) {
      if (ra[x] != ra[y]) {
        return ra[x] < ra[y];
      }
      int rx{x + w < s.size() ? ra[x + w] : 0};
      int ry{y + w < s.size() ? ra[y + w] : 0};
      return rx < ry;
    });

    // Reranking
    vector<int> new_ra(s.size(), 1);
    int current_rank{1};
    new_ra[sa[0]] = 1;
    for (int i = 1; i < s.size(); i++) {
      int cur{sa[i]};
      int prev{sa[i - 1]};
      new_ra[sa[i]] =
          (ra[cur] == ra[prev] && (cur + w < s.size() ? ra[cur + w] : 0) ==
                                      (prev + w < s.size() ? ra[prev + w] : 0)
               ? current_rank
               : ++current_rank);
    }
    ra = new_ra;
  }
  for (int i = 0; i < s.size(); i++) {
    cout << sa[i] << (i == s.size() - 1 ? "" : " ");
  }
}
#include <algorithm>
#include <iostream>
#include <string>
#include <vector>

using std::cin;
using std::cout;
using std::fill;
using std::max;
using std::string;
using std::vector;

// Sort `sa` using `ra[i + delta]`
void radix_sort(vector<int> &sa, vector<int> &new_sa, const vector<int> &ra,
                int n, int delta) {
  // The valid rank for lowercase letters is 1 ~ 26
  // or the length of the input string + 1
  int max_rank{max(27, n + 1)};
  vector<int> count(max_rank, 0);

  // counting
  for (int i = 0; i < n; i++) {
    int key{sa[i] + delta < n ? ra[sa[i] + delta] : 0};
    count[key]++;
  }

  for (int i = 0, prefix_sum = 0; i < max_rank; i++) {
    int temp{count[i]};
    count[i] = prefix_sum;
    prefix_sum += temp;
  }

  // invariant: all ranking pair in the same count[i] share the same rank
  for (int i = 0; i < n; i++) {
    int key{sa[i] + delta < n ? ra[sa[i] + delta] : 0};
    new_sa[count[key]++] = sa[i];
  }
  sa = new_sa;
}

int main() {
  string s;
  cin >> s;

  // Init
  vector<int> sa(s.size());
  vector<int> ra(s.size());
  for (int i = 0; i < s.size(); i++) {
    sa[i] = i;
    ra[i] = s[i] - 'a' + 1; // valid rank starts from 1
  }

  vector<int> new_sa(s.size());
  for (int w = 1; w < s.size(); w <<= 1) {
    // Sort suffix array
    radix_sort(sa, new_sa, ra, s.size(), w);
    radix_sort(sa, new_sa, ra, s.size(), 0);

    // Reranking
    vector<int> new_ra(s.size(), 1);
    int current_rank{1};
    new_ra[sa[0]] = 1;
    for (int i = 1; i < s.size(); i++) {
      int cur{sa[i]};
      int prev{sa[i - 1]};
      new_ra[sa[i]] =
          (ra[cur] == ra[prev] && (cur + w < s.size() ? ra[cur + w] : 0) ==
                                      (prev + w < s.size() ? ra[prev + w] : 0)
               ? current_rank
               : ++current_rank);
    }
    ra = new_ra;
  }
  for (int i = 0; i < s.size(); i++) {
    cout << sa[i] << (i == s.size() - 1 ? "" : " ");
  }
}

仍然用之前的 fizzbuzz 为例,在初始化之后,第一轮排序之前是这样子的

后缀数组 sa[i] ranking 数组 ra[sa[i]] ranking 数组 ra[sa[i]+1] 对应的后缀
0 6 9 fizzbuzz
1 9 26 izzbuzz
2 26 26 zzbuzz
3 26 2 zbuzz
4 2 21 buzz
5 21 26 uzz
6 26 26 zz
7 26 0 z

此时第一轮排序,要看的是前 $2^1=2$ 个字符,对应的 ranking pair 是 (ra[sa[i]], ra[sa[i]+1]),排序结果是

后缀数组 sa[i] ranking 数组 ra[sa[i]] ranking 数组 ra[sa[i]+1] 对应的后缀
4 2 21 buzz
0 6 9 fizzbuzz
1 9 26 izzbuzz
5 21 26 uzz
7 26 0 z
3 26 2 zbuzz
2 26 26 zzbuzz
6 26 26 zz

接下来重新分配 ranking(从 1 开始),准备开始第二轮排序

后缀数组 sa[i] ranking 数组 ra[sa[i]] ranking 数组 ra[sa[i]+2] 对应的后缀
4 1 7 buzz
0 2 7 fizzbuzz
1 3 6 izzbuzz
5 4 5 uzz
7 5 0 z
3 6 4 zbuzz
2 7 1 zzbuzz
6 7 0 zz

第二轮排序看前 $2^2=4$ 个字符,对应的 ranking pair 是 (ra[sa[i]], ra[sa[i]+2]),排序完之后是

后缀数组 sa[i] ranking 数组 ra[sa[i]] ranking 数组 ra[sa[i]+2] 对应的后缀
4 1 7 buzz
0 2 7 fizzbuzz
1 3 6 izzbuzz
5 4 5 uzz
7 5 0 z
3 6 4 zbuzz
6 7 0 zz
2 7 1 zzbuzz

接下来重新分配 ranking(从 1 开始),准备开始第三轮排序

后缀数组 sa[i] ranking 数组 ra[sa[i]] ranking 数组 ra[sa[i]+4] 对应的后缀
4 1 0 buzz
0 2 1 fizzbuzz
1 3 4 izzbuzz
5 4 0 uzz
7 5 0 z
3 6 5 zbuzz
6 7 0 zz
2 8 7 zzbuzz

第三轮排序看前 $2^3=8$ 个字符,对应的 ranking pair 是 (ra[sa[i]], ra[sa[i]+4]),排序完之后和上一轮无变化,说明已经排序完了

后缀数组可以用于查找模式串 P 在字符串 S 中的所有出现位置,这是因为它利用到了这个性质:如果模式串 P 是字符串 S 的子串,那么它肯定是 S 的某些后缀的前缀

算法

  • 二分搜索后缀数组,找到前缀是模式串 P下界 L
  • 二分搜索后缀数组,找到前缀是模式串 P上界 R

此时位于 [L, R] 的所有后缀的前缀都是 P,用后缀数组就可以得到出现的位置

时间复杂度

$$ O(Llog\ N) $$

其中

  • $L$ 是字符串 P 的长度
  • $N$ 是字符串 S 的长度