字符串哈希

Hash 的思想

Hash 的核心思想在于,将输入映射到一个值域较小、可以方便比较的范围。

Warning

这里的“值域较小”在不同情况下意义不同。

哈希表 中,值域需要小到能够接受线性的空间与时间复杂度。

在字符串哈希中,值域需要小到能够快速比较( 10^9 10^{18} 都是可以快速比较的)。

同时,为了降低哈希冲突率,值域也不能太小。

我们定义一个把字符串映射到整数的函数 f ,这个 f 称为是 Hash 函数。

我们希望这个函数 f 可以方便地帮我们判断两个字符串是否相等。

具体来说,我们希望在 Hash 函数值不一样的时候,两个字符串一定不一样。

另外,反过来不需要成立。我们把这种条件称为是单侧错误。

我们需要关注的是什么?

时间复杂度和 Hash 的准确率。

通常我们采用的是多项式 Hash 的方法,即 f(s) = \sum s[i] \times b^i \pmod M

这里面的 b M 需要选取得足够合适才行,以使得 Hash 函数的值分布尽量均匀。

如果 b M 互质,在输入随机的情况下,这个 Hash 函数在 [0,M) 上每个值概率相等,此时单次比较的错误率为 \frac 1 M 。所以,哈希的模数一般会选用大质数。

Hash 的实现

参考代码:(效率低下的版本,实际使用时一般不会这么写)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
using std::string;

const int M = 1e9 + 7;
const int B = 233;

typedef long long ll;

int get_hash(const string& s) {
  int res = 0;
  for (int i = 0; i < s.size(); ++i) {
    res = (ll)(res * B + s[i]) % M;
  }
  return res;
}

bool cmp(const string& s, const string& t) {
  return get_hash(s) == get_hash(t);
}

Hash 的分析与改进

错误率

若进行 n 次比较,每次错误率 \dfrac 1 M ,那么总错误率是 1-\left(1-\dfrac 1 M\right)^n 。在随机数据下,若 M=10^9 + 7 , n=10^6 ,错误率约为 \dfrac 1{1000} ,并不是能够完全忽略不计的。

所以,进行字符串哈希时,经常会对两个大质数分别取模,这样的话哈希函数的值域就能扩大到两者之积,错误率就非常小了。

多次询问子串哈希

单次计算一个字符串的哈希值复杂度是 O(n) ,其中 n 为串长,与暴力匹配没有区别,如果需要多次询问一个字符串的子串的哈希值,每次重新计算效率非常低下。

一般采取的方法是对整个字符串先预处理出每个前缀的哈希值,将哈希值看成一个 b 进制的数对 M 取模的结果,这样的话每次就能快速求出子串的哈希了:

f_i(s) 表示 f(s[1..i]) ,那么 f(s[l..r])=f_r(s)-f_{l-1}(s) \times b^{r-l+1} ,其中 b^{r-l+1} 可以预处理出来。

这样的话,就可以在 O(n) 的预处理后每次 O(1) 地计算子串的哈希值了。

Hash 的应用

字符串匹配

求出模式串的哈希值后,求出文本串每个长度为模式串长度的子串的哈希值,分别与模式串的哈希值比较即可。

最长回文子串

二分答案,判断是否可行时枚举回文中心(对称轴),哈希判断两侧是否相等。需要分别预处理正着和倒着的哈希值。时间复杂度 O(n\log n)

这个问题可以使用 manacher 算法 O(n) 的时间内解决。

确定字符串中不同子字符串的数量

问题:给定长为 n 的字符串,仅由小写英文字母组成,查找该字符串中不同子串的数量。

为了解决这个问题,我们遍历了所有长度为 l=1,\cdots ,n 的子串。对于每个长度为 l ,我们将其 Hash 值乘以相同的 b 的幂次方,并存入一个数组中。数组中不同元素的数量等于字符串中长度不同的子串的数量,并此数字将添加到最终答案中。

为了方便起见,我们将使用 h [i] 作为 Hash 的前缀字符,并定义 h[0]=0

参考代码
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
int count_unique_substrings(string const& s) {
  int n = s.size();

  const int b = 31;
  const int m = 1e9 + 9;
  vector<long long> b_pow(n);
  b_pow[0] = 1;
  for (int i = 1; i < n; i++) b_pow[i] = (b_pow[i - 1] * b) % m;

  vector<long long> h(n + 1, 0);
  for (int i = 0; i < n; i++)
    h[i + 1] = (h[i] + (s[i] - 'a' + 1) * b_pow[i]) % m;

  int cnt = 0;
  for (int l = 1; l <= n; l++) {
    set<long long> hs;
    for (int i = 0; i <= n - l; i++) {
      long long cur_h = (h[i + l] + m - h[i]) % m;
      cur_h = (cur_h * b_pow[n - i - 1]) % m;
      hs.insert(cur_h);
    }
    cnt += hs.size();
  }
  return cnt;
}

例题

CF1200E Compress Words

给你若干个字符串,答案串初始为空。第 i 步将第 i 个字符串加到答案串的后面,但是尽量地去掉重复部分(即去掉一个最长的、是原答案串的后缀、也是第 i 个串的前缀的字符串),求最后得到的字符串。

字符串个数不超过 10^5 ,总长不超过 10^6

题解

每次需要求最长的、是原答案串的后缀、也是第 i 个串的前缀的字符串。枚举这个串的长度,哈希比较即可。

当然,这道题也可以使用 KMP 算法 解决。

参考代码
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include <cstdio>
#include <cstring>
#include <iostream>
#include <string>
using namespace std;
const int CN = 1e6 + 6;
const int M1 = 11431471;
const int B1 = 231;
const int M2 = 37101101;
const int B2 = 312;
int read() {
  int s = 0, ne = 1;
  char c = getchar();
  while (c < '0' || c > '9') ne = c == '-' ? -1 : 1, c = getchar();
  while (c >= '0' && c <= '9') s = (s << 1) + (s << 3) + c - '0', c = getchar();
  return s * ne;
}
int qp(int a, int b, int P) {
  int r = 1;
  while (b) {
    if (b & 1) r = 1ll * r * a % P;
    a = 1ll * a * a % P;
    b >>= 1;
  }
  return r;
}
int H1[CN], H2[CN], l1 = 0;
void add1(int x) {
  H1[l1 + 1] = (1ll * H1[l1] * B1 % M1 + x) % M1,
          H2[l1 + 1] = (1ll * H2[l1] * B2 % M2 + x) % M2;
  l1++;
}
int h1[CN], h2[CN], l2 = 0;
void add2(int x) {
  h1[l2 + 1] = (1ll * h1[l2] * B1 % M1 + x) % M1,
          h2[l2 + 1] = (1ll * h2[l2] * B2 % M2 + x) % M2;
  l2++;
}
int get(int* h, int l, int r, int b, int m) {
  return 1ll * (h[r] - 1ll * h[l - 1] * qp(b, r - l + 1, m) % m + m) % m;
}
int n, len;
char cur[CN], nxt[CN];
int main() {
  n = read() - 1;
  cin >> cur;
  len = strlen(cur);
  for (int i = 0; i < len; i++) add1(cur[i] - '0');
  while (n--) {
    cin >> nxt;
    int l = strlen(nxt);
    for (int i = 0; i < l; i++) add2(nxt[i] - '0');
    int p = 0;
    for (int i = 0; i < l && i < len; i++) {
      int G1 = get(H1, len - i, len, B1, M1),
          G2 = get(H2, len - i, len, B2, M2);
      int g1 = get(h1, 1, i + 1, B1, M1), g2 = get(h2, 1, i + 1, B2, M2);
      if (G1 == g1 && G2 == g2) p = i + 1;
    }

    for (int i = len; i < len - p + l; i++)
      cur[i] = nxt[i - len + p], add1(cur[i] - '0');
    len = len - p + l, cur[len] = '\0';
    l2 = 0;
  }
  cout << cur;
}

本页面部分内容译自博文 строковый хеш 与其英文翻译版 String Hashing 。其中俄文版版权协议为 Public Domain + Leave a Link;英文版版权协议为 CC-BY-SA 4.0。


评论