From 1a05fcbbea0a7252e3e9198e251d38f044ad7ab6 Mon Sep 17 00:00:00 2001 From: swanand11 Date: Wed, 3 Dec 2025 11:34:51 +0530 Subject: [PATCH 1/2] =?UTF-8?q?Add=20Rabin=E2=80=93Karp=20String=20Matchin?= =?UTF-8?q?g=20Algorithm=20(Fixes=20#13918)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- strings/rabin_karp.py | 176 +++++++++++++++++++++++------------------- 1 file changed, 97 insertions(+), 79 deletions(-) diff --git a/strings/rabin_karp.py b/strings/rabin_karp.py index 9c0d0fe5c739..b158c76a46e7 100644 --- a/strings/rabin_karp.py +++ b/strings/rabin_karp.py @@ -1,91 +1,109 @@ -# Numbers of alphabet which we call base -alphabet_size = 256 -# Modulus to hash a string -modulus = 1000003 +""" +Rabin–Karp String Matching Algorithm +https://en.wikipedia.org/wiki/Rabin%E2%80%93Karp_algorithm +""" +from typing import Dict, Iterable, List, Tuple -def rabin_karp(pattern: str, text: str) -> bool: +MOD: int = 1_000_000_007 +BASE: int = 257 + + +def rabin_karp(text: str, pattern: str) -> List[int]: """ - The Rabin-Karp Algorithm for finding a pattern within a piece of text - with complexity O(nm), most efficient when it is used with multiple patterns - as it is able to check if any of a set of patterns match a section of text in o(1) - given the precomputed hashes. + Return all starting indices where `pattern` appears in `text`. + + >>> rabin_karp("abracadabra", "abra") + [0, 7] + >>> rabin_karp("aaaaa", "aa") # overlapping matches + [0, 1, 2, 3] + >>> rabin_karp("hello", "") # empty pattern matches everywhere + [0, 1, 2, 3, 4, 5] + >>> rabin_karp("", "abc") + [] + """ + n, m = len(text), len(pattern) + if m == 0: + return list(range(n + 1)) + if n < m: + return [] + + # Precompute BASE^(m-1) % MOD + power = pow(BASE, m - 1, MOD) + + # Hashes for pattern and first window of text + hp = ht = 0 + for i in range(m): + hp = (hp * BASE + ord(pattern[i])) % MOD + ht = (ht * BASE + ord(text[i])) % MOD + + results: List[int] = [] + + for i in range(n - m + 1): + if hp == ht and text[i : i + m] == pattern: + results.append(i) + + if i < n - m: + # sliding window: remove left char, add right char + left = (ord(text[i]) * power) % MOD + ht = (ht - left) % MOD + ht = (ht * BASE + ord(text[i + m])) % MOD - This will be the simple version which only assumes one pattern is being searched - for but it's not hard to modify + return results - 1) Calculate pattern hash - 2) Step through the text one character at a time passing a window with the same - length as the pattern - calculating the hash of the text within the window compare it with the hash - of the pattern. Only testing equality if the hashes match +def rabin_karp_multi(text: str, patterns: Iterable[str]) -> Dict[str, List[int]]: """ - p_len = len(pattern) - t_len = len(text) - if p_len > t_len: - return False - - p_hash = 0 - text_hash = 0 - modulus_power = 1 - - # Calculating the hash of pattern and substring of text - for i in range(p_len): - p_hash = (ord(pattern[i]) + p_hash * alphabet_size) % modulus - text_hash = (ord(text[i]) + text_hash * alphabet_size) % modulus - if i == p_len - 1: + Multiple-pattern Rabin–Karp. + Groups patterns by length and scans text once. + + >>> rabin_karp_multi("abracadabra", ["abra", "bra", "cad"]) + {'abra': [0, 7], 'bra': [1, 8], 'cad': [4]} + >>> rabin_karp_multi("aaaaa", ["aa", "aaa"]) + {'aa': [0, 1, 2, 3], 'aaa': [0, 1, 2]} + """ + patterns = list(patterns) + result: Dict[str, List[int]] = {p: [] for p in patterns} + + # Group patterns by length + groups: Dict[int, List[str]] = {} + for p in patterns: + groups.setdefault(len(p), []).append(p) + + for length, group in groups.items(): + if length == 0: + for p in group: + result[p] = list(range(len(text) + 1)) continue - modulus_power = (modulus_power * alphabet_size) % modulus - for i in range(t_len - p_len + 1): - if text_hash == p_hash and text[i : i + p_len] == pattern: - return True - if i == t_len - p_len: + # Precompute pattern hashes + p_hash: Dict[int, List[str]] = {} + for p in group: + h = 0 + for c in p: + h = (h * BASE + ord(c)) % MOD + p_hash.setdefault(h, []).append(p) + + # Scan text using sliding window hashing + if len(text) < length: continue - # Calculate the https://en.wikipedia.org/wiki/Rolling_hash - text_hash = ( - (text_hash - ord(text[i]) * modulus_power) * alphabet_size - + ord(text[i + p_len]) - ) % modulus - return False + power = pow(BASE, length - 1, MOD) + h = 0 + for i in range(length): + h = (h * BASE + ord(text[i])) % MOD + + for i in range(len(text) - length + 1): + if h in p_hash: + window = text[i : i + length] + for p in p_hash[h]: + if window == p: + result[p].append(i) + + if i < len(text) - length: + left = (ord(text[i]) * power) % MOD + h = (h - left) % MOD + h = (h * BASE + ord(text[i + length])) % MOD + + return result -def test_rabin_karp() -> None: - """ - >>> test_rabin_karp() - Success. - """ - # Test 1) - pattern = "abc1abc12" - text1 = "alskfjaldsabc1abc1abc12k23adsfabcabc" - text2 = "alskfjaldsk23adsfabcabc" - assert rabin_karp(pattern, text1) - assert not rabin_karp(pattern, text2) - - # Test 2) - pattern = "ABABX" - text = "ABABZABABYABABX" - assert rabin_karp(pattern, text) - - # Test 3) - pattern = "AAAB" - text = "ABAAAAAB" - assert rabin_karp(pattern, text) - - # Test 4) - pattern = "abcdabcy" - text = "abcxabcdabxabcdabcdabcy" - assert rabin_karp(pattern, text) - - # Test 5) - pattern = "Lü" - text = "Lüsai" - assert rabin_karp(pattern, text) - pattern = "Lue" - assert not rabin_karp(pattern, text) - print("Success.") - - -if __name__ == "__main__": - test_rabin_karp() From d367ed3ef80887fc563f32f7c3466558817544ee Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 3 Dec 2025 06:07:55 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strings/rabin_karp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/strings/rabin_karp.py b/strings/rabin_karp.py index b158c76a46e7..a63e9b53b0fa 100644 --- a/strings/rabin_karp.py +++ b/strings/rabin_karp.py @@ -106,4 +106,3 @@ def rabin_karp_multi(text: str, patterns: Iterable[str]) -> Dict[str, List[int]] h = (h * BASE + ord(text[i + length])) % MOD return result -