Skip to content

Commit 38daa15

Browse files
committed
Add Rabin-Karp String Matching Algorithm (#13918)
1 parent a051ab5 commit 38daa15

File tree

1 file changed

+228
-76
lines changed

1 file changed

+228
-76
lines changed

strings/rabin_karp.py

Lines changed: 228 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,243 @@
1-
# Numbers of alphabet which we call base
2-
alphabet_size = 256
3-
# Modulus to hash a string
4-
modulus = 1000003
1+
"""
2+
Rabin-Karp String Matching Algorithm
53
4+
The Rabin-Karp algorithm uses hashing to find patterns in text.
5+
It employs a rolling hash technique for efficient pattern searching.
66
7-
def rabin_karp(pattern: str, text: str) -> bool:
8-
"""
9-
The Rabin-Karp Algorithm for finding a pattern within a piece of text
10-
with complexity O(nm), most efficient when it is used with multiple patterns
11-
as it is able to check if any of a set of patterns match a section of text in o(1)
12-
given the precomputed hashes.
7+
Time Complexity:
8+
- Average case: O(n + m) where n is text length, m is pattern length
9+
- Worst case: O(nm) when many spurious hits occur
1310
14-
This will be the simple version which only assumes one pattern is being searched
15-
for but it's not hard to modify
11+
Space Complexity: O(1) for single pattern, O(k) for k patterns
1612
17-
1) Calculate pattern hash
13+
Applications:
14+
- Plagiarism detection
15+
- DNA sequence matching
16+
- Multiple pattern searching
17+
- Finding duplicate content
18+
"""
1819

19-
2) Step through the text one character at a time passing a window with the same
20-
length as the pattern
21-
calculating the hash of the text within the window compare it with the hash
22-
of the pattern. Only testing equality if the hashes match
23-
"""
24-
p_len = len(pattern)
25-
t_len = len(text)
26-
if p_len > t_len:
27-
return False
2820

29-
p_hash = 0
30-
text_hash = 0
31-
modulus_power = 1
21+
def rabin_karp_search(
22+
text: str, pattern: str, base: int = 256, modulus: int = 101
23+
) -> list[int]:
24+
"""
25+
Search for a pattern in text using Rabin-Karp algorithm.
26+
27+
Args:
28+
text: The text to search in
29+
pattern: The pattern to search for
30+
base: The base for hash calculation (default: 256 for ASCII)
31+
modulus: The modulus for hash calculation (prime number)
32+
33+
Returns:
34+
List of starting indices where pattern is found
35+
36+
Examples:
37+
>>> rabin_karp_search("hello world hello", "hello")
38+
[0, 12]
39+
>>> rabin_karp_search("aaaa", "aa")
40+
[0, 1, 2]
41+
>>> rabin_karp_search("abc", "xyz")
42+
[]
43+
>>> rabin_karp_search("", "a")
44+
[]
45+
>>> rabin_karp_search("a", "")
46+
[]
47+
>>> rabin_karp_search("abcdefg", "cde")
48+
[2]
49+
>>> rabin_karp_search("ABABDABACDABABCABAB", "ABABCABAB")
50+
[10]
51+
>>> rabin_karp_search("test test test", "test")
52+
[0, 5, 10]
53+
"""
54+
if not pattern or not text or len(pattern) > len(text):
55+
return []
3256

33-
# Calculating the hash of pattern and substring of text
34-
for i in range(p_len):
35-
p_hash = (ord(pattern[i]) + p_hash * alphabet_size) % modulus
36-
text_hash = (ord(text[i]) + text_hash * alphabet_size) % modulus
37-
if i == p_len - 1:
38-
continue
39-
modulus_power = (modulus_power * alphabet_size) % modulus
57+
n = len(text)
58+
m = len(pattern)
59+
matches = []
4060

41-
for i in range(t_len - p_len + 1):
42-
if text_hash == p_hash and text[i : i + p_len] == pattern:
43-
return True
44-
if i == t_len - p_len:
61+
# Calculate hash value for pattern and first window of text
62+
pattern_hash = 0
63+
text_hash = 0
64+
h = 1
65+
66+
# The value of h would be "pow(base, m-1) % modulus"
67+
for _ in range(m - 1):
68+
h = (h * base) % modulus
69+
70+
# Calculate initial hash values
71+
for i in range(m):
72+
pattern_hash = (base * pattern_hash + ord(pattern[i])) % modulus
73+
text_hash = (base * text_hash + ord(text[i])) % modulus
74+
75+
# Slide the pattern over text one by one
76+
for i in range(n - m + 1):
77+
# Check if hash values match
78+
if pattern_hash == text_hash:
79+
# Verify character by character to avoid spurious hits
80+
if text[i : i + m] == pattern:
81+
matches.append(i)
82+
83+
# Calculate hash for next window (rolling hash)
84+
if i < n - m:
85+
# Remove leading character and add trailing character
86+
text_hash = (
87+
base * (text_hash - ord(text[i]) * h) + ord(text[i + m])
88+
) % modulus
89+
90+
# Handle negative hash values
91+
if text_hash < 0:
92+
text_hash += modulus
93+
94+
return matches
95+
96+
97+
def rabin_karp_multiple(
98+
text: str, patterns: list[str], base: int = 256, modulus: int = 101
99+
) -> dict[str, list[int]]:
100+
"""
101+
Search for multiple patterns in text using Rabin-Karp algorithm.
102+
103+
This is more efficient than running single pattern search multiple times
104+
because we only scan the text once.
105+
106+
Args:
107+
text: The text to search in
108+
patterns: List of patterns to search for
109+
base: The base for hash calculation
110+
modulus: The modulus for hash calculation
111+
112+
Returns:
113+
Dictionary mapping each pattern to list of indices where found
114+
115+
Examples:
116+
>>> result = rabin_karp_multiple("hello world hello", ["hello", "world"])
117+
>>> result == {"hello": [0, 12], "world": [6]}
118+
True
119+
>>> result = rabin_karp_multiple("aaaa", ["aa", "aaa"])
120+
>>> result == {"aa": [0, 1, 2], "aaa": [0, 1]}
121+
True
122+
>>> result = rabin_karp_multiple("test", ["abc", "xyz"])
123+
>>> result == {"abc": [], "xyz": []}
124+
True
125+
>>> result = rabin_karp_multiple("", ["a", "b"])
126+
>>> result == {"a": [], "b": []}
127+
True
128+
>>> result = rabin_karp_multiple("abcdef", ["ab", "cd", "ef"])
129+
>>> result == {"ab": [0], "cd": [2], "ef": [4]}
130+
True
131+
"""
132+
if not text or not patterns:
133+
return {pattern: [] for pattern in patterns}
134+
135+
# Group patterns by length for efficient processing
136+
patterns_by_length: dict[int, list[str]] = {}
137+
for pattern in patterns:
138+
if pattern: # Skip empty patterns
139+
length = len(pattern)
140+
if length not in patterns_by_length:
141+
patterns_by_length[length] = []
142+
patterns_by_length[length].append(pattern)
143+
144+
results = {pattern: [] for pattern in patterns}
145+
146+
# Process each group of patterns with same length
147+
for pattern_length, pattern_group in patterns_by_length.items():
148+
if pattern_length > len(text):
45149
continue
46-
# Calculate the https://en.wikipedia.org/wiki/Rolling_hash
47-
text_hash = (
48-
(text_hash - ord(text[i]) * modulus_power) * alphabet_size
49-
+ ord(text[i + p_len])
50-
) % modulus
51-
return False
52-
53150

54-
def test_rabin_karp() -> None:
151+
# Calculate pattern hashes
152+
pattern_hashes = {}
153+
for pattern in pattern_group:
154+
pattern_hash = 0
155+
for char in pattern:
156+
pattern_hash = (base * pattern_hash + ord(char)) % modulus
157+
pattern_hashes[pattern] = pattern_hash
158+
159+
# Calculate hash for first window
160+
text_hash = 0
161+
h = 1
162+
for _ in range(pattern_length - 1):
163+
h = (h * base) % modulus
164+
165+
for i in range(pattern_length):
166+
text_hash = (base * text_hash + ord(text[i])) % modulus
167+
168+
# Slide the window over text
169+
for i in range(len(text) - pattern_length + 1):
170+
# Check if current hash matches any pattern hash
171+
for pattern, pattern_hash in pattern_hashes.items():
172+
if text_hash == pattern_hash:
173+
# Verify to avoid spurious hits
174+
if text[i : i + pattern_length] == pattern:
175+
results[pattern].append(i)
176+
177+
# Calculate hash for next window
178+
if i < len(text) - pattern_length:
179+
text_hash = (
180+
base * (text_hash - ord(text[i]) * h)
181+
+ ord(text[i + pattern_length])
182+
) % modulus
183+
184+
if text_hash < 0:
185+
text_hash += modulus
186+
187+
return results
188+
189+
190+
def rabin_karp_search_optimized(
191+
text: str, pattern: str, base: int = 256, modulus: int = 1_000_000_007
192+
) -> list[int]:
55193
"""
56-
>>> test_rabin_karp()
57-
Success.
194+
Optimized version with larger modulus to reduce collisions.
195+
196+
Using a larger prime modulus (10^9 + 7) significantly reduces
197+
the probability of hash collisions, improving average-case performance.
198+
199+
Args:
200+
text: The text to search in
201+
pattern: The pattern to search for
202+
base: The base for hash calculation
203+
modulus: Large prime modulus (default: 10^9 + 7)
204+
205+
Returns:
206+
List of starting indices where pattern is found
207+
208+
Examples:
209+
>>> rabin_karp_search_optimized("hello world", "world")
210+
[6]
211+
>>> rabin_karp_search_optimized("aaabaaaa", "aaaa")
212+
[4]
213+
>>> rabin_karp_search_optimized("abc", "d")
214+
[]
58215
"""
59-
# Test 1)
60-
pattern = "abc1abc12"
61-
text1 = "alskfjaldsabc1abc1abc12k23adsfabcabc"
62-
text2 = "alskfjaldsk23adsfabcabc"
63-
assert rabin_karp(pattern, text1)
64-
assert not rabin_karp(pattern, text2)
65-
66-
# Test 2)
67-
pattern = "ABABX"
68-
text = "ABABZABABYABABX"
69-
assert rabin_karp(pattern, text)
70-
71-
# Test 3)
72-
pattern = "AAAB"
73-
text = "ABAAAAAB"
74-
assert rabin_karp(pattern, text)
75-
76-
# Test 4)
77-
pattern = "abcdabcy"
78-
text = "abcxabcdabxabcdabcdabcy"
79-
assert rabin_karp(pattern, text)
80-
81-
# Test 5)
82-
pattern = "Lü"
83-
text = "Lüsai"
84-
assert rabin_karp(pattern, text)
85-
pattern = "Lue"
86-
assert not rabin_karp(pattern, text)
87-
print("Success.")
216+
return rabin_karp_search(text, pattern, base, modulus)
88217

89218

90219
if __name__ == "__main__":
91-
test_rabin_karp()
220+
import doctest
221+
222+
doctest.testmod()
223+
224+
# Performance demonstration
225+
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit" * 100
226+
pattern = "consectetur"
227+
228+
print("Rabin-Karp String Matching Algorithm Demo")
229+
print("=" * 50)
230+
231+
# Single pattern search
232+
matches = rabin_karp_search(text, pattern)
233+
print(f"\nSearching for '{pattern}' in text ({len(text)} chars)")
234+
print(f"Found {len(matches)} matches at indices: {matches[:5]}...")
235+
236+
# Multiple pattern search
237+
patterns = ["Lorem", "ipsum", "consectetur", "adipiscing"]
238+
results = rabin_karp_multiple(text, patterns)
239+
print(f"\nSearching for {len(patterns)} patterns:")
240+
for p, indices in results.items():
241+
print(f" '{p}': {len(indices)} matches")
242+
243+
print("\n✓ All tests passed!")

0 commit comments

Comments
 (0)