后缀数组:快速查找模式串的所有出现位置
后缀数组
从定义上来说,后缀数组(用 sa 表示)包含了字符串 s 所有后缀的起始索引,它是一个 int 数组,sa[i] 表示对应后缀的起始索引。注意后缀已经按照字典序排好。
以字符串 fizzbuzz 为例,它的后缀数组是 4 0 1 5 7 3 6 2,对应关系如下
后缀数组 sa |
对应后缀 |
|---|---|
| 4 | buzz |
| 0 | fizzbuzz |
| 1 | izzbuzz |
| 5 | uzz |
| 7 | z |
| 3 | zbuzz |
| 6 | zz |
| 2 | zzbuzz |
生成算法
朴素算法
用 $N$ 表示字符串的长度,朴素算法只需要生成所有后缀,然后对后缀进行排序即可。排序的时间复杂度是 $O(Nlog\ N)$,每一次字符串比较需要 $O(N)$,所以朴素算法的时间复杂度是
$$ 1\times O(Nlog\ N)\times O(N)=\boxed{O(N^2log\ N)} $$
实现很简单,这里不展开
倍增法
倍增法简单来说,就是不断执行如下步骤直到 $2^k\ge n$
- 第 1 轮,只看每个后缀的前 $2^1$ 字符,对后缀数组排序
- 第 2 轮,只看每个后缀的前 $2^2$ 字符,对后缀数组排序
- …
- 第 $k$ 轮,只看每个后缀的前 $2^k$ 字符,对后缀数组排序
可以看到,每一次都从 $2^{k-1}\rightarrow 2^k$,这也是名字中“倍增”的含义
倍增法的本质是:用已经排序好的短字符串,去比较更长的字符串。举例来说,如何比较 2 个长度为 $2^k$ 的字符串?我们可以先比较他们的前 $2^{k-1}$ 个字符,然后再比较后 $2^{k-1}$ 字符
为了复用之前的比较结果,倍增法需要引入一个额外数组,一般称之为 ranking 数组(用 ra 表示),它也是一个 int 数组。在第 $k$ 轮执行开始前,ra[sa[i]] 表示以 sa[i] 开始的后缀的前 $2^{k-1}$ 个字符在所有后缀中的字典序排序
这里可以推导一下前面提到的“第 $k$ 轮,只看每个后缀的前 $2^k$ 字符,对后缀数组排序”是怎么工作的:根据前面的讨论,利用 ra 数组,可以这么干
- 先比较他们的前 $2^{k-1}$ 个字符,那么就是比较 $ra\big[sa[i]\big]$
- 然后再比较后 $2^{k-1}$ 字符,那么就是比较 $ra\big[sa[i]+2^{k-1}\big]$
换言之,只看前 $2^k$ 字符比较字符串被转化为了下面这个 tuple 的排序:
$$ \Big( ra\big[sa[i]],ra\big[sa[i]+2^{k-1}\big] \Big) $$
而 $ra$ 数组的这两个值在上一轮都已经求好了
最后分析一下倍增法的时间复杂度:因为排序每次都在倍增,所以只需要比较 $O(log\ N)$ 轮,每一轮的比较时间复杂度是 $O(Nlog\ N)\times O(1)=O(Nlog\ N)$,因为现在不是字符串比较,而是用 ranking pair 比较
$$ O(log\ N)\times O(Nlog\ N)\times O(1)=\boxed{O(Nlog^2\ N)} $$
更好的倍增法
上面的倍增法还可以利用基数排序进行优化,这样排序的时间复杂度就从 $O(Nlog\ N)$ 变成了 $O(N)$
比如比较 (1,3), (2,2), (3,2) 这些 ranking pair,基数排序会这么做:
- 先基于第二个位置做基数排序得到:
(2,2), (3,2), (1,3) - 再基于第一个位置做基数排序得到:
(1,3), (2,2), (3,2)
利用基数排序优化后,后缀数组生成的时间复杂度就是
$$ O(log\ N)\times O(N)\times O(1)=\boxed{O(Nlog\ N)} $$
倍增法代码
倍增法
代码实现上需要注意:ra 数组的访问可能会越界,越界时应返回一个比所有可能的 rank 更小的数字,这样才符合字典序比较的。下面的实现中所有合法的 ranking $\in[1,)$,所以返回的是 $0$
#include <algorithm>
#include <iostream>
#include <string>
#include <vector>
using std::cin;
using std::cout;
using std::fill;
using std::string;
using std::vector;
int main() {
string s;
cin >> s;
// Init
vector<int> sa(s.size());
vector<int> ra(s.size());
for (int i = 0; i < s.size(); i++) {
sa[i] = i;
ra[i] = s[i] - 'a' + 1; // valid rank starts from 1
}
vector<int> new_sa(s.size());
for (int w = 1; w < s.size(); w <<= 1) {
// Sort suffix array
sort(sa.begin(), sa.end(), [&](int x, int y) {
if (ra[x] != ra[y]) {
return ra[x] < ra[y];
}
int rx{x + w < s.size() ? ra[x + w] : 0};
int ry{y + w < s.size() ? ra[y + w] : 0};
return rx < ry;
});
// Reranking
vector<int> new_ra(s.size(), 1);
int current_rank{1};
new_ra[sa[0]] = 1;
for (int i = 1; i < s.size(); i++) {
int cur{sa[i]};
int prev{sa[i - 1]};
new_ra[sa[i]] =
(ra[cur] == ra[prev] && (cur + w < s.size() ? ra[cur + w] : 0) ==
(prev + w < s.size() ? ra[prev + w] : 0)
? current_rank
: ++current_rank);
}
ra = new_ra;
}
for (int i = 0; i < s.size(); i++) {
cout << sa[i] << (i == s.size() - 1 ? "" : " ");
}
}
更好的倍增法
#include <algorithm>
#include <iostream>
#include <string>
#include <vector>
using std::cin;
using std::cout;
using std::fill;
using std::max;
using std::string;
using std::vector;
// Sort `sa` using `ra[i + delta]`
void radix_sort(vector<int> &sa, vector<int> &new_sa, const vector<int> &ra,
int n, int delta) {
// The valid rank for lowercase letters is 1 ~ 26
// or the length of the input string + 1
int max_rank{max(27, n + 1)};
vector<int> count(max_rank, 0);
// counting
for (int i = 0; i < n; i++) {
int key{sa[i] + delta < n ? ra[sa[i] + delta] : 0};
count[key]++;
}
for (int i = 0, prefix_sum = 0; i < max_rank; i++) {
int temp{count[i]};
count[i] = prefix_sum;
prefix_sum += temp;
}
// invariant: all ranking pair in the same count[i] share the same rank
for (int i = 0; i < n; i++) {
int key{sa[i] + delta < n ? ra[sa[i] + delta] : 0};
new_sa[count[key]++] = sa[i];
}
sa = new_sa;
}
int main() {
string s;
cin >> s;
// Init
vector<int> sa(s.size());
vector<int> ra(s.size());
for (int i = 0; i < s.size(); i++) {
sa[i] = i;
ra[i] = s[i] - 'a' + 1; // valid rank starts from 1
}
vector<int> new_sa(s.size());
for (int w = 1; w < s.size(); w <<= 1) {
// Sort suffix array
radix_sort(sa, new_sa, ra, s.size(), w);
radix_sort(sa, new_sa, ra, s.size(), 0);
// Reranking
vector<int> new_ra(s.size(), 1);
int current_rank{1};
new_ra[sa[0]] = 1;
for (int i = 1; i < s.size(); i++) {
int cur{sa[i]};
int prev{sa[i - 1]};
new_ra[sa[i]] =
(ra[cur] == ra[prev] && (cur + w < s.size() ? ra[cur + w] : 0) ==
(prev + w < s.size() ? ra[prev + w] : 0)
? current_rank
: ++current_rank);
}
ra = new_ra;
}
for (int i = 0; i < s.size(); i++) {
cout << sa[i] << (i == s.size() - 1 ? "" : " ");
}
}
倍增法例子
仍然用之前的 fizzbuzz 为例,在初始化之后,第一轮排序之前是这样子的
后缀数组 sa[i] |
ranking 数组 ra[sa[i]] |
ranking 数组 ra[sa[i]+1] |
对应的后缀 |
|---|---|---|---|
| 0 | 6 | 9 | fizzbuzz |
| 1 | 9 | 26 | izzbuzz |
| 2 | 26 | 26 | zzbuzz |
| 3 | 26 | 2 | zbuzz |
| 4 | 2 | 21 | buzz |
| 5 | 21 | 26 | uzz |
| 6 | 26 | 26 | zz |
| 7 | 26 | 0 | z |
此时第一轮排序,要看的是前 $2^1=2$ 个字符,对应的 ranking pair 是 (ra[sa[i]], ra[sa[i]+1]),排序结果是
后缀数组 sa[i] |
ranking 数组 ra[sa[i]] |
ranking 数组 ra[sa[i]+1] |
对应的后缀 |
|---|---|---|---|
| 4 | 2 | 21 | buzz |
| 0 | 6 | 9 | fizzbuzz |
| 1 | 9 | 26 | izzbuzz |
| 5 | 21 | 26 | uzz |
| 7 | 26 | 0 | z |
| 3 | 26 | 2 | zbuzz |
| 2 | 26 | 26 | zzbuzz |
| 6 | 26 | 26 | zz |
接下来重新分配 ranking(从 1 开始),准备开始第二轮排序
后缀数组 sa[i] |
ranking 数组 ra[sa[i]] |
ranking 数组 ra[sa[i]+2] |
对应的后缀 |
|---|---|---|---|
| 4 | 1 | 7 | buzz |
| 0 | 2 | 7 | fizzbuzz |
| 1 | 3 | 6 | izzbuzz |
| 5 | 4 | 5 | uzz |
| 7 | 5 | 0 | z |
| 3 | 6 | 4 | zbuzz |
| 2 | 7 | 1 | zzbuzz |
| 6 | 7 | 0 | zz |
第二轮排序看前 $2^2=4$ 个字符,对应的 ranking pair 是 (ra[sa[i]], ra[sa[i]+2]),排序完之后是
后缀数组 sa[i] |
ranking 数组 ra[sa[i]] |
ranking 数组 ra[sa[i]+2] |
对应的后缀 |
|---|---|---|---|
| 4 | 1 | 7 | buzz |
| 0 | 2 | 7 | fizzbuzz |
| 1 | 3 | 6 | izzbuzz |
| 5 | 4 | 5 | uzz |
| 7 | 5 | 0 | z |
| 3 | 6 | 4 | zbuzz |
| 6 | 7 | 0 | zz |
| 2 | 7 | 1 | zzbuzz |
接下来重新分配 ranking(从 1 开始),准备开始第三轮排序
后缀数组 sa[i] |
ranking 数组 ra[sa[i]] |
ranking 数组 ra[sa[i]+4] |
对应的后缀 |
|---|---|---|---|
| 4 | 1 | 0 | buzz |
| 0 | 2 | 1 | fizzbuzz |
| 1 | 3 | 4 | izzbuzz |
| 5 | 4 | 0 | uzz |
| 7 | 5 | 0 | z |
| 3 | 6 | 5 | zbuzz |
| 6 | 7 | 0 | zz |
| 2 | 8 | 7 | zzbuzz |
第三轮排序看前 $2^3=8$ 个字符,对应的 ranking pair 是 (ra[sa[i]], ra[sa[i]+4]),排序完之后和上一轮无变化,说明已经排序完了
应用
后缀数组可以用于查找模式串 P 在字符串 S 中的所有出现位置,这是因为它利用到了这个性质:如果模式串 P 是字符串 S 的子串,那么它肯定是 S 的某些后缀的前缀
算法
- 二分搜索后缀数组,找到前缀是模式串
P的下界L - 二分搜索后缀数组,找到前缀是模式串
P的上界R
此时位于 [L, R] 的所有后缀的前缀都是 P,用后缀数组就可以得到出现的位置
时间复杂度:
$$ O(Llog\ N) $$
其中
- $L$ 是字符串 P 的长度
- $N$ 是字符串 S 的长度