Skip to content

Commit 7b8a314

Browse files
BUG: use arrow backend for digits references in str.replace (#62872)
Co-authored-by: zishan044 <winchesterfelix007@gmail.com> (cherry picked from commit a947b55)
1 parent f4e1c73 commit 7b8a314

File tree

3 files changed

+72
-8
lines changed

3 files changed

+72
-8
lines changed

pandas/core/arrays/_arrow_string_mixins.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,15 +174,12 @@ def _str_replace(
174174
or callable(repl)
175175
or not case
176176
or flags
177-
or (
178-
isinstance(repl, str)
179-
and (r"\g<" in repl or re.search(r"\\\d", repl) is not None)
180-
)
177+
or (isinstance(repl, str) and r"\g<" in repl)
181178
):
182179
raise NotImplementedError(
183180
"replace is not supported with a re.Pattern, callable repl, "
184181
"case=False, flags!=0, or when the replacement string contains "
185-
"named group references (\\g<...>, \\d+)"
182+
"named group references (\\g<...>)"
186183
)
187184

188185
func = pc.replace_substring_regex if regex else pc.replace_substring

pandas/core/arrays/string_arrow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def _str_replace(
427427
or ( # substitution contains a named group pattern
428428
# https://docs.python.org/3/library/re.html
429429
isinstance(repl, str)
430-
and (r"\g<" in repl or re.search(r"\\\d", repl) is not None)
430+
and r"\g<" in repl
431431
)
432432
):
433433
return super()._str_replace(pat, repl, n, case, flags, regex)

pandas/tests/strings/test_find_replace.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pandas as pd
1010
from pandas import (
1111
Series,
12+
StringDtype,
1213
_testing as tm,
1314
)
1415
from pandas.tests.strings import (
@@ -601,6 +602,10 @@ def test_replace_callable_raises(any_string_dtype, repl):
601602
r"\g<three> \g<two> \g<one>",
602603
["Three Two One", "Baz Bar Foo"],
603604
),
605+
(
606+
r"\3 \2 \1",
607+
["Three Two One", "Baz Bar Foo"],
608+
),
604609
(
605610
r"\g<3> \g<2> \g<1>",
606611
["Three Two One", "Baz Bar Foo"],
@@ -616,6 +621,7 @@ def test_replace_callable_raises(any_string_dtype, repl):
616621
],
617622
ids=[
618623
"named_groups_full_swap",
624+
"numbered_groups_no_g_full_swap",
619625
"numbered_groups_full_swap",
620626
"single_group_with_literal",
621627
"mixed_group_reference_with_literal",
@@ -640,22 +646,83 @@ def test_replace_named_groups_regex_swap(
640646
[
641647
r"\g<20>",
642648
r"\20",
649+
r"\40",
650+
r"\4",
643651
],
644652
)
645653
@pytest.mark.parametrize("use_compile", [True, False])
646654
def test_replace_named_groups_regex_swap_expected_fail(
647-
any_string_dtype, repl, use_compile
655+
any_string_dtype, repl, use_compile, request
648656
):
649657
# GH#57636
658+
if (
659+
not use_compile
660+
and r"\g" not in repl
661+
and isinstance(any_string_dtype, StringDtype)
662+
and any_string_dtype.storage == "pyarrow"
663+
):
664+
# calls pyarrow method directly
665+
if repl == r"\20":
666+
mark = pytest.mark.xfail(reason="PyArrow interprets as group + literal")
667+
request.applymarker(mark)
668+
669+
pa = pytest.importorskip("pyarrow")
670+
error_type = pa.ArrowInvalid
671+
error_msg = r"only has \d parenthesized subexpressions"
672+
else:
673+
error_type = re.error
674+
error_msg = "invalid group reference"
675+
650676
pattern = r"(?P<one>\w+) (?P<two>\w+) (?P<three>\w+)"
651677
if use_compile:
652678
pattern = re.compile(pattern)
653679
ser = Series(["One Two Three", "Foo Bar Baz"], dtype=any_string_dtype)
654680

655-
with pytest.raises(re.error, match="invalid group reference"):
681+
with pytest.raises(error_type, match=error_msg):
656682
ser.str.replace(pattern, repl, regex=True)
657683

658684

685+
@pytest.mark.parametrize(
686+
"pattern, repl",
687+
[
688+
(r"(\w+) (\w+) (\w+)", r"\20"),
689+
(r"(?P<one>\w+) (?P<two>\w+) (?P<three>\w+)", r"\20"),
690+
],
691+
)
692+
def test_pyarrow_ambiguous_group_references(pyarrow_string_dtype, pattern, repl):
693+
# GH#62653
694+
ser = Series(["One Two Three", "Foo Bar Baz"], dtype=pyarrow_string_dtype)
695+
696+
result = ser.str.replace(pattern, repl, regex=True)
697+
expected = Series(["Two0", "Bar0"], dtype=pyarrow_string_dtype)
698+
tm.assert_series_equal(result, expected)
699+
700+
701+
@pytest.mark.parametrize(
702+
"pattern, repl, expected_list",
703+
[
704+
(
705+
r"\[(?P<one>\d+)\]",
706+
r"(\1)",
707+
["var.one(0)", "var.two(1)", "var.three(2)"],
708+
),
709+
(
710+
r"\[(\d+)\]",
711+
r"(\1)",
712+
["var.one(0)", "var.two(1)", "var.three(2)"],
713+
),
714+
],
715+
)
716+
@td.skip_if_no("pyarrow")
717+
def test_pyarrow_backend_group_replacement(pattern, repl, expected_list):
718+
ser = Series(["var.one[0]", "var.two[1]", "var.three[2]"]).convert_dtypes(
719+
dtype_backend="pyarrow"
720+
)
721+
result = ser.str.replace(pattern, repl, regex=True)
722+
expected = Series(expected_list).convert_dtypes(dtype_backend="pyarrow")
723+
tm.assert_series_equal(result, expected)
724+
725+
659726
def test_replace_callable_named_groups(any_string_dtype):
660727
# test regex named groups
661728
ser = Series(["Foo Bar Baz", np.nan], dtype=any_string_dtype)

0 commit comments

Comments
 (0)