Skip to content

Commit dbec9f3

Browse files
committed
Deprecate weak and return_df args
1 parent dd7b020 commit dbec9f3

File tree

2 files changed

+92
-54
lines changed

2 files changed

+92
-54
lines changed

boruta/boruta_py.py

Lines changed: 42 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ class BorutaPy(BaseEstimator, SelectorMixin):
142142
The mask of selected tentative features, which haven't gained enough
143143
support during the max_iter number of iterations.
144144
145+
weak : bool, default=False
146+
147+
If set to true, the tentative features are also used to reduce X.
148+
145149
ranking_ : array of shape [n_features]
146150
147151
The feature ranking, such that ``ranking_[i]`` corresponds to the
@@ -194,7 +198,7 @@ class BorutaPy(BaseEstimator, SelectorMixin):
194198

195199
def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05,
196200
two_step=True, max_iter=100, random_state=None, verbose=0,
197-
early_stopping=False, n_iter_no_change=20):
201+
early_stopping=False, n_iter_no_change=20, weak: bool = False):
198202
self.estimator = estimator
199203
self.n_estimators = n_estimators
200204
self.perc = perc
@@ -207,8 +211,9 @@ def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05,
207211
self.n_iter_no_change = n_iter_no_change
208212
self.__version__ = '0.3'
209213
self._is_lightgbm = 'lightgbm' in str(type(self.estimator))
214+
self.weak = weak
210215

211-
def fit(self, X, y):
216+
def fit(self, X, y, **fit_params):
212217
"""
213218
Fits the Boruta feature selection with the provided estimator.
214219
@@ -223,7 +228,7 @@ def fit(self, X, y):
223228

224229
return self._fit(X, y)
225230

226-
def transform(self, X, weak=False, return_df=False):
231+
def transform(self, X, weak=None, return_df=None):
227232
"""
228233
Reduces the input X to the features selected by Boruta.
229234
@@ -232,23 +237,37 @@ def transform(self, X, weak=False, return_df=False):
232237
X : array-like, shape = [n_samples, n_features]
233238
The training input samples.
234239
235-
weak: boolean, default = False
236-
If set to true, the tentative features are also used to reduce X.
237-
238-
return_df : boolean, default = False
239-
If ``X`` if a pandas dataframe and this parameter is set to True,
240-
the transformed data will also be a dataframe.
240+
weak : boolean, optional
241+
Deprecated. Set ``weak`` in the constructor instead.
241242
242-
Returns
243-
-------
244-
X : array-like, shape = [n_samples, n_features_]
245-
The input matrix X's columns are reduced to the features which were
246-
selected by Boruta.
243+
return_df : bool, optional
244+
Deprecated. Output type now follows scikit-learn's standard
245+
``set_output``/``set_config`` mechanism.
247246
"""
247+
prev_weak = self.weak
248+
if weak is not None:
249+
warnings.warn(
250+
"`weak` is deprecated and will be removed in a future release. "
251+
"Set `weak` in the constructor instead.",
252+
FutureWarning,
253+
stacklevel=2,
254+
)
255+
self.weak = weak
256+
if return_df is not None:
257+
warnings.warn(
258+
"`return_df` is deprecated and will be removed in a future "
259+
"release. Use scikit-learn's `set_output(transform='pandas')` "
260+
"or `set_config(transform_output='pandas')` instead.",
261+
FutureWarning,
262+
stacklevel=2,
263+
)
264+
try:
265+
return super().transform(X)
266+
finally:
267+
if weak is not None:
268+
self.weak = prev_weak
248269

249-
return self._transform(X, weak, return_df)
250-
251-
def fit_transform(self, X, y, weak=False, return_df=False):
270+
def fit_transform(self, X, y=None, **fit_params):
252271
"""
253272
Fits Boruta, then reduces the input X to the selected features.
254273
@@ -259,23 +278,10 @@ def fit_transform(self, X, y, weak=False, return_df=False):
259278
260279
y : array-like, shape = [n_samples]
261280
The target values.
262-
263-
weak: boolean, default = False
264-
If set to true, the tentative features are also used to reduce X.
265-
266-
return_df : boolean, default = False
267-
If ``X`` if a pandas dataframe and this parameter is set to True,
268-
the transformed data will also be a dataframe.
269-
270-
Returns
271-
-------
272-
X : array-like, shape = [n_samples, n_features_]
273-
The input matrix X's columns are reduced to the features which were
274-
selected by Boruta.
275281
"""
276-
277-
self._fit(X, y)
278-
return self._transform(X, weak, return_df)
282+
weak = fit_params.pop("weak", None)
283+
return_df = fit_params.pop("return_df", None)
284+
return self.fit(X, y, **fit_params).transform(X, weak=weak, return_df=return_df)
279285

280286
def _validate_pandas_input(self, arg):
281287
try:
@@ -446,24 +452,6 @@ def _fit(self, X, y):
446452
self._print_results(dec_reg, _iter, 1)
447453
return self
448454

449-
def _transform(self, X, weak=False, return_df=False):
450-
# sanity check
451-
try:
452-
self.ranking_
453-
except AttributeError:
454-
raise ValueError('You need to call the fit(X, y) method first.')
455-
456-
if weak:
457-
indices = self.support_ + self.support_weak_
458-
else:
459-
indices = self.support_
460-
461-
if return_df:
462-
X = X.iloc[:, indices]
463-
else:
464-
X = X[:, indices]
465-
return X
466-
467455
def _set_n_estimators(self, n_estimators):
468456
try:
469457
self.estimator.set_params(n_estimators=n_estimators)
@@ -476,7 +464,9 @@ def _set_n_estimators(self, n_estimators):
476464
return self
477465

478466
def _get_support_mask(self):
479-
check_is_fitted(self, 'support_')
467+
check_is_fitted(self, ['support_', 'support_weak_'])
468+
if self.weak:
469+
return np.logical_or(self.support_, self.support_weak_)
480470
return self.support_
481471

482472
def _get_tree_num(self, n_feat):

boruta/test/test_boruta.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import re
2+
13
import numpy as np
24
import pandas as pd
35
import pytest
6+
from sklearn import config_context
47
from sklearn.ensemble import RandomForestClassifier
58
from sklearn.exceptions import NotFittedError
69
from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier
@@ -65,8 +68,53 @@ def test_dataframe_is_returned(Xy):
6568
X_df, y_df = pd.DataFrame(X), pd.Series(y)
6669
rfc = RandomForestClassifier()
6770
bt = BorutaPy(rfc)
68-
bt.fit(X_df, y_df)
69-
assert isinstance(bt.transform(X_df, return_df=True), pd.DataFrame)
71+
with config_context(transform_output="pandas"):
72+
bt.fit(X_df, y_df)
73+
transformed = bt.transform(X_df)
74+
assert isinstance(transformed, pd.DataFrame)
75+
76+
77+
def test_return_df_parameter_emits_warning(Xy):
78+
X, y = Xy
79+
X_df, y_df = pd.DataFrame(X), pd.Series(y)
80+
bt = BorutaPy(RandomForestClassifier())
81+
with config_context(transform_output="pandas"):
82+
bt.fit(X_df, y_df)
83+
with pytest.warns(FutureWarning, match=re.escape("`set_output(transform='pandas')`")):
84+
transformed = bt.transform(X_df, return_df=True)
85+
assert isinstance(transformed, pd.DataFrame)
86+
87+
88+
def test_weak_attribute_controls_support_mask(Xy):
89+
X, y = Xy
90+
bt = BorutaPy(RandomForestClassifier(), weak=True)
91+
bt.fit(X, y)
92+
93+
union_mask = bt.support_ | bt.support_weak_
94+
assert np.array_equal(bt.get_support(), union_mask)
95+
96+
97+
def test_transform_with_weak_parameter_is_deprecated(Xy):
98+
X, y = Xy
99+
bt = BorutaPy(RandomForestClassifier())
100+
bt.fit(X, y)
101+
bt.support_[5] = False
102+
bt.support_weak_[5] = True
103+
104+
with pytest.warns(FutureWarning, match=re.escape("`weak` is deprecated")):
105+
transformed = bt.transform(X, weak=True)
106+
107+
expected_features = np.count_nonzero(bt.support_ | bt.support_weak_)
108+
assert transformed.shape[1] == expected_features
109+
110+
111+
def test_fit_transform_with_weak_parameter_is_deprecated(Xy):
112+
X, y = Xy
113+
bt = BorutaPy(RandomForestClassifier())
114+
with pytest.warns(FutureWarning, match=re.escape("`weak` is deprecated")):
115+
transformed = bt.fit_transform(X, y, weak=True)
116+
expected_features = np.count_nonzero(bt.support_ | bt.support_weak_)
117+
assert transformed.shape[1] == expected_features
70118

71119

72120
def test_selector_mixin_get_support_requires_fit():

0 commit comments

Comments
 (0)