|
49 | 49 | LOSS_FUNCTIONS = { |
50 | 50 | "hinge": (Hinge,), |
51 | 51 | "log": (Log,), |
| 52 | + "squared_error": (SquaredLoss,), |
52 | 53 | "squared_loss": (SquaredLoss,), |
53 | 54 | "squared_hinge": (SquaredHinge,), |
54 | 55 | "modified_huber": (ModifiedHuber,), |
55 | 56 | "huber": (Huber, 1.35), # 1.35 is default value. TODO : set as parameter |
56 | 57 | } |
57 | 58 |
|
| 59 | +# Test version of sklearn, in version older than v1.0 squared_loss must be used |
| 60 | +import sklearn |
| 61 | + |
| 62 | +if sklearn.__version__[0] == "0": |
| 63 | + SQ_LOSS = "squared_loss" |
| 64 | +else: |
| 65 | + SQ_LOSS = "squared_error" |
| 66 | + |
58 | 67 |
|
59 | 68 | def _huber_psisx(x, c): |
60 | 69 | """Huber-loss weight for RobustWeightedEstimator algorithm""" |
@@ -107,7 +116,7 @@ class _RobustWeightedEstimator(BaseEstimator): |
107 | 116 | base_estimator. |
108 | 117 | Classification losses supported : 'log', 'hinge', 'squared_hinge', |
109 | 118 | 'modified_huber'. If 'log', then the base_estimator must support |
110 | | - predict_proba. Regression losses supported : 'squared_loss', 'huber'. |
| 119 | + predict_proba. Regression losses supported : 'squared_error', 'huber'. |
111 | 120 | If callable, the function is used as loss function ro construct |
112 | 121 | the weights. |
113 | 122 |
|
@@ -270,7 +279,7 @@ def fit(self, X, y=None): |
270 | 279 | if "warm_start" in parameters: |
271 | 280 | base_estimator.set_params(warm_start=True) |
272 | 281 |
|
273 | | - if "loss" in parameters: |
| 282 | + if ("loss" in parameters) and (loss_param != "squared_error"): |
274 | 283 | base_estimator.set_params(loss=loss_param) |
275 | 284 |
|
276 | 285 | if "eta0" in parameters: |
@@ -971,8 +980,8 @@ class RobustWeightedRegressor(BaseEstimator, RegressorMixin): |
971 | 980 | (using the inter-quartile range), this tends to be conservative |
972 | 981 | (robust). |
973 | 982 |
|
974 | | - loss : string, None or callable, default="squared_loss" |
975 | | - For now, only "squared_loss" and "huber" are implemented. |
| 983 | + loss : string, None or callable, default="squared_error" |
| 984 | + For now, only "squared_error" and "huber" are implemented. |
976 | 985 |
|
977 | 986 | sgd_args : dict, default={} |
978 | 987 | arguments of the SGDClassifier base estimator. |
@@ -1057,7 +1066,7 @@ def __init__( |
1057 | 1066 | eta0=0.01, |
1058 | 1067 | c=None, |
1059 | 1068 | k=0, |
1060 | | - loss="squared_loss", |
| 1069 | + loss=SQ_LOSS, |
1061 | 1070 | sgd_args=None, |
1062 | 1071 | tol=1e-3, |
1063 | 1072 | n_iter_no_change=10, |
|
0 commit comments