Skip to content

Commit a2f2862

Browse files
authored
BUG: Fix convert_dtype complex (#62960)
1 parent 2d73d62 commit a2f2862

File tree

4 files changed

+26
-4
lines changed

4 files changed

+26
-4
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,6 +1111,7 @@ Conversion
11111111
- Bug in :meth:`DataFrame.astype` not casting ``values`` for Arrow-based dictionary dtype correctly (:issue:`58479`)
11121112
- Bug in :meth:`DataFrame.update` bool dtype being converted to object (:issue:`55509`)
11131113
- Bug in :meth:`Series.astype` might modify read-only array inplace when casting to a string dtype (:issue:`57212`)
1114+
- Bug in :meth:`Series.convert_dtypes` and :meth:`DataFrame.convert_dtypes` raising ``TypeError`` when called on data with complex dtype (:issue:`60129`)
11141115
- Bug in :meth:`Series.convert_dtypes` and :meth:`DataFrame.convert_dtypes` removing timezone information for objects with :class:`ArrowDtype` (:issue:`60237`)
11151116
- Bug in :meth:`Series.reindex` not maintaining ``float32`` type when a ``reindex`` introduces a missing value (:issue:`45857`)
11161117
- Bug in :meth:`to_datetime` and :meth:`to_timedelta` with input ``None`` returning ``None`` instead of ``NaT``, inconsistent with other conversion methods (:issue:`23055`)

pandas/core/dtypes/cast.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,9 @@ def convert_dtypes(
934934
if (
935935
convert_string or convert_integer or convert_boolean or convert_floating
936936
) and isinstance(input_array, np.ndarray):
937+
if input_array.dtype.kind == "c":
938+
return input_array.dtype
939+
937940
if input_array.dtype == object:
938941
inferred_dtype = lib.infer_dtype(input_array)
939942
else:
@@ -954,7 +957,7 @@ def convert_dtypes(
954957
inferred_dtype = NUMPY_INT_TO_DTYPE.get(
955958
input_array.dtype, target_int_dtype
956959
)
957-
elif input_array.dtype.kind in "fcb":
960+
elif input_array.dtype.kind in "fb":
958961
# TODO: de-dup with maybe_cast_to_integer_array?
959962
arr = input_array[notna(input_array)]
960963
if len(arr) < len(input_array) and not is_nan_na():
@@ -972,7 +975,7 @@ def convert_dtypes(
972975
inferred_dtype = target_int_dtype
973976

974977
if convert_floating:
975-
if input_array.dtype.kind in "fcb":
978+
if input_array.dtype.kind in "fb":
976979
# i.e. numeric but not integer
977980
from pandas.core.arrays.floating import NUMPY_FLOAT_TO_DTYPE
978981

@@ -1028,11 +1031,11 @@ def convert_dtypes(
10281031

10291032
if (
10301033
(convert_integer and inferred_dtype.kind in "iu")
1031-
or (convert_floating and inferred_dtype.kind in "fc")
1034+
or (convert_floating and inferred_dtype.kind in "f")
10321035
or (convert_boolean and inferred_dtype.kind == "b")
10331036
or (convert_string and isinstance(inferred_dtype, StringDtype))
10341037
or (
1035-
inferred_dtype.kind not in "iufcb"
1038+
inferred_dtype.kind not in "iufb"
10361039
and not isinstance(inferred_dtype, StringDtype)
10371040
and not isinstance(inferred_dtype, CategoricalDtype)
10381041
)

pandas/tests/frame/methods/test_convert_dtypes.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,15 @@ def test_convert_dtype_pyarrow_timezone_preserve(self):
228228
result = df.convert_dtypes(dtype_backend="pyarrow")
229229
expected = df.copy()
230230
tm.assert_frame_equal(result, expected)
231+
232+
def test_convert_dtypes_complex(self):
233+
# GH 60129
234+
df = pd.DataFrame({"a": [1.0 + 5.0j, 1.5 - 3.0j], "b": [1, 2]})
235+
expected = pd.DataFrame(
236+
{
237+
"a": pd.array([1.0 + 5.0j, 1.5 - 3.0j], dtype="complex128"),
238+
"b": pd.array([1, 2], dtype="Int64"),
239+
}
240+
)
241+
result = df.convert_dtypes()
242+
tm.assert_frame_equal(result, expected)

pandas/tests/series/methods/test_convert_dtypes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,3 +332,9 @@ def test_convert_dtype_pyarrow_timezone_preserve(self):
332332
result = ser.convert_dtypes(dtype_backend="pyarrow")
333333
expected = ser.copy()
334334
tm.assert_series_equal(result, expected)
335+
336+
def test_convert_dtypes_complex(self):
337+
# GH 60129
338+
ser = pd.Series([1.5 + 3.0j, 1.5 - 3.0j])
339+
result = ser.convert_dtypes()
340+
tm.assert_series_equal(result, ser)

0 commit comments

Comments
 (0)