Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ repos:
rev: v0.12.7
hooks:
- id: ruff
args: [--exit-non-zero-on-fix]
args: [--exit-non-zero-on-fix --show-fixes]
exclude: ^pandas/tests/frame/test_query_eval.py
- id: ruff
# TODO: remove autofix only rules when they are checked by ruff
Expand Down
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,14 @@ Other enhancements
- Added support to read and write from and to Apache Iceberg tables with the new :func:`read_iceberg` and :meth:`DataFrame.to_iceberg` functions (:issue:`61383`)
- Errors occurring during SQL I/O will now throw a generic :class:`.DatabaseError` instead of the raw Exception type from the underlying driver manager library (:issue:`60748`)
- Implemented :meth:`Series.str.isascii` and :meth:`Series.str.isascii` (:issue:`59091`)
- Improve the resulting dtypes in :meth:`DataFrame.where` and :meth:`DataFrame.mask` with :class:`ExtensionDtype` ``other`` (:issue:`62038`)
- Improved deprecation message for offset aliases (:issue:`60820`)
- Multiplying two :class:`DateOffset` objects will now raise a ``TypeError`` instead of a ``RecursionError`` (:issue:`59442`)
- Restore support for reading Stata 104-format and enable reading 103-format dta files (:issue:`58554`)
- Support passing a :class:`Iterable[Hashable]` input to :meth:`DataFrame.drop_duplicates` (:issue:`59237`)
- Support reading Stata 102-format (Stata 1) dta files (:issue:`58978`)
- Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`)
-

.. ---------------------------------------------------------------------------
.. _whatsnew_300.notable_bug_fixes:
Expand Down
42 changes: 35 additions & 7 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9788,14 +9788,42 @@ def _where(
raise InvalidIndexError

if other.ndim < self.ndim:
# TODO(EA2D): avoid object-dtype cast in EA case GH#38729
other = other._values
if axis == 0:
other = np.reshape(other, (-1, 1))
elif axis == 1:
other = np.reshape(other, (1, -1))

other = np.broadcast_to(other, self.shape)
if isinstance(other, np.ndarray):
# TODO(EA2D): could also do this for NDArrayBackedEA cases?
if axis == 0:
other = np.reshape(other, (-1, 1))
elif axis == 1:
other = np.reshape(other, (1, -1))

other = np.broadcast_to(other, self.shape)
else:
# GH#38729, GH#62038 avoid lossy casting or object-casting
if axis == 0:
res_cols = [
self.iloc[:, i]._where(
cond.iloc[:, i],
other,
)
for i in range(self.shape[1])
]
elif axis == 1:
# TODO: can we use a zero-copy alternative to "repeat"?
res_cols = [
self.iloc[:, i]._where(
cond.iloc[:, i],
other[i : i + 1].repeat(len(self)),
)
for i in range(self.shape[1])
]
res = self._constructor(
dict(enumerate(res_cols))
)
res.index = self.index
res.columns = self.columns
if inplace:
return self._update_inplace(res)
return res.__finalize__(self)

# slice me out of the other
else:
Expand Down
24 changes: 16 additions & 8 deletions pandas/tests/frame/indexing/test_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,22 +698,30 @@ def test_where_categorical_filtering(self):
tm.assert_equal(result, expected)

def test_where_ea_other(self):
# GH#38729/GH#38742
# GH#38729/GH#38742, GH#62038
df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
arr = pd.array([7, pd.NA, 9])
ser = Series(arr)
mask = np.ones(df.shape, dtype=bool)
mask[1, :] = False

# TODO: ideally we would get Int64 instead of object
result = df.where(mask, ser, axis=0)
expected = DataFrame({"A": [1, np.nan, 3], "B": [4, np.nan, 6]})
tm.assert_frame_equal(result, expected)
result1 = df.where(mask, ser, axis=0)
expected1 = DataFrame({"A": [1, pd.NA, 3], "B": [4, pd.NA, 6]}, dtype="Int64")
tm.assert_frame_equal(result1, expected1)

ser2 = Series(arr[:2], index=["A", "B"])
expected = DataFrame({"A": [1, 7, 3], "B": [4, np.nan, 6]})
result = df.where(mask, ser2, axis=1)
tm.assert_frame_equal(result, expected)
expected2 = DataFrame({"A": [1, 7, 3], "B": [4, pd.NA, 6]})
expected2["B"] = expected2["B"].astype("Int64")
result2 = df.where(mask, ser2, axis=1)
tm.assert_frame_equal(result2, expected2)

result3 = df.copy()
result3.mask(mask, ser, axis=0, inplace=True)
tm.assert_frame_equal(result3, expected1)

result4 = df.copy()
result4.mask(mask, ser2, axis=1, inplace=True)
tm.assert_frame_equal(result4, expected2)

def test_where_interval_noop(self):
# GH#44181
Expand Down
Loading