Skip to content
22 changes: 22 additions & 0 deletions pandas/core/dtypes/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def concat_compat(
-------
a single array, preserving the combined dtypes
"""

if len(to_concat) and lib.dtypes_all_equal([obj.dtype for obj in to_concat]):
# fastpath!
obj = to_concat[0]
Expand All @@ -92,6 +93,27 @@ def concat_compat(
to_concat_eas,
axis=axis, # type: ignore[call-arg]
)
# Special handling for categorical arrays solves #51362
if (
len(to_concat)
and all(isinstance(arr.dtype, CategoricalDtype) for arr in to_concat)
and axis == 0
):
# Filter out empty arrays before union, similar to non_empties logic
non_empty_categoricals = [x for x in to_concat if _is_nonempty(x, axis)]

if len(non_empty_categoricals) == 0:
# All arrays are empty, return the first one (they're all categorical)
return to_concat[0]
elif len(non_empty_categoricals) == 1:
# Only one non-empty array, return it directly
return non_empty_categoricals[0]
else:
# Multiple non-empty arrays, use union_categoricals
return union_categoricals(
non_empty_categoricals, sort_categories=True
) # Performance cost, but necessary to keep tests passing.
# see pandas/tests/reshape/concat/test_append_common.py:498

# If all arrays are empty, there's nothing to convert, just short-cut to
# the concatenation, #3121.
Expand Down
29 changes: 25 additions & 4 deletions pandas/tests/dtypes/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import pandas.core.dtypes.concat as _concat

import pandas as pd
from pandas import Series
from pandas import (
DataFrame,
Series,
)
import pandas._testing as tm


Expand All @@ -14,12 +17,12 @@ def test_concat_mismatched_categoricals_with_empty():

result = _concat.concat_compat([ser1._values, ser2._values])
expected = pd.concat([ser1, ser2])._values
tm.assert_numpy_array_equal(result, expected)
tm.assert_categorical_equal(result, expected)


def test_concat_single_dataframe_tz_aware():
# https://github.com/pandas-dev/pandas/issues/25257
df = pd.DataFrame(
df = DataFrame(
{"timestamp": [pd.Timestamp("2020-04-08 09:00:00.709949+0000", tz="UTC")]}
)
expected = df.copy()
Expand Down Expand Up @@ -53,7 +56,7 @@ def test_concat_series_between_empty_and_tzaware_series(using_infer_string):
ser2 = Series(dtype=float)

result = pd.concat([ser1, ser2], axis=1)
expected = pd.DataFrame(
expected = DataFrame(
data=[
(0.0, None),
],
Expand All @@ -64,3 +67,21 @@ def test_concat_series_between_empty_and_tzaware_series(using_infer_string):
dtype=float,
)
tm.assert_frame_equal(result, expected)


def test_concat_categorical_dataframes():
df = DataFrame({"a": [0, 1]}, dtype="category")
df2 = DataFrame({"a": [2, 3]}, dtype="category")

result = pd.concat([df, df2], axis=0)

assert result["a"].dtype.name == "category"


def test_concat_categorical_series():
ser = Series([0, 1], dtype="category")
ser2 = Series([2, 3], dtype="category")

result = pd.concat([ser, ser2], axis=0)

assert result.dtype.name == "category"
Loading