1010from sklearn .dummy import DummyRegressor as SklearnDummyRegressor
1111from sklearn .utils import check_random_state
1212from sklearn .utils .multiclass import check_classification_targets
13- from sklearn .utils .validation import check_is_fitted
13+ from sklearn .utils .validation import _num_samples , check_is_fitted
1414
1515from tsml .base import BaseTimeSeriesEstimator
1616
@@ -85,65 +85,77 @@ class prior probabilities.
8585 0.5
8686 """
8787
88- def __init__ (self , strategy = "prior" , random_state = None , constant = None ):
88+ def __init__ (
89+ self , strategy = "prior" , validate = False , random_state = None , constant = None
90+ ):
8991 self .strategy = strategy
92+ self .validate = validate
9093 self .random_state = random_state
9194 self .constant = constant
9295
9396 super (DummyClassifier , self ).__init__ ()
9497
9598 def fit (self , X , y ):
9699 """"""
97- X , y = self ._validate_data (X = X , y = y , ensure_min_series_length = 1 )
100+ if self .validate :
101+ X , y = self ._validate_data (X = X , y = y , ensure_min_series_length = 1 )
98102
99- check_classification_targets (y )
103+ check_classification_targets (y )
100104
101- self .classes_ = np .unique (y )
102- self .n_classes_ = self .classes_ .shape [0 ]
103- self .class_dictionary_ = {}
104- for index , classVal in enumerate (self .classes_ ):
105- self .class_dictionary_ [classVal ] = index
105+ self .classes_ = np .unique (np .asarray (y ))
106106
107- if self .n_classes_ == 1 :
108- return self
107+ if self .validate :
108+ self .n_classes_ = self .classes_ .shape [0 ]
109+ self .class_dictionary_ = {}
110+ for index , classVal in enumerate (self .classes_ ):
111+ self .class_dictionary_ [classVal ] = index
109112
110- self ._clf = SklearnDummyClassifier (
113+ if self .n_classes_ == 1 :
114+ return self
115+
116+ self .clf_ = SklearnDummyClassifier (
111117 strategy = self .strategy ,
112118 random_state = self .random_state ,
113119 constant = self .constant ,
114120 )
115- self ._clf .fit (None , y )
121+ self .clf_ .fit (None , y )
116122
117123 return self
118124
119125 def predict (self , X ) -> np .ndarray :
120126 """"""
121127 check_is_fitted (self )
122128
123- X = self ._validate_data (X = X , reset = False , ensure_min_series_length = 1 )
129+ if self .validate :
130+ # treat case of single class seen in fit
131+ if self .n_classes_ == 1 :
132+ return np .repeat (
133+ list (self .class_dictionary_ .keys ()), X .shape [0 ], axis = 0
134+ )
124135
125- # treat case of single class seen in fit
126- if self .n_classes_ == 1 :
127- return np .repeat (list (self .class_dictionary_ .keys ()), X .shape [0 ], axis = 0 )
136+ X = self ._validate_data (X = X , reset = False , ensure_min_series_length = 1 )
128137
129- return self ._clf .predict (np .zeros (X . shape ))
138+ return self .clf_ .predict (np .zeros (( _num_samples ( X ), 2 ) ))
130139
131140 def predict_proba (self , X ) -> np .ndarray :
132141 """"""
133142 check_is_fitted (self )
134143
135- # treat case of single class seen in fit
136- if self .n_classes_ == 1 :
137- return np .repeat ([[1 ]], X .shape [0 ], axis = 0 )
144+ if self .validate :
145+ # treat case of single class seen in fit
146+ if self .n_classes_ == 1 :
147+ return np .repeat ([[1 ]], X .shape [0 ], axis = 0 )
138148
139- X = self ._validate_data (X = X , reset = False , ensure_min_series_length = 1 )
149+ X = self ._validate_data (X = X , reset = False , ensure_min_series_length = 1 )
140150
141- return self ._clf .predict_proba (np .zeros (X . shape ))
151+ return self .clf_ .predict_proba (np .zeros (( _num_samples ( X ), 2 ) ))
142152
143153 def _more_tags (self ):
144154 return {
145155 "X_types" : ["3darray" , "2darray" , "np_list" ],
146156 "equal_length_only" : False ,
157+ "no_validation" : not self .validate ,
158+ "allow_nan" : True ,
147159 }
148160
149161
@@ -199,36 +211,41 @@ class DummyRegressor(RegressorMixin, BaseTimeSeriesEstimator):
199211 -0.07184048625633688
200212 """
201213
202- def __init__ (self , strategy = "mean" , constant = None , quantile = None ):
214+ def __init__ (self , strategy = "mean" , validate = False , constant = None , quantile = None ):
203215 self .strategy = strategy
216+ self .validate = validate
204217 self .constant = constant
205218 self .quantile = quantile
206219
207220 super (DummyRegressor , self ).__init__ ()
208221
209222 def fit (self , X , y ):
210223 """"""
211- _ , y = self ._validate_data (X = X , y = y , ensure_min_series_length = 1 )
224+ if self .validate :
225+ _ , y = self ._validate_data (X = X , y = y , ensure_min_series_length = 1 )
212226
213- self ._reg = SklearnDummyRegressor (
227+ self .reg_ = SklearnDummyRegressor (
214228 strategy = self .strategy , constant = self .constant , quantile = self .quantile
215229 )
216- self ._reg .fit (None , y )
230+ self .reg_ .fit (None , y )
217231
218232 return self
219233
220234 def predict (self , X ):
221235 """"""
222236 check_is_fitted (self )
223237
224- X = self ._validate_data (X = X , reset = False , ensure_min_series_length = 1 )
238+ if self .validate :
239+ X = self ._validate_data (X = X , reset = False , ensure_min_series_length = 1 )
225240
226- return self ._reg .predict (np .zeros (X . shape ))
241+ return self .reg_ .predict (np .zeros (( _num_samples ( X ), 2 ) ))
227242
228243 def _more_tags (self ):
229244 return {
230245 "X_types" : ["3darray" , "2darray" , "np_list" ],
231246 "equal_length_only" : False ,
247+ "no_validation" : not self .validate ,
248+ "allow_nan" : True ,
232249 }
233250
234251
@@ -257,16 +274,20 @@ class DummyClusterer(ClusterMixin, BaseTimeSeriesEstimator):
257274 0.2087729039422543
258275 """
259276
260- def __init__ (self , strategy = "single" , n_clusters = 2 , random_state = None ):
277+ def __init__ (
278+ self , strategy = "single" , validate = False , n_clusters = 2 , random_state = None
279+ ):
261280 self .strategy = strategy
281+ self .validate = validate
262282 self .n_clusters = n_clusters
263283 self .random_state = random_state
264284
265285 super (DummyClusterer , self ).__init__ ()
266286
267287 def fit (self , X , y = None ):
268288 """"""
269- X = self ._validate_data (X = X , ensure_min_series_length = 1 )
289+ if self .validate :
290+ X = self ._validate_data (X = X , ensure_min_series_length = 1 )
270291
271292 if self .strategy == "single" :
272293 self .labels_ = np .zeros (len (X ), dtype = np .int32 )
@@ -284,20 +305,23 @@ def predict(self, X):
284305 """"""
285306 check_is_fitted (self )
286307
287- X = self ._validate_data (X = X , reset = False , ensure_min_series_length = 1 )
308+ if self .validate :
309+ X = self ._validate_data (X = X , reset = False , ensure_min_series_length = 1 )
288310
289311 if self .strategy == "single" :
290- return np .zeros (len (X ), dtype = np .int32 )
312+ return np .zeros (_num_samples (X ), dtype = np .int32 )
291313 elif self .strategy == "unique" :
292- return np .arange (len (X ), dtype = np .int32 )
314+ return np .arange (_num_samples (X ), dtype = np .int32 )
293315 elif self .strategy == "random" :
294316 rng = check_random_state (self .random_state )
295- return rng .randint (self .n_clusters , size = len (X ), dtype = np .int32 )
317+ return rng .randint (self .n_clusters , size = _num_samples (X ), dtype = np .int32 )
296318 else :
297319 raise ValueError (f"Unknown strategy { self .strategy } " )
298320
299321 def _more_tags (self ):
300322 return {
301323 "X_types" : ["3darray" , "2darray" , "np_list" ],
302324 "equal_length_only" : False ,
325+ "no_validation" : not self .validate ,
326+ "allow_nan" : True ,
303327 }
0 commit comments