Skip to content

Commit 125a9b7

Browse files
authored
REF: move string arithmetic tests to tests.arithmetic.test_string (#62869)
1 parent a947b55 commit 125a9b7

File tree

5 files changed

+312
-268
lines changed

5 files changed

+312
-268
lines changed

pandas/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,6 +1447,9 @@ def any_string_dtype(request):
14471447
return pd.StringDtype(storage, na_value)
14481448

14491449

1450+
any_string_dtype2 = any_string_dtype
1451+
1452+
14501453
@pytest.fixture(params=tm.DATETIME64_DTYPES)
14511454
def datetime64_dtype(request):
14521455
"""

pandas/tests/arithmetic/test_string.py

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,49 @@
1+
import operator
12
from pathlib import Path
23

34
import numpy as np
45
import pytest
56

7+
from pandas.compat import HAS_PYARROW
68
from pandas.errors import Pandas4Warning
9+
import pandas.util._test_decorators as td
710

11+
import pandas as pd
812
from pandas import (
913
NA,
1014
ArrowDtype,
1115
Series,
1216
StringDtype,
1317
)
1418
import pandas._testing as tm
19+
from pandas.core.construction import extract_array
20+
21+
22+
def string_dtype_highest_priority(dtype1, dtype2):
23+
if HAS_PYARROW:
24+
DTYPE_HIERARCHY = [
25+
StringDtype("python", na_value=np.nan),
26+
StringDtype("pyarrow", na_value=np.nan),
27+
StringDtype("python", na_value=NA),
28+
StringDtype("pyarrow", na_value=NA),
29+
]
30+
else:
31+
DTYPE_HIERARCHY = [
32+
StringDtype("python", na_value=np.nan),
33+
StringDtype("python", na_value=NA),
34+
]
35+
36+
h1 = DTYPE_HIERARCHY.index(dtype1)
37+
h2 = DTYPE_HIERARCHY.index(dtype2)
38+
return DTYPE_HIERARCHY[max(h1, h2)]
39+
40+
41+
def test_eq_all_na():
42+
pytest.importorskip("pyarrow")
43+
a = pd.array([NA, NA], dtype=StringDtype("pyarrow"))
44+
result = a == a
45+
expected = pd.array([NA, NA], dtype="boolean[pyarrow]")
46+
tm.assert_extension_array_equal(result, expected)
1547

1648

1749
def test_reversed_logical_ops(any_string_dtype):
@@ -134,3 +166,279 @@ def test_mul_bool_invalid(any_string_dtype):
134166
ser * np.array([True, False, True], dtype=bool)
135167
with pytest.raises(TypeError, match=msg):
136168
np.array([True, False, True], dtype=bool) * ser
169+
170+
171+
def test_add(any_string_dtype, request):
172+
dtype = any_string_dtype
173+
if dtype == object:
174+
mark = pytest.mark.xfail(
175+
reason="Need to update expected for numpy object dtype"
176+
)
177+
request.applymarker(mark)
178+
179+
a = Series(["a", "b", "c", None, None], dtype=dtype)
180+
b = Series(["x", "y", None, "z", None], dtype=dtype)
181+
182+
result = a + b
183+
expected = Series(["ax", "by", None, None, None], dtype=dtype)
184+
tm.assert_series_equal(result, expected)
185+
186+
result = a.add(b)
187+
tm.assert_series_equal(result, expected)
188+
189+
result = a.radd(b)
190+
expected = Series(["xa", "yb", None, None, None], dtype=dtype)
191+
tm.assert_series_equal(result, expected)
192+
193+
result = a.add(b, fill_value="-")
194+
expected = Series(["ax", "by", "c-", "-z", None], dtype=dtype)
195+
tm.assert_series_equal(result, expected)
196+
197+
198+
def test_add_2d(any_string_dtype, request):
199+
dtype = any_string_dtype
200+
201+
if dtype == object or dtype.storage == "pyarrow":
202+
reason = "Failed: DID NOT RAISE <class 'ValueError'>"
203+
mark = pytest.mark.xfail(raises=None, reason=reason)
204+
request.applymarker(mark)
205+
206+
a = pd.array(["a", "b", "c"], dtype=dtype)
207+
b = np.array([["a", "b", "c"]], dtype=object)
208+
with pytest.raises(ValueError, match="3 != 1"):
209+
a + b
210+
211+
s = Series(a)
212+
with pytest.raises(ValueError, match="3 != 1"):
213+
s + b
214+
215+
216+
def test_add_sequence(any_string_dtype, request):
217+
dtype = any_string_dtype
218+
if dtype == np.dtype(object):
219+
mark = pytest.mark.xfail(reason="Cannot broadcast list")
220+
request.applymarker(mark)
221+
222+
a = pd.array(["a", "b", None, None], dtype=dtype)
223+
other = ["x", None, "y", None]
224+
225+
result = a + other
226+
expected = pd.array(["ax", None, None, None], dtype=dtype)
227+
tm.assert_extension_array_equal(result, expected)
228+
229+
result = other + a
230+
expected = pd.array(["xa", None, None, None], dtype=dtype)
231+
tm.assert_extension_array_equal(result, expected)
232+
233+
234+
def test_mul(any_string_dtype):
235+
dtype = any_string_dtype
236+
a = pd.array(["a", "b", None], dtype=dtype)
237+
result = a * 2
238+
expected = pd.array(["aa", "bb", None], dtype=dtype)
239+
tm.assert_extension_array_equal(result, expected)
240+
241+
result = 2 * a
242+
tm.assert_extension_array_equal(result, expected)
243+
244+
245+
def test_add_strings(any_string_dtype, request):
246+
dtype = any_string_dtype
247+
if dtype != np.dtype(object):
248+
mark = pytest.mark.xfail(reason="GH-28527")
249+
request.applymarker(mark)
250+
arr = pd.array(["a", "b", "c", "d"], dtype=dtype)
251+
df = pd.DataFrame([["t", "y", "v", "w"]], dtype=object)
252+
assert arr.__add__(df) is NotImplemented
253+
254+
result = arr + df
255+
expected = pd.DataFrame([["at", "by", "cv", "dw"]]).astype(dtype)
256+
tm.assert_frame_equal(result, expected)
257+
258+
result = df + arr
259+
expected = pd.DataFrame([["ta", "yb", "vc", "wd"]]).astype(dtype)
260+
tm.assert_frame_equal(result, expected)
261+
262+
263+
@pytest.mark.xfail(reason="GH-28527")
264+
def test_add_frame(dtype):
265+
arr = pd.array(["a", "b", np.nan, np.nan], dtype=dtype)
266+
df = pd.DataFrame([["x", np.nan, "y", np.nan]])
267+
268+
assert arr.__add__(df) is NotImplemented
269+
270+
result = arr + df
271+
expected = pd.DataFrame([["ax", np.nan, np.nan, np.nan]]).astype(dtype)
272+
tm.assert_frame_equal(result, expected)
273+
274+
result = df + arr
275+
expected = pd.DataFrame([["xa", np.nan, np.nan, np.nan]]).astype(dtype)
276+
tm.assert_frame_equal(result, expected)
277+
278+
279+
def test_comparison_methods_scalar(comparison_op, any_string_dtype):
280+
dtype = any_string_dtype
281+
op_name = f"__{comparison_op.__name__}__"
282+
a = pd.array(["a", None, "c"], dtype=dtype)
283+
other = "a"
284+
result = getattr(a, op_name)(other)
285+
if dtype == object or dtype.na_value is np.nan:
286+
expected = np.array([getattr(item, op_name)(other) for item in a])
287+
if comparison_op == operator.ne:
288+
expected[1] = True
289+
else:
290+
expected[1] = False
291+
result = extract_array(result, extract_numpy=True)
292+
tm.assert_numpy_array_equal(result, expected.astype(np.bool_))
293+
else:
294+
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
295+
expected = np.array([getattr(item, op_name)(other) for item in a], dtype=object)
296+
expected = pd.array(expected, dtype=expected_dtype)
297+
tm.assert_extension_array_equal(result, expected)
298+
299+
300+
def test_comparison_methods_scalar_pd_na(comparison_op, any_string_dtype):
301+
dtype = any_string_dtype
302+
op_name = f"__{comparison_op.__name__}__"
303+
a = pd.array(["a", None, "c"], dtype=dtype)
304+
result = getattr(a, op_name)(NA)
305+
306+
if dtype == np.dtype(object) or dtype.na_value is np.nan:
307+
if operator.ne == comparison_op:
308+
expected = np.array([True, True, True])
309+
else:
310+
expected = np.array([False, False, False])
311+
result = extract_array(result, extract_numpy=True)
312+
tm.assert_numpy_array_equal(result, expected)
313+
else:
314+
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
315+
expected = pd.array([None, None, None], dtype=expected_dtype)
316+
tm.assert_extension_array_equal(result, expected)
317+
tm.assert_extension_array_equal(result, expected)
318+
319+
320+
def test_comparison_methods_scalar_not_string(comparison_op, any_string_dtype):
321+
op_name = f"__{comparison_op.__name__}__"
322+
dtype = any_string_dtype
323+
324+
a = pd.array(["a", None, "c"], dtype=dtype)
325+
other = 42
326+
327+
if op_name not in ["__eq__", "__ne__"]:
328+
with pytest.raises(TypeError, match="Invalid comparison|not supported between"):
329+
getattr(a, op_name)(other)
330+
331+
return
332+
333+
result = getattr(a, op_name)(other)
334+
result = extract_array(result, extract_numpy=True)
335+
336+
if dtype == np.dtype(object) or dtype.na_value is np.nan:
337+
expected_data = {
338+
"__eq__": [False, False, False],
339+
"__ne__": [True, True, True],
340+
}[op_name]
341+
expected = np.array(expected_data)
342+
tm.assert_numpy_array_equal(result, expected)
343+
else:
344+
expected_data = {"__eq__": [False, None, False], "__ne__": [True, None, True]}[
345+
op_name
346+
]
347+
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
348+
expected = pd.array(expected_data, dtype=expected_dtype)
349+
tm.assert_extension_array_equal(result, expected)
350+
351+
352+
def test_comparison_methods_array(comparison_op, any_string_dtype, any_string_dtype2):
353+
op_name = f"__{comparison_op.__name__}__"
354+
dtype = any_string_dtype
355+
dtype2 = any_string_dtype2
356+
357+
a = pd.array(["a", None, "c"], dtype=dtype)
358+
other = pd.array([None, None, "c"], dtype=dtype2)
359+
result = comparison_op(a, other)
360+
result = extract_array(result, extract_numpy=True)
361+
362+
# ensure operation is commutative
363+
result2 = comparison_op(other, a)
364+
result2 = extract_array(result2, extract_numpy=True)
365+
tm.assert_equal(result, result2)
366+
367+
if (dtype == object or dtype.na_value is np.nan) and (
368+
dtype2 == object or dtype2.na_value is np.nan
369+
):
370+
if operator.ne == comparison_op:
371+
expected = np.array([True, True, False])
372+
else:
373+
expected = np.array([False, False, False])
374+
expected[-1] = getattr(other[-1], op_name)(a[-1])
375+
result = extract_array(result, extract_numpy=True)
376+
tm.assert_numpy_array_equal(result, expected)
377+
378+
else:
379+
if dtype == object:
380+
max_dtype = dtype2
381+
elif dtype2 == object:
382+
max_dtype = dtype
383+
else:
384+
max_dtype = string_dtype_highest_priority(dtype, dtype2)
385+
if max_dtype.storage == "python":
386+
expected_dtype = "boolean"
387+
else:
388+
expected_dtype = "bool[pyarrow]"
389+
390+
expected = np.full(len(a), fill_value=None, dtype="object")
391+
expected[-1] = getattr(other[-1], op_name)(a[-1])
392+
expected = pd.array(expected, dtype=expected_dtype)
393+
tm.assert_equal(result, expected)
394+
395+
396+
@td.skip_if_no("pyarrow")
397+
def test_comparison_methods_array_arrow_extension(comparison_op, any_string_dtype):
398+
# Test pd.ArrowDtype(pa.string()) against other string arrays
399+
import pyarrow as pa
400+
401+
dtype2 = any_string_dtype
402+
403+
op_name = f"__{comparison_op.__name__}__"
404+
dtype = ArrowDtype(pa.string())
405+
a = pd.array(["a", None, "c"], dtype=dtype)
406+
other = pd.array([None, None, "c"], dtype=dtype2)
407+
result = comparison_op(a, other)
408+
409+
# ensure operation is commutative
410+
result2 = comparison_op(other, a)
411+
tm.assert_equal(result, result2)
412+
413+
expected = pd.array([None, None, True], dtype="bool[pyarrow]")
414+
expected[-1] = getattr(other[-1], op_name)(a[-1])
415+
tm.assert_extension_array_equal(result, expected)
416+
417+
418+
def test_comparison_methods_list(comparison_op, any_string_dtype):
419+
dtype = any_string_dtype
420+
op_name = f"__{comparison_op.__name__}__"
421+
422+
a = pd.array(["a", None, "c"], dtype=dtype)
423+
other = [None, None, "c"]
424+
result = comparison_op(a, other)
425+
426+
# ensure operation is commutative
427+
result2 = comparison_op(other, a)
428+
tm.assert_equal(result, result2)
429+
430+
if dtype == object or dtype.na_value is np.nan:
431+
if operator.ne == comparison_op:
432+
expected = np.array([True, True, False])
433+
else:
434+
expected = np.array([False, False, False])
435+
expected[-1] = getattr(other[-1], op_name)(a[-1])
436+
result = extract_array(result, extract_numpy=True)
437+
tm.assert_numpy_array_equal(result, expected)
438+
439+
else:
440+
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
441+
expected = np.full(len(a), fill_value=None, dtype="object")
442+
expected[-1] = getattr(other[-1], op_name)(a[-1])
443+
expected = pd.array(expected, dtype=expected_dtype)
444+
tm.assert_extension_array_equal(result, expected)

0 commit comments

Comments
 (0)