Skip to content

Commit e1100e2

Browse files
add additional test for custom flags being respected
1 parent 62a8c21 commit e1100e2

File tree

2 files changed

+37
-9
lines changed

2 files changed

+37
-9
lines changed

pandas/core/arrays/string_arrow.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,15 @@ def astype(self, dtype, copy: bool = True):
348348
_str_len = ArrowStringArrayMixin._str_len
349349
_str_slice = ArrowStringArrayMixin._str_slice
350350

351+
@staticmethod
352+
def _is_re_pattern_with_flags(pat: str | re.Pattern) -> bool:
353+
# check if `pat` is a compiled regex pattern with flags that are not
354+
# supported by pyarrow
355+
return (
356+
isinstance(pat, re.Pattern)
357+
and (pat.flags & ~(re.IGNORECASE | re.UNICODE)) != 0
358+
)
359+
351360
@staticmethod
352361
def _preprocess_re_pattern(pat: re.Pattern, case: bool):
353362
flags = pat.flags
@@ -369,12 +378,11 @@ def _str_contains(
369378
na=lib.no_default,
370379
regex: bool = True,
371380
):
372-
if flags:
381+
if flags or self._is_re_pattern_with_flags(pat):
373382
return super()._str_contains(pat, case, flags, na, regex)
374383
if isinstance(pat, re.Pattern):
384+
# TODO flags passed separately by user are ignored
375385
pat, case, flags = self._preprocess_re_pattern(pat, case)
376-
if flags:
377-
return super()._str_contains(pat, case, flags, na, regex)
378386

379387
return ArrowStringArrayMixin._str_contains(self, pat, case, flags, na, regex)
380388

@@ -385,12 +393,10 @@ def _str_match(
385393
flags: int = 0,
386394
na: Scalar | lib.NoDefault = lib.no_default,
387395
):
388-
if flags:
396+
if flags or self._is_re_pattern_with_flags(pat):
389397
return super()._str_match(pat, case, flags, na)
390398
if isinstance(pat, re.Pattern):
391399
pat, case, flags = self._preprocess_re_pattern(pat, case)
392-
if flags:
393-
return super()._str_match(pat, case, flags, na)
394400

395401
return ArrowStringArrayMixin._str_match(self, pat, case, flags, na)
396402

@@ -401,12 +407,10 @@ def _str_fullmatch(
401407
flags: int = 0,
402408
na: Scalar | lib.NoDefault = lib.no_default,
403409
):
404-
if flags:
410+
if flags or self._is_re_pattern_with_flags(pat):
405411
return super()._str_fullmatch(pat, case, flags, na)
406412
if isinstance(pat, re.Pattern):
407413
pat, case, flags = self._preprocess_re_pattern(pat, case)
408-
if flags:
409-
return super()._str_fullmatch(pat, case, flags, na)
410414

411415
return ArrowStringArrayMixin._str_fullmatch(self, pat, case, flags, na)
412416

pandas/tests/strings/test_find_replace.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,30 @@ def test_contains_compiled_regex(any_string_dtype):
317317
ser.str.contains(pat, flags=re.IGNORECASE)
318318

319319

320+
def test_contains_compiled_regex_flags(any_string_dtype):
321+
# ensure other (than ignorecase) flags are respected
322+
expected_dtype = (
323+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
324+
)
325+
326+
ser = Series(["foobar", "foo\nbar", "Baz"], dtype=any_string_dtype)
327+
328+
pat = re.compile("^ba")
329+
result = ser.str.contains(pat)
330+
expected = Series([False, False, False], dtype=expected_dtype)
331+
tm.assert_series_equal(result, expected)
332+
333+
pat = re.compile("^ba", flags=re.MULTILINE)
334+
result = ser.str.contains(pat)
335+
expected = Series([False, True, False], dtype=expected_dtype)
336+
tm.assert_series_equal(result, expected)
337+
338+
pat = re.compile("^ba", flags=re.MULTILINE | re.IGNORECASE)
339+
result = ser.str.contains(pat)
340+
expected = Series([False, True, True], dtype=expected_dtype)
341+
tm.assert_series_equal(result, expected)
342+
343+
320344
# --------------------------------------------------------------------------------------
321345
# str.startswith
322346
# --------------------------------------------------------------------------------------

0 commit comments

Comments
 (0)