kmp 算法

发布于:2025-08-31 ⋅ 阅读:(23) ⋅ 点赞:(0)

hello,大家好!今天是2025年8月29日,我要来给大家分享的是 kmp 算法。

1:kmp 算法

1.1 相关概念

【字符串】

  • 用字符构成的序列就是字符串。

在字符串匹配问题中,我们会让字符串的下标从 1 开始,这样便于我们处理一些边界问题。因此,在输入字符串时,我们一般会在前面加上一个空格,这样字符就从 1 开始计数了

string s; cin >> s;
int n = s.size();
s = ' ' + s;

注意:上面代码的第二行和第三行尽量不要颠倒,如果硬要颠倒的话,要注意枚举字符时候的上界

回顾我们之前字符串相关的动态规划问题,如果我们把下标为 0 的格子空出来,从下标为 1 的格子开始填表,就可以省去一些初始化问题,又可以防止数组下标越界 “-1”。

【子串】

  • 选取字符串中连续的一段。

【前缀】

  • 从字符串的首端开始,到某一个位置结束的子串。

那么,字符串长度为 i 的前缀,就是字符串【1,i】区间的子串。

【真前缀】

  • 不包含字符串本身的前缀子串。

【后缀】

  • 从字符串的某一个位置开始,到字符串末端的子串。

那么,字符串长度为 i 的后缀,就是字符串【n - i + 1,n】区间的子串。

【真后缀】

  • 不包含字符串本身的后缀子串。

【真公共前后缀 - border】

  • 字符串 s 的真公共前后缀为 s 的一个子串 t,满足 t 既是 s 的真前缀,又是 s 的真后缀,又称为字符串 s 的 border。
  • 在一个字符串中,最长的真公共前后缀的长度为用 pi 表示。

练习:

性质:

  • 传递性:字符串 s 的 border 的 border 也是字符串 s 的 border。

这个性质是很显然的,可以画图看一看:

左 == 右 ==> A == C && B == D,A == B (已知) ==> A == D 

【字符串匹配(kmp)问题】

  • 字符串匹配又称为模式匹配,给定两个字符串 S 和 T,需要在主串 S 中找到模式串 T。

比如,主串 S = “abcdefcde”,模式串 T = “cde”。如果下标从 1 开始计数,模式串会在主串 3,7 位置出现。

关于字符串匹配,大家首先想到的大概率是暴力解法,拿着模式串在主串中一个位置一个位置判断。但是,暴力解法最差情况下时间开销会达到 O(n * m),在算法题中大概率是会超时的。

接下来要介绍的 kmp 算法,能在 O(n + m)的线性时间复杂度内找到所有的匹配位置,并且维护出更多的信息。

1.2 前缀函数

【前缀函数】

  • 字符串每一个前缀子串的 pi 值。

以字符串 “aabaab” 为例,pi[i] 表示:字符串 s 长度为 i 的前缀,最长的 border 长度(最长公共前后缀)。

【小用途】

  • 从大到小枚举字符串 s 某个前缀的所有 border。

与求解前缀函数的过程息息相关。

假设我们此时生成了一个字符串 s 的前缀函数表,我们可以利用这张表,从大到小枚举某个前缀的所有 border。原理就是 border 的传递性:字符串 border 的 border 还是字符串的 border。

证明:(反证法)

正确性:会不会漏掉一些 border ?

显然是不会的:

  • 首先 pi[i] 存的是最长的 border 的长度,定义 pi 数组时就定死了;
  • 其次(反证法),下一个 pi[pi[i]] 如果不是次长的,那么 pi[pi[i]] 就不是长度为 pi[i] 的前缀的最长 border 的长度,与我们 pi 数组的定义相违背;
  • 因此,整个过程一定能够不重不漏的将所有的 border 从大到小枚举出来; 

代码实现:(假设已经生成好了前缀函数)

string s;
int pi[N];

// 长度为 i 的前缀中,所有 border 的长度
void get_border(int i)
{
    int j = pi[i];
    while(j)
    {
        cout << j << endl;
        j = pi[j];
    }
}

1.3 计算前缀函数

【计算过程】

  • 前缀函数的计算过程包含了动态规划的思想,就是推导状态转移方程。

对于字符串 s:

1. 状态表示:

        pi[i] 表示:字符串 s 长度为 i 的前缀,最长的 border 的长度(最长真公共前后缀的长度)。

2. 状态转移方程:

        a. 我们发现,如果将长度为 i 的前缀中的 border 删去最后一个字符,就变成了长度为 i - 1 的前缀中的 border;

        b. 那么,我们就可以从大到小枚举长度 i - 1 的前缀中的所有 border,然后判断这个 border 的下一个字符是否和 s[i] 相等:(小贪心)

        如果相等,说明这个就是最长的 ;

        如果不相等,那就继续判断下一个 border,直到所有的 border 验证完毕。

【正确性】

  • 毋庸置疑是正确的。因为我们只从大到小枚举【1,i - 1】区间字符串所有的 border;
  • 当第一次判断成功时,一定是【1,i】最长的 border;

【代码实现】

string s;
int pi[N];

void get_pi()
{
	cin >> s;
	int n = s.size();
	s = ' ' + s;
	// pi[1] = 0
	
	for(int i = 2; i <= n; i++)
	{
		int j = pi[i - 1];
		while(j && s[i] != s[j + 1]) j = pi[j];
		if(s[i] == s[j + 1]) j++;
		pi[i] = j;
	}
	
	// 我们注意到,i++ 之后 pi[i - 1] 依旧是上一次的 j,所以代码也可以这样写
	// 更能体现到 j 指针基本上不回退
	
	for(int i = 2, j = 0; i <= n; i++)
	{
		while(j && s[i] != s[j + 1]) j = pi[j];
		if(s[i] == s[j + 1]) j++;
		pi[i] = j;
	} 
}

这段代码中间的 for 循环和我们传统(学校讲 kmp)学习的求 next 数组的过程完全一样。

【时间复杂度】

  • 模拟一遍过程会发现,i 指针每次向后移动一位后,j 指针最多会向后移动一位,然后继续往前跳。因此两个指针最差情况下会遍历字符串两遍,时间复杂度 O(2n)= O(n)。

自己一定要去尝试模拟一遍!!!

1.4 用前缀函数解决字符串匹配

  • 设主串 S = “abcabaaaba”,模式串 T = “aba”,主串的长度为 n,模式串的长度为 m。

将两个字符串拼接起来:S = T + '#' + S = "aba#abcabaaaba",对于新的字符串,可以在线性时间复杂度 O(n + m)内生成前缀函数。

注意:上面的 '#' 并不是固定的,只要选择一个主串和模式串都不会出现的字符就可以。

前缀函数等于模式串长度的位置 i ,就是能够匹配的末端。在主串中,出现的位置就是 i - 2 * m。

那么,有了前缀函数之后,我们不仅能够知道模式串在主串中匹配了多少次,还能知道每次匹配的起始位置。

(这里其实就已经是 kmp 算法了。因为 kmp 算法的本质,就是利用模式串的前缀函数的信息,实现快速匹配。只不过大家经常遇到的 next 数组的形式(过程更复杂),是把上述过程拆成两部分而已。)

1.5 kmp 算法模板题

题目链接:【模板】KMP

【题目解析】

本题直接就是 KMP 算法的模板题,直接利用前缀函数解决字符串匹配问题就可以了。这样的话顺便还可以求出模式串所有前缀的 border。

注意 N 要定义成 2e6 + 10,我们要将两个字符串拼接。2 * len

【代码实现】

#include <iostream>

using namespace std;

const int N = 2e6 + 10;

string s, t;
int n, m;
int pi[N];

int main()
{
	cin >> s >> t;
	
	n = s.size(), m = t.size();
	
	s = ' ' + t + '#' + s;
	for(int i = 2; i <= n + m + 1; i++)
	{
		int j = pi[i - 1];
		while(j && s[i] != s[j + 1]) j = pi[j];
		if(s[i] == s[j + 1]) j++;
		pi[i] = j;
		
		if(pi[i] == m) cout << i - 2 * m << endl;
		
	}
	
	for(int i = 1; i <= m; i++) cout << pi[i] << " ";
	
	return 0;
}

至于上面为什么 i <= n + m + 1,cout << i - 2 * m,可以在草稿纸上画图找找规律~~

1.6 next 数组版本(选学)

大多数教材中的 next 数组版本,其实是把【用前缀函数解决字符串匹配】的过程拆成了两部分:

  1. 先预处理模式串 t 的前缀函数 - next 数组;
  2. 在暴力枚举的过程中,用生成的 next 数组,加速匹配。

next 数组本质上就是前缀函数~~(找某一个前缀的最长 border)

只预处理出模式串的前缀函数就可以了~~

只需知道模式串的前缀函数就足矣完成匹配了。

加速匹配的原理与求解前缀函数的原理是一样的,本质都是枚举 border 的 border。

接下来,用一个实际的例子模拟一下匹配的过程,大家会发现,next 数组版本和求前缀函数的本质是一样的。

【代码实现】

string s, t;
int n, m;
int ne[N];

void kmp()
{
	n = s.size(); m = t.size();
	s = ' ' + s; t = ' ' + t;
	// 预处理模式串的 next 数组
	for(int i = 2, j = 0; i <= m; i++)
	{
		while(j && t[i] != t[j + 1]) j = ne[j];
		if(t[i] == t[j + 1]) j++;
		ne[i] = j;
	} 
	
	// 利用 next 数组匹配
	for(int i = 1, j = 0; i <= m; i++)
	{
		while(j && s[i] != t[j + 1]) j = ne[j];
		if(s[i] == t[j + 1]) j++;
		if(j == m)
		{
			cout << i - m + 1 << endl;
		}
	} 
}

2:周期和循环节

2.1 相关概念

【周期】

  • 对于字符串 s 和正整数 p,如果有 s[i] = s[i + p],对于 1 <= i <= |s| - p 成立,则称 p 为字符串  s 的一个周期。

比如:字符串 “abbabba” 的周期有 3,6,7。

【循环节】

  • 如果字符串 s 的周期 p 满足 p | |s|,则称 p 为 s 的一个循环节。

比如:字符串 “abbabba” 的循环节为 7;字符串 “abbabbabb” 的循环节为 3,9。

特殊地,p = |s| 既是 s 的周期,也是 s 的循环节。

注意:这里不要死记概念,理解了即可。

【性质】(很重要)

设字符串 s 的长度为 n,那么有:

  • 字符串 s 有长度为 p 的 border 等价于 n - p 是字符串 s 的周期。

证明的话比较简单,画个图就可以了:

因此,字符串的周期性等价于 border 的性质,求周期就是在求字符串的 border。

字符串最小的周期 = n - 字符串最长的 border。

如何求所有的周期?

前缀函数:小用途 -> 从大到小枚举出所有的 border 就可以从小到大算出所有的周期~~

2.3 相关算法题

【题目一:Power Strings】

题目链接 Power Strings

这道题在求的是循环节,只需要求出最长的 border 就可以了。

至于求出的是周期还是循环节,再加一层判断就 OK

具体请看代码:

#include <iostream>

using namespace std;

const int N = 1e6 + 10;

string s;
int n;
int pi[N];

int kmp()
{
    n = s.size();
    s = ' ' + s;

    for(int i = 2; i <= n; i++)
    {
        int j = pi[i - 1];
        while(j && s[i] != s[j + 1]) j = pi[j];
        if(s[i] == s[j + 1]) j++;
        pi[i] = j;
    }

    if(n % (n - pi[n])) return 1;
    else return n / (n - pi[n]);
}

int main()
{
    while(cin >> s)
    {
        if(s == ".") break;

        cout << kmp() << endl;
    }

    return 0;
}

【题目二:Radio Transmission】

题目链接:Radio Transmission

这道题求的就是最小的周期,直接求最长的 border 就可以了。n - pi[n]。

#include <iostream>

using namespace std;

const int N = 1e6 + 10;

string s;
int n, pi[N];

int main()
{
    cin >> n >> s;
    s = ' ' + s;

    for(int i = 2; i <= n; i++)
    {
        int j = pi[i - 1];
        while(j && s[i] != s[j + 1]) j = pi[j];
        if(s[i] == s[j + 1]) j++;
        pi[i] = j;
    }

    cout << n - pi[n] << endl;

    return 0;
}

好了,这一次的分享就到达这里了,谢谢大家的观看~~ 


网站公告

今日签到

点亮在社区的每一天
去签到