Skip to content

Commit d55fa6a

Browse files
committed
Preserve functionality for return_df arg
1 parent 72cd436 commit d55fa6a

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

boruta/boruta_py.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from sklearn.base import BaseEstimator
1616
from sklearn.feature_selection import SelectorMixin
1717
from sklearn.utils.validation import check_is_fitted
18+
from sklearn.utils._set_output import _get_output_config
1819
import warnings
1920

2021

@@ -240,7 +241,7 @@ def transform(self, X, weak=None, return_df=None):
240241
weak : boolean, optional
241242
Deprecated. Set ``weak`` in the constructor instead.
242243
243-
return_df : bool, optional
244+
return_df : boolean, optional
244245
Deprecated. Output type now follows scikit-learn's standard
245246
``set_output``/``set_config`` mechanism.
246247
"""
@@ -253,6 +254,9 @@ def transform(self, X, weak=None, return_df=None):
253254
stacklevel=2,
254255
)
255256
self.weak = weak
257+
requested_transform = None
258+
prev_output_config = None
259+
force_numpy = return_df is False
256260
if return_df is not None:
257261
warnings.warn(
258262
"`return_df` is deprecated and will be removed in a future "
@@ -261,11 +265,20 @@ def transform(self, X, weak=None, return_df=None):
261265
FutureWarning,
262266
stacklevel=2,
263267
)
268+
prev_output_config = _get_output_config("transform", estimator=self)["dense"]
269+
requested_transform = "pandas" if return_df else "default"
270+
if prev_output_config != requested_transform:
271+
self.set_output(transform=requested_transform)
264272
try:
265-
return super().transform(X)
273+
result = super().transform(X)
266274
finally:
267275
if weak is not None:
268276
self.weak = prev_weak
277+
if requested_transform is not None and prev_output_config != requested_transform:
278+
self.set_output(transform=prev_output_config)
279+
if force_numpy and hasattr(result, "to_numpy"):
280+
result = result.to_numpy()
281+
return result
269282

270283
def fit_transform(self, X, y=None, **fit_params):
271284
"""

boruta/test/test_boruta.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,33 @@ def test_return_df_parameter_emits_warning(Xy):
8585
assert isinstance(transformed, pd.DataFrame)
8686

8787

88+
def test_return_df_true_temporarily_enables_pandas_output(Xy):
89+
X, y = Xy
90+
bt = BorutaPy(RandomForestClassifier())
91+
bt.fit(X, y)
92+
93+
baseline = bt.transform(X)
94+
assert isinstance(baseline, np.ndarray)
95+
96+
with pytest.warns(FutureWarning, match="`return_df` is deprecated"):
97+
transformed = bt.transform(X, return_df=True)
98+
assert isinstance(transformed, pd.DataFrame)
99+
100+
reverted = bt.transform(X)
101+
assert isinstance(reverted, np.ndarray)
102+
103+
104+
def test_return_df_false_with_dataframe_input_returns_numpy(Xy):
105+
X, y = Xy
106+
X_df = pd.DataFrame(X)
107+
bt = BorutaPy(RandomForestClassifier())
108+
bt.fit(X_df, y)
109+
110+
with pytest.warns(FutureWarning, match="`return_df` is deprecated"):
111+
transformed = bt.transform(X_df, return_df=False)
112+
assert isinstance(transformed, np.ndarray)
113+
114+
88115
def test_weak_attribute_controls_support_mask(Xy):
89116
X, y = Xy
90117
bt = BorutaPy(RandomForestClassifier(), weak=True)

0 commit comments

Comments
 (0)