KMP算法

前缀数组讲解 KMP回退原理讲解  印度阿三KMP演示视频 

规定:本文所有的字符串和指针下标都是从0开始

算法引入

假设有字符串 s1s2 ,长度分别为 nm 。求解 s2 是否为 s1 的子串这个问题中,可称为 s1 为匹配串, s2 为模式串。即解决模式串 s2 是否能够在匹配串 s1 成功匹配问题

朴素算法就是单纯的暴力枚举尝试匹配从左到右一个个匹配,如果这个过程中有某个字符不匹配,就跳回去,将模式串向右移动一位,分别为两个字符串设两个匹配指针

之后我们只需要比较 i 指针指向的字符和 j 指针指向的字符是否一致。如果一致就都向后移动,如果不一致如图

AE 不一致,那就把 i 指针移回开始匹配位置的后一位(这里是第 1 位),j移动到模式串的第 0 位,然后又重新开始下一轮的匹配

朴素算法简单但是太慢,时间复杂度高达 O(n*m) 。如果是人为进行寻找绝对不会把 i 移动到第 1 位进行匹配尝试,因为先前尝试的时候模式串的前三位都成功匹配了而且模式串只有开头一个 A ,所以不难知道匹配第 2 ,3位肯定不和模式串的 A 匹配(移动过去肯定不匹配)。那么如果能够忽略这种无意义的匹配尝试减少 i的回溯不就可以加快匹配了吗?

KMP 算法就是利用这个想法产生的,它的时间复杂度仅为 O(n+m) 的算法,既可以求出字符串 s2 是不是字符串 s1 的子串,又可以求出字符串 s2 在字符串 s1 中的出现的次数,每次出现在什么位置

该算法主要有两个操作:首先用 O(m) 的复杂度求出一个 next 数组,记录了关于模式串 s2 的信息,再用 O(n) 的复杂度快速求出模式串串 s2 在匹配串 s1 出现的位置

步骤一:求 next 数组

注意:在 KMP算法中,有很多定义 next 数组的方式。这里只选一种讲解,其他的几种可以使相似的方法,在使用KMP算法的时候一定要注意 next数组的定义

next 数组处理对象是模式串自身,具体什么作用看了体会就知道了

本文采用的定义是:若 next[j]==k ,则 k 为满足以下条件的最大值:s2[0..k]s2[j-k+1..j] 相同,且 s2[k+1]s2[j+1] 不同。若不存在这样的 k ,则 next[j]=-1

next 数组若是从 -1 开始一般记录的是匹配前缀末尾元素在模式串中的位置,若是从 0 开始一般记录的就是匹配的前缀的长度

以模式串 ababaab 为例

模式串子串 next 数组
a nex[0]=-1
ab nex[1]=-1
aba nex[2]=0
abab nex[3]=1
abAba nex[4]=2
ababaa nex[5]=0
ababaab nex[6]=1

在求解 next 数组时,暴⼒的算法效率不满⾜要求,这⾥使⽤递推的思想求解,也就是已知第0位的值为-1,计算当前位置的 next 值时,要根据前⼀个字符的最⻓前后的值来判断,这基于 next 数组的意义,它的值始终是该子串的最长匹配前缀的最后⼀位下标

更具体的说,每当我们要计算当前长度子串的 next 值,我们已经知道上⼀位的 next
如果此时新增的字符和上⼀位最长前缀的下⼀个字符相同,那么无疑 nex[i]=nex[i-1]+1 ,这样可以保证nex值最大,否则新增的字符不匹配了,我们就要找次长的匹配前缀了,直到找到的前长度的下⼀位和新增的字符相等

模式串子串 next 数组
a i=0,nex[0]=-1
ab i=1,str[nex[0]+1]!=str[1],nex[1]=-1
aba i=2,str[nex[1]+1]==str[2],nex[2]=0
abab i=3,str[nex[2]+1]==str[3],nex[3]=1
abAba i=4,str[nex[3]+1]==str[4],nex[4]=2
ababaa i=5,str[nex[4]+1]!=str[5]->str[nex[0]+1]==str[5],nex[5]=0
ababaab i=6,str[nex[5]+1]==str[6],nex[6]=1

ababa 子串求 nex[4] 为例,nex[3]=1 并且此时最长前缀的末尾元素的后一个字符为 a 恰好和当前新增字符 a 相同,所以 nex[4]=nex[3]+1=2

再以 ababaanex[5] 子串为,nex[4]=2 但是此时最长前缀的末尾元素的后一个字符 b 和当前新增字符 a 不同,所以再向前找次为的最长前缀nex[nex[4]]=0并且此时最长前缀的末尾元素的后一个字符为 a 和新增字符相同了所以 nex[5]=nex[0]+1=0

注意:根据上面模拟可以看出对于字符串来说,求解 nex[] 数组时不把本身长度作为最长前缀,因为求解 nex[] 时是在当前位置前面的字符串中找最长前缀

例如:对于 abcd 的末尾 nex[3] 来说 k=3 也可以保证前缀和后缀相同但是求解保证了该情况不算在内,这是合理的也是明显的(因为进行匹配操作时这个情况不可以使用)

模板

1
2
3
4
5
6
7
8
9
int nex[N];
void getnex(char str[],int len){
nex[0]=-1;
for(int i=1,j=-1;i<len;i++){
while(j!=-1&&str[i]!=str[j+1])j=nex[j];//不符合回退
if(str[i]==str[j+1])j++;//符合前进
nex[i]=j;//记录当前位置的最长前缀的最后一个字符的位置
}
}

步骤二:匹配模式串和匹配串

借助求出的关于模式串 nex 数组进行如下操作:利用已经部分匹配这个有效信息,保持 i 指针不回溯,通过修改 j 指针,让模式串尽量地移动到有效的位置

所以,整个 KMP 的重点就在于当某一个字符与主串不匹配时,我们应该知道j指针要移动到哪?而 nex 数组的存在恰好解决了这个问题。看两个例子就找到规律了:

例子一:

CD 不匹配了,保持 i 不回溯,我们显然要把 j 移动到第 1 位,为什么?因为前面有一个 A 相同可以匹配上

例子二:

CB 不匹配了,保持 i 不回溯,把 j 指针移动到第 2 位,因为前面有两个字母是一样的

其实 j 的移动就是移到了之前求出的 nex[] 的位置,但是有点区别,本文转移应该按照新增字符的前一个位置的 nex 转移,所以不妨让 j 指针错一个位置也就是从 -1 位置开始,尝试匹配的是 tex=[i]patt[j+1]

模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
//判断是否是子串
bool kmp1(char tex[],char patt[]){
int len1=strlen(tex),len2=strlen(patt);
//预处理求出patt的前缀函数
getnex(patt,len2);
for(int i=0,j=-1;i<len1;i++){
//如果不匹配
while(j!=-1&&tex[i]!=patt[j+1])j=nex[j];
//匹配到了,那么下一个
if(tex[i]==patt[j+1])j++;
if(j==len2-1)return true;
}
return false;
}

//统计patt在tex中出现的次数
int kmp2(char tex[],char patt[]){
int len1=strlen(tex),len2=strlen(patt),cnt=0;
getnex(patt,len2);
for(int i=0,j=-1;i<len1;i++){
while(j!=-1&&tex[i]!=patt[j+1])j=nex[j];
if(patt[j+1]==tex[i])j++;
//匹配成功后次数+1,同时j回推继续匹配
if(j==len2-1)cnt++,j=nex[j];
}
return cnt;
}
//变式:如果题目要求匹配的字符串位置不能重叠,那么匹配成功后j就不能回溯

模板

模板题:给定一个字符串 A 和一个字符串 B ,求 BA 中的出现次数。AB 中的字符均为英语大写字母或小写字母。 A 中不同位置出现的 B 可重叠。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e6+10;

int nex[N];
void getnex(char str[],int len){
nex[0]=-1;
for(int i=1,j=-1;i<len;i++){
while(j!=-1&&str[i]!=str[j+1])j=nex[j];//不符合回退
if(str[i]==str[j+1])j++;//符合前进
nex[i]=j;//记录当前位置的最长前缀的最后一个字符的位置
}
}

//判断是否是子串
bool kmp1(char tex[],char patt[]){
int len1=strlen(tex),len2=strlen(patt);
//预处理求出patt的前缀函数
getnex(patt,len2);
for(int i=0,j=-1;i<len1;i++){
//如果不匹配
while(j!=-1&&tex[i]!=patt[j+1])j=nex[j];
//匹配到了,那么下一个
if(tex[i]==patt[j+1])j++;
if(j==len2-1)return true;
}
return false;
}

int kmp2(char tex[],char patt[]){
int len1=strlen(tex),len2=strlen(patt),cnt=0;
getnex(patt,len2);
for(int i=0,j=-1;i<len1;i++){
while(j!=-1&&tex[i]!=patt[j+1])j=nex[j];
if(patt[j+1]==tex[i])j++;
//匹配成功后次数+1,同时j回推继续匹配
if(j==len2-1)cnt++,j=nex[j];
}
return cnt;
}

signed main()
{
ios::sync_with_stdio(false),cin.tie(0);
char a[N],b[N];
cin>>a>>b;
getnex(b,strlen(b));
cout<<kmp2(a,b)<<endl;
//system("pause");
return 0;
}

拓展KMP算法

参考博客

拓展KMP算法时间复杂度也是 O(n+m) ,它是由 KMP 算法拓展而来的而另一种快速匹配方法,同样是先求模式串本身的 next 数组但是匹配是在借助 next 数组计算出来的 extend 数组中寻找满足 extend[i]==m

模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include<bits/stdc++.h>
using namespace std;

/* 求解T中next[],注释参考GetExtend()*/
void GetNext(string &T,int &m,int next[]){
int a=0,p=0;
next[0]=m;
for(int i=1;i<m;i++){
if(i>=p||i+next[i-a]>=p){
if(i>=p)p=i;
while(p<m&&T[p]==T[p-i])p++;
next[i]=p-i;
a=i;
}
else next[i]=next[i-a];
}
}

/* 求解extend[]*/
void GetExtend(string &S,int &n,string &T,int &m,int extend[],int next[]){
int a=0,p=0;
GetNext(T,m,next);
for(int i=0;i<n;i++){
if(i>=p||i+next[i-a]>=p){//i>=p的作用:举个典型例子,S和T无一字符相同
if(i>=p)p=i;
while(p<n&&p-i<m&&S[p]==T[p-i])p++;
extend[i]=p-i;
a=i;
}
else extend[i] = next[i - a];
}
}

signed main()
{
int next[100];
int extend[100];
string S,T;
int n,m;
while (cin>>S>>T){//匹配串和模式串
n=S.size();
m=T.size();
GetExtend(S,n,T,m,extend,next);
//打印next
cout<< "next: ";
for(int i=0;i<m;i++)
cout<<next[i]<<" ";
//打印extend
cout<< "\nextend: ";
for(int i=0;i<n;i++)
cout<<extend[i]<< " ";
cout<<endl<<endl;
//extend[i]==m的就是匹配成功
}
system("pause");
return 0;
}

Trie(字典树)

字典树,顾名思义就是一个像字典一样的树。

定义:Trie 是⼀种⽤于实现字符串快速检索的多叉树结构。Trie 的每个节点都拥有若⼲个字符指针,若在插⼊或检索字符串时扫描到⼀个字符串 c,就沿着当前节点的 c 这个这个字符指针,⾛向该指针指向的节点

借用这样的数据结构,我们可以方便存取大量字符串,大幅度优化空间复杂度

操作:它是首先建立一个 root 根节点,然后在读取后来的字符串的同时,从根节点出发,查找字符串每一位的节点是否存在。若存在,就从这一位出发继续查找下一位;若不存在,就建立这个节点。反复以上过程。

注意:Trie 树是将字符转换为 ASCii 码存取需要转换

特点

  • 根节点不包含字符,除根节点外每一个节点都只包含一个字符
  • 从根节点到某一个节点,路径上经过的字符连接起来,为该节点对应的字符串
  • 在字典树中查找某一个关键字的时间和树种包含的节点数无关,而取决去组成关键字的字符数,也就是查询关键字的时间复杂度 O(s.length)
  • 实际上每一个节点指向的是下一个字符的 ASCii 码(节点相当于是辅助)真正记录字符的在边上

代码实现

字典树的两种基本操作分别是建树和查询。其中建树操作就是把一个新单词插入到字典树里。查询操作就是查询给定单词是否在字典树上

初始化:⼀颗空的字典树仅包含⼀个指向空的根节点

插入:当需要插⼊⼀个字符串 S 时,令⼀个指针 P 从根节点开始,依次扫描 S 中的每⼀个字符 c

  • 如果 c 字符的指针指向⼀个已经存在的节点 Q ,则令 P=Q
  • 如果 c 字符的指针指向空,那么新建⼀个 Q 节点,令 Pc 字符指针指向 Q ,然后令 P=Q

每次插入的最后将插入的字符串末尾的节点标记表示走到当前节点路径上的字符串存在

查询:开始检索之前,令⼀个指针 P 指向根节点,然后依次扫描S

  • PC 字符指针指向空,说明 S 不在 Trie 中。
  • PC 字符指针指向⼀个已经存在的节点 Q ,令 P=Q

扫描完毕时如果 P 被标记为字符串的末尾,则说明 S 存在,否则说明不在 Trie 中

图示:依次插入 cabcatrain

上图的 Trie 插⼊的均是⼩写字符,所以每个节点最多有 26 个指针。底部的绿⾊节点代表⼀个字符串的结尾。可以发现在 Tire 中,所有字符保存在树的指针(边)上,⽽节点存储的是⼀些辅助信息

模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include<bits/stdc++.h>
using namespace std;
const int N=1e5;//插入字符串的最大字符数

int trie[N][26];
bool ed[N];//字典树[当前指针][指向指针]和结尾标记
int tot;//寄存指针,根节点从1开始,可以统计有多少个节点,值-1就是字符数

int cal(char x){
return x-'a';
}

void init(){
memset(trie,0,sizeof trie);
memset(ed,0,sizeof ed);
tot=1;
}

void insert(char str[]){
int len=strlen(str),p=1;//从根节点1开始往下存
for(int i=0;i<len;i++){
int ch=cal(str[i]);
if(trie[p][ch]==0)trie[p][ch]=++tot;
p=trie[p][ch];
}
ed[p]=true;
}

bool search(char str[]){
int len=strlen(str),p=1;//这里的str[]为待查询的字符串
for(int i=0;i<len;i++){
int ch=cal(str[i]);
p=trie[p][ch];
if(p==0)return false;//没有存储字符串的字符
}
return ed[p];
//关键:若确实到了结尾就返还true,否则说明可能只是字典树里的某个字符串的子串,那也是没找到
}

signed main()
{
init();
int n;cin>>n;
char s[100005];
//插入字符串
for(int i=0;i<n;i++)
cin>>s,insert(s);

//查询字符串
while(cin>>s)
search(s);
return 0;
}

优化:边插入边查询

模板题:给你一系列字符串,以 9 为分割线,判断在某个集合里的字符串有没有作为其他字符串的前缀出现

1

拓展:字典树异或操作

模板题:在给定的 N 个整数 A1,A2,…,AN 中选出两个进行异或运算,得到的结果最大是多少

思路:先把所有的数插入到字典树。根据贪心原理:尽可能出现多的 1 ,1 数量相同的情况下尽量往高处放。所以插入的时候数从高位到低位的顺序插入。因此在查询操作中对每一个数尝试将每个为 0 的位变为1(前提是字典树内存在对应的点,因为只有存在对应的点才存在这个数)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;

ll a[100005],trie[100005*31][2];//只有01取值所以开第二维开两个空间
ll tot=1;//寄存指针

//插入N个31位的数
void insert(ll x){
ll p=1;//根节点从1开始
for(ll i=30;i>=0;i--){
ll ch=(x>>i)&1;
if(trie[p][ch]==0)trie[p][ch]=++tot;
p=trie[p][ch];
}
}

ll query(ll x){
ll res=0;//记录查找的最大异或数的数值
ll p=1;
for(ll i=30;i>=0;i--){
ll ch=(x>>i)&1;
if(trie[p][!ch]>0){
p=trie[p][!ch];
res=res*2+!ch;//向前进一位然后加上新的一位
}
else{
p=trie[p][ch];
res=res*2+ch;
}
}
return res;
}

signed main()
{
ios::sync_with_stdio(false),cin.tie(0);
ll n;cin>>n;
for(ll i=0;i<n;i++){
cin>>a[i];
insert(a[i]);
}
ll cnt=-1;
for(ll i=0;i<n;i++){
cnt=max(cnt,query(a[i]);
}
cout<<cnt<<endl;
//system("pause");
return 0;
}

AC自动机

看这个博客讲解足够了解大概思想,AC自动机主要是用于多个模式串同时匹配

模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 2e6+9;

int trie[maxn][26];//字典树
int cntword[maxn];//记录该单词出现次数
int fail[maxn];//失败时的回溯指针
int cnt = 0;

void insert(string s){
int root=0;
for(int i=0;i<s.size();i++){
int next=s[i]-'a';
if(!trie[root][next])
trie[root][next]=++cnt;
root=trie[root][next];
}
cntword[root]++;//当前节点单词数+1
}
void getFail(){
queue<int>q;
for(int i=0;i<26;i++){//将第二层所有出现了的字母扔进队列
if(trie[0][i]){
fail[trie[0][i]] = 0;
q.push(trie[0][i]);
}
}

//fail[now]->当前节点now的失败指针指向的地方
//tire[now][i]->下一个字母为i+'a'的节点的下标为tire[now][i]
while(!q.empty()){
int now=q.front();
q.pop();

for(int i=0;i<26;i++){//查询26个字母
if(trie[now][i]){
//如果有这个子节点为字母i+'a',则
//让这个节点的失败指针指向(((他父亲节点)的失败指针所指向的那个节点)的下一个节点)
//有点绕,为了方便理解特意加了括号

fail[trie[now][i]] = trie[fail[now]][i];
q.push(trie[now][i]);
}
else//否则就让当前节点的这个子节点
//指向当前节点fail指针的这个子节点
trie[now][i] = trie[fail[now]][i];
}
}
}


int query(string s){
int now=0,ans=0;
for(int i=0;i<s.size();i++){//遍历文本串
now=trie[now][s[i]-'a'];//从s[i]点开始寻找
for(int j=now;j&&cntword[j]!=-1;j=fail[j]){
//一直向下寻找,直到匹配失败(失败指针指向根或者当前节点已找过).
ans+=cntword[j];
cntword[j]=-1;//将遍历国后的节点标记,防止重复计算
}
}
return ans;
}

signed main() {
int n;cin>>n;
string s;
for(int i=0;i<n;i++){//存入多组模式串
cin>>s ;
insert(s);
}
fail[0]=0;
getFail();
cin>>s;//存入匹配串(文本串)
cout<<query(s)<<endl;
return 0;
}