|
| 1 | +import operator |
1 | 2 | from pathlib import Path |
2 | 3 |
|
3 | 4 | import numpy as np |
4 | 5 | import pytest |
5 | 6 |
|
| 7 | +from pandas.compat import HAS_PYARROW |
6 | 8 | from pandas.errors import Pandas4Warning |
| 9 | +import pandas.util._test_decorators as td |
7 | 10 |
|
| 11 | +import pandas as pd |
8 | 12 | from pandas import ( |
9 | 13 | NA, |
10 | 14 | ArrowDtype, |
11 | 15 | Series, |
12 | 16 | StringDtype, |
13 | 17 | ) |
14 | 18 | import pandas._testing as tm |
| 19 | +from pandas.core.construction import extract_array |
| 20 | + |
| 21 | + |
| 22 | +def string_dtype_highest_priority(dtype1, dtype2): |
| 23 | + if HAS_PYARROW: |
| 24 | + DTYPE_HIERARCHY = [ |
| 25 | + StringDtype("python", na_value=np.nan), |
| 26 | + StringDtype("pyarrow", na_value=np.nan), |
| 27 | + StringDtype("python", na_value=NA), |
| 28 | + StringDtype("pyarrow", na_value=NA), |
| 29 | + ] |
| 30 | + else: |
| 31 | + DTYPE_HIERARCHY = [ |
| 32 | + StringDtype("python", na_value=np.nan), |
| 33 | + StringDtype("python", na_value=NA), |
| 34 | + ] |
| 35 | + |
| 36 | + h1 = DTYPE_HIERARCHY.index(dtype1) |
| 37 | + h2 = DTYPE_HIERARCHY.index(dtype2) |
| 38 | + return DTYPE_HIERARCHY[max(h1, h2)] |
| 39 | + |
| 40 | + |
| 41 | +def test_eq_all_na(): |
| 42 | + pytest.importorskip("pyarrow") |
| 43 | + a = pd.array([NA, NA], dtype=StringDtype("pyarrow")) |
| 44 | + result = a == a |
| 45 | + expected = pd.array([NA, NA], dtype="boolean[pyarrow]") |
| 46 | + tm.assert_extension_array_equal(result, expected) |
15 | 47 |
|
16 | 48 |
|
17 | 49 | def test_reversed_logical_ops(any_string_dtype): |
@@ -134,3 +166,279 @@ def test_mul_bool_invalid(any_string_dtype): |
134 | 166 | ser * np.array([True, False, True], dtype=bool) |
135 | 167 | with pytest.raises(TypeError, match=msg): |
136 | 168 | np.array([True, False, True], dtype=bool) * ser |
| 169 | + |
| 170 | + |
| 171 | +def test_add(any_string_dtype, request): |
| 172 | + dtype = any_string_dtype |
| 173 | + if dtype == object: |
| 174 | + mark = pytest.mark.xfail( |
| 175 | + reason="Need to update expected for numpy object dtype" |
| 176 | + ) |
| 177 | + request.applymarker(mark) |
| 178 | + |
| 179 | + a = Series(["a", "b", "c", None, None], dtype=dtype) |
| 180 | + b = Series(["x", "y", None, "z", None], dtype=dtype) |
| 181 | + |
| 182 | + result = a + b |
| 183 | + expected = Series(["ax", "by", None, None, None], dtype=dtype) |
| 184 | + tm.assert_series_equal(result, expected) |
| 185 | + |
| 186 | + result = a.add(b) |
| 187 | + tm.assert_series_equal(result, expected) |
| 188 | + |
| 189 | + result = a.radd(b) |
| 190 | + expected = Series(["xa", "yb", None, None, None], dtype=dtype) |
| 191 | + tm.assert_series_equal(result, expected) |
| 192 | + |
| 193 | + result = a.add(b, fill_value="-") |
| 194 | + expected = Series(["ax", "by", "c-", "-z", None], dtype=dtype) |
| 195 | + tm.assert_series_equal(result, expected) |
| 196 | + |
| 197 | + |
| 198 | +def test_add_2d(any_string_dtype, request): |
| 199 | + dtype = any_string_dtype |
| 200 | + |
| 201 | + if dtype == object or dtype.storage == "pyarrow": |
| 202 | + reason = "Failed: DID NOT RAISE <class 'ValueError'>" |
| 203 | + mark = pytest.mark.xfail(raises=None, reason=reason) |
| 204 | + request.applymarker(mark) |
| 205 | + |
| 206 | + a = pd.array(["a", "b", "c"], dtype=dtype) |
| 207 | + b = np.array([["a", "b", "c"]], dtype=object) |
| 208 | + with pytest.raises(ValueError, match="3 != 1"): |
| 209 | + a + b |
| 210 | + |
| 211 | + s = Series(a) |
| 212 | + with pytest.raises(ValueError, match="3 != 1"): |
| 213 | + s + b |
| 214 | + |
| 215 | + |
| 216 | +def test_add_sequence(any_string_dtype, request): |
| 217 | + dtype = any_string_dtype |
| 218 | + if dtype == np.dtype(object): |
| 219 | + mark = pytest.mark.xfail(reason="Cannot broadcast list") |
| 220 | + request.applymarker(mark) |
| 221 | + |
| 222 | + a = pd.array(["a", "b", None, None], dtype=dtype) |
| 223 | + other = ["x", None, "y", None] |
| 224 | + |
| 225 | + result = a + other |
| 226 | + expected = pd.array(["ax", None, None, None], dtype=dtype) |
| 227 | + tm.assert_extension_array_equal(result, expected) |
| 228 | + |
| 229 | + result = other + a |
| 230 | + expected = pd.array(["xa", None, None, None], dtype=dtype) |
| 231 | + tm.assert_extension_array_equal(result, expected) |
| 232 | + |
| 233 | + |
| 234 | +def test_mul(any_string_dtype): |
| 235 | + dtype = any_string_dtype |
| 236 | + a = pd.array(["a", "b", None], dtype=dtype) |
| 237 | + result = a * 2 |
| 238 | + expected = pd.array(["aa", "bb", None], dtype=dtype) |
| 239 | + tm.assert_extension_array_equal(result, expected) |
| 240 | + |
| 241 | + result = 2 * a |
| 242 | + tm.assert_extension_array_equal(result, expected) |
| 243 | + |
| 244 | + |
| 245 | +def test_add_strings(any_string_dtype, request): |
| 246 | + dtype = any_string_dtype |
| 247 | + if dtype != np.dtype(object): |
| 248 | + mark = pytest.mark.xfail(reason="GH-28527") |
| 249 | + request.applymarker(mark) |
| 250 | + arr = pd.array(["a", "b", "c", "d"], dtype=dtype) |
| 251 | + df = pd.DataFrame([["t", "y", "v", "w"]], dtype=object) |
| 252 | + assert arr.__add__(df) is NotImplemented |
| 253 | + |
| 254 | + result = arr + df |
| 255 | + expected = pd.DataFrame([["at", "by", "cv", "dw"]]).astype(dtype) |
| 256 | + tm.assert_frame_equal(result, expected) |
| 257 | + |
| 258 | + result = df + arr |
| 259 | + expected = pd.DataFrame([["ta", "yb", "vc", "wd"]]).astype(dtype) |
| 260 | + tm.assert_frame_equal(result, expected) |
| 261 | + |
| 262 | + |
| 263 | +@pytest.mark.xfail(reason="GH-28527") |
| 264 | +def test_add_frame(dtype): |
| 265 | + arr = pd.array(["a", "b", np.nan, np.nan], dtype=dtype) |
| 266 | + df = pd.DataFrame([["x", np.nan, "y", np.nan]]) |
| 267 | + |
| 268 | + assert arr.__add__(df) is NotImplemented |
| 269 | + |
| 270 | + result = arr + df |
| 271 | + expected = pd.DataFrame([["ax", np.nan, np.nan, np.nan]]).astype(dtype) |
| 272 | + tm.assert_frame_equal(result, expected) |
| 273 | + |
| 274 | + result = df + arr |
| 275 | + expected = pd.DataFrame([["xa", np.nan, np.nan, np.nan]]).astype(dtype) |
| 276 | + tm.assert_frame_equal(result, expected) |
| 277 | + |
| 278 | + |
| 279 | +def test_comparison_methods_scalar(comparison_op, any_string_dtype): |
| 280 | + dtype = any_string_dtype |
| 281 | + op_name = f"__{comparison_op.__name__}__" |
| 282 | + a = pd.array(["a", None, "c"], dtype=dtype) |
| 283 | + other = "a" |
| 284 | + result = getattr(a, op_name)(other) |
| 285 | + if dtype == object or dtype.na_value is np.nan: |
| 286 | + expected = np.array([getattr(item, op_name)(other) for item in a]) |
| 287 | + if comparison_op == operator.ne: |
| 288 | + expected[1] = True |
| 289 | + else: |
| 290 | + expected[1] = False |
| 291 | + result = extract_array(result, extract_numpy=True) |
| 292 | + tm.assert_numpy_array_equal(result, expected.astype(np.bool_)) |
| 293 | + else: |
| 294 | + expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean" |
| 295 | + expected = np.array([getattr(item, op_name)(other) for item in a], dtype=object) |
| 296 | + expected = pd.array(expected, dtype=expected_dtype) |
| 297 | + tm.assert_extension_array_equal(result, expected) |
| 298 | + |
| 299 | + |
| 300 | +def test_comparison_methods_scalar_pd_na(comparison_op, any_string_dtype): |
| 301 | + dtype = any_string_dtype |
| 302 | + op_name = f"__{comparison_op.__name__}__" |
| 303 | + a = pd.array(["a", None, "c"], dtype=dtype) |
| 304 | + result = getattr(a, op_name)(NA) |
| 305 | + |
| 306 | + if dtype == np.dtype(object) or dtype.na_value is np.nan: |
| 307 | + if operator.ne == comparison_op: |
| 308 | + expected = np.array([True, True, True]) |
| 309 | + else: |
| 310 | + expected = np.array([False, False, False]) |
| 311 | + result = extract_array(result, extract_numpy=True) |
| 312 | + tm.assert_numpy_array_equal(result, expected) |
| 313 | + else: |
| 314 | + expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean" |
| 315 | + expected = pd.array([None, None, None], dtype=expected_dtype) |
| 316 | + tm.assert_extension_array_equal(result, expected) |
| 317 | + tm.assert_extension_array_equal(result, expected) |
| 318 | + |
| 319 | + |
| 320 | +def test_comparison_methods_scalar_not_string(comparison_op, any_string_dtype): |
| 321 | + op_name = f"__{comparison_op.__name__}__" |
| 322 | + dtype = any_string_dtype |
| 323 | + |
| 324 | + a = pd.array(["a", None, "c"], dtype=dtype) |
| 325 | + other = 42 |
| 326 | + |
| 327 | + if op_name not in ["__eq__", "__ne__"]: |
| 328 | + with pytest.raises(TypeError, match="Invalid comparison|not supported between"): |
| 329 | + getattr(a, op_name)(other) |
| 330 | + |
| 331 | + return |
| 332 | + |
| 333 | + result = getattr(a, op_name)(other) |
| 334 | + result = extract_array(result, extract_numpy=True) |
| 335 | + |
| 336 | + if dtype == np.dtype(object) or dtype.na_value is np.nan: |
| 337 | + expected_data = { |
| 338 | + "__eq__": [False, False, False], |
| 339 | + "__ne__": [True, True, True], |
| 340 | + }[op_name] |
| 341 | + expected = np.array(expected_data) |
| 342 | + tm.assert_numpy_array_equal(result, expected) |
| 343 | + else: |
| 344 | + expected_data = {"__eq__": [False, None, False], "__ne__": [True, None, True]}[ |
| 345 | + op_name |
| 346 | + ] |
| 347 | + expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean" |
| 348 | + expected = pd.array(expected_data, dtype=expected_dtype) |
| 349 | + tm.assert_extension_array_equal(result, expected) |
| 350 | + |
| 351 | + |
| 352 | +def test_comparison_methods_array(comparison_op, any_string_dtype, any_string_dtype2): |
| 353 | + op_name = f"__{comparison_op.__name__}__" |
| 354 | + dtype = any_string_dtype |
| 355 | + dtype2 = any_string_dtype2 |
| 356 | + |
| 357 | + a = pd.array(["a", None, "c"], dtype=dtype) |
| 358 | + other = pd.array([None, None, "c"], dtype=dtype2) |
| 359 | + result = comparison_op(a, other) |
| 360 | + result = extract_array(result, extract_numpy=True) |
| 361 | + |
| 362 | + # ensure operation is commutative |
| 363 | + result2 = comparison_op(other, a) |
| 364 | + result2 = extract_array(result2, extract_numpy=True) |
| 365 | + tm.assert_equal(result, result2) |
| 366 | + |
| 367 | + if (dtype == object or dtype.na_value is np.nan) and ( |
| 368 | + dtype2 == object or dtype2.na_value is np.nan |
| 369 | + ): |
| 370 | + if operator.ne == comparison_op: |
| 371 | + expected = np.array([True, True, False]) |
| 372 | + else: |
| 373 | + expected = np.array([False, False, False]) |
| 374 | + expected[-1] = getattr(other[-1], op_name)(a[-1]) |
| 375 | + result = extract_array(result, extract_numpy=True) |
| 376 | + tm.assert_numpy_array_equal(result, expected) |
| 377 | + |
| 378 | + else: |
| 379 | + if dtype == object: |
| 380 | + max_dtype = dtype2 |
| 381 | + elif dtype2 == object: |
| 382 | + max_dtype = dtype |
| 383 | + else: |
| 384 | + max_dtype = string_dtype_highest_priority(dtype, dtype2) |
| 385 | + if max_dtype.storage == "python": |
| 386 | + expected_dtype = "boolean" |
| 387 | + else: |
| 388 | + expected_dtype = "bool[pyarrow]" |
| 389 | + |
| 390 | + expected = np.full(len(a), fill_value=None, dtype="object") |
| 391 | + expected[-1] = getattr(other[-1], op_name)(a[-1]) |
| 392 | + expected = pd.array(expected, dtype=expected_dtype) |
| 393 | + tm.assert_equal(result, expected) |
| 394 | + |
| 395 | + |
| 396 | +@td.skip_if_no("pyarrow") |
| 397 | +def test_comparison_methods_array_arrow_extension(comparison_op, any_string_dtype): |
| 398 | + # Test pd.ArrowDtype(pa.string()) against other string arrays |
| 399 | + import pyarrow as pa |
| 400 | + |
| 401 | + dtype2 = any_string_dtype |
| 402 | + |
| 403 | + op_name = f"__{comparison_op.__name__}__" |
| 404 | + dtype = ArrowDtype(pa.string()) |
| 405 | + a = pd.array(["a", None, "c"], dtype=dtype) |
| 406 | + other = pd.array([None, None, "c"], dtype=dtype2) |
| 407 | + result = comparison_op(a, other) |
| 408 | + |
| 409 | + # ensure operation is commutative |
| 410 | + result2 = comparison_op(other, a) |
| 411 | + tm.assert_equal(result, result2) |
| 412 | + |
| 413 | + expected = pd.array([None, None, True], dtype="bool[pyarrow]") |
| 414 | + expected[-1] = getattr(other[-1], op_name)(a[-1]) |
| 415 | + tm.assert_extension_array_equal(result, expected) |
| 416 | + |
| 417 | + |
| 418 | +def test_comparison_methods_list(comparison_op, any_string_dtype): |
| 419 | + dtype = any_string_dtype |
| 420 | + op_name = f"__{comparison_op.__name__}__" |
| 421 | + |
| 422 | + a = pd.array(["a", None, "c"], dtype=dtype) |
| 423 | + other = [None, None, "c"] |
| 424 | + result = comparison_op(a, other) |
| 425 | + |
| 426 | + # ensure operation is commutative |
| 427 | + result2 = comparison_op(other, a) |
| 428 | + tm.assert_equal(result, result2) |
| 429 | + |
| 430 | + if dtype == object or dtype.na_value is np.nan: |
| 431 | + if operator.ne == comparison_op: |
| 432 | + expected = np.array([True, True, False]) |
| 433 | + else: |
| 434 | + expected = np.array([False, False, False]) |
| 435 | + expected[-1] = getattr(other[-1], op_name)(a[-1]) |
| 436 | + result = extract_array(result, extract_numpy=True) |
| 437 | + tm.assert_numpy_array_equal(result, expected) |
| 438 | + |
| 439 | + else: |
| 440 | + expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean" |
| 441 | + expected = np.full(len(a), fill_value=None, dtype="object") |
| 442 | + expected[-1] = getattr(other[-1], op_name)(a[-1]) |
| 443 | + expected = pd.array(expected, dtype=expected_dtype) |
| 444 | + tm.assert_extension_array_equal(result, expected) |
0 commit comments