1616import numpy as np
1717from sklearn .base import ClassifierMixin , RegressorMixin
1818from sklearn .ensemble import RandomForestClassifier , RandomForestRegressor
19- from sklearn .utils .validation import check_is_fitted
19+ from sklearn .ensemble ._base import _set_random_states
20+ from sklearn .utils .validation import check_is_fitted , check_random_state
2021
2122from tsml .base import BaseTimeSeriesEstimator , _clone_estimator
2223from tsml .transformations ._interval_extraction import (
@@ -47,6 +48,13 @@ class RandomIntervalClassifier(ClassifierMixin, BaseTimeSeriesEstimator):
4748 of said transformers and functions, default=None
4849 Transformers and functions used to extract features from selected intervals.
4950 If None, defaults to [mean, median, min, max, std, 25% quantile, 75% quantile]
51+ series_transformers : TransformerMixin, list, tuple, or None, default=None
52+ The transformers to apply to the series before extracting intervals and
53+ shapelets. If None, use the series as is.
54+
55+ A list or tuple of transformers will extract intervals from
56+ all transformations concatenate the output. Including None in the list or tuple
57+ will use the series as is for interval extraction.
5058 dilation : int, list or None, default=None
5159 Add dilation to extracted intervals. No dilation is added if None or 1. If a
5260 list of ints, a random dilation value is selected from the list for each
@@ -110,6 +118,7 @@ def __init__(
110118 min_interval_length = 3 ,
111119 max_interval_length = np .inf ,
112120 features = None ,
121+ series_transformers = None ,
113122 dilation = None ,
114123 estimator = None ,
115124 n_jobs = 1 ,
@@ -120,6 +129,7 @@ def __init__(
120129 self .min_interval_length = min_interval_length
121130 self .max_interval_length = max_interval_length
122131 self .features = features
132+ self .series_transformers = series_transformers
123133 self .dilation = dilation
124134 self .estimator = estimator
125135 self .random_state = random_state
@@ -159,17 +169,42 @@ def fit(self, X: Union[np.ndarray, List[np.ndarray]], y: np.ndarray) -> object:
159169 return self
160170
161171 self ._n_jobs = check_n_jobs (self .n_jobs )
172+ rng = check_random_state (self .random_state )
162173
163- self ._transformer = RandomIntervalTransformer (
164- n_intervals = self .n_intervals ,
165- min_interval_length = self .min_interval_length ,
166- max_interval_length = self .max_interval_length ,
167- features = self .features ,
168- dilation = self .dilation ,
169- random_state = self .random_state ,
170- n_jobs = self ._n_jobs ,
171- parallel_backend = self .parallel_backend ,
172- )
174+ if isinstance (self .series_transformers , (list , tuple )):
175+ self ._series_transformers = [
176+ None if st is None else _clone_estimator (st , random_state = rng )
177+ for st in self .series_transformers
178+ ]
179+ else :
180+ self ._series_transformers = [
181+ None
182+ if self .series_transformers is None
183+ else _clone_estimator (self .series_transformers , random_state = rng )
184+ ]
185+
186+ X_t = np .empty ((X .shape [0 ], 0 ))
187+ self ._transformers = []
188+ for st in self ._series_transformers :
189+ if st is not None :
190+ s = st .fit_transform (X , y )
191+ else :
192+ s = X
193+
194+ ct = RandomIntervalTransformer (
195+ n_intervals = self .n_intervals ,
196+ min_interval_length = self .min_interval_length ,
197+ max_interval_length = self .max_interval_length ,
198+ features = self .features ,
199+ dilation = self .dilation ,
200+ n_jobs = self ._n_jobs ,
201+ parallel_backend = self .parallel_backend ,
202+ )
203+ _set_random_states (ct , rng )
204+ self ._transformers .append (ct )
205+ t = ct .fit_transform (s , y )
206+
207+ X_t = np .hstack ((X_t , t ))
173208
174209 self ._estimator = _clone_estimator (
175210 RandomForestClassifier (n_estimators = 200 )
@@ -182,7 +217,6 @@ def fit(self, X: Union[np.ndarray, List[np.ndarray]], y: np.ndarray) -> object:
182217 if m is not None :
183218 self ._estimator .n_jobs = self ._n_jobs
184219
185- X_t = self ._transformer .fit_transform (X , y )
186220 self ._estimator .fit (X_t , y )
187221
188222 return self
@@ -209,7 +243,17 @@ def predict(self, X: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
209243 X = self ._validate_data (X = X , reset = False , ensure_min_series_length = 3 )
210244 X = self ._convert_X (X )
211245
212- return self ._estimator .predict (self ._transformer .transform (X ))
246+ X_t = np .empty ((X .shape [0 ], 0 ))
247+ for i , st in enumerate (self ._series_transformers ):
248+ if st is not None :
249+ s = st .transform (X )
250+ else :
251+ s = X
252+
253+ t = self ._transformers [i ].transform (s )
254+ X_t = np .hstack ((X_t , t ))
255+
256+ return self ._estimator .predict (X_t )
213257
214258 def predict_proba (self , X : Union [np .ndarray , List [np .ndarray ]]) -> np .ndarray :
215259 """Predicts labels probabilities for sequences in X.
@@ -233,12 +277,22 @@ def predict_proba(self, X: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
233277 X = self ._validate_data (X = X , reset = False , ensure_min_series_length = 3 )
234278 X = self ._convert_X (X )
235279
280+ X_t = np .empty ((X .shape [0 ], 0 ))
281+ for i , st in enumerate (self ._series_transformers ):
282+ if st is not None :
283+ s = st .transform (X )
284+ else :
285+ s = X
286+
287+ t = self ._transformers [i ].transform (s )
288+ X_t = np .hstack ((X_t , t ))
289+
236290 m = getattr (self ._estimator , "predict_proba" , None )
237291 if callable (m ):
238- return self ._estimator .predict_proba (self . _transformer . transform ( X ) )
292+ return self ._estimator .predict_proba (X_t )
239293 else :
240294 dists = np .zeros ((X .shape [0 ], self .n_classes_ ))
241- preds = self ._estimator .predict (self . _transformer . transform ( X ) )
295+ preds = self ._estimator .predict (X_t )
242296 for i in range (0 , X .shape [0 ]):
243297 dists [i , self .class_dictionary_ [preds [i ]]] = 1
244298 return dists
@@ -290,6 +344,13 @@ class RandomIntervalRegressor(RegressorMixin, BaseTimeSeriesEstimator):
290344 of said transformers and functions, default=None
291345 Transformers and functions used to extract features from selected intervals.
292346 If None, defaults to [mean, median, min, max, std, 25% quantile, 75% quantile]
347+ series_transformers : TransformerMixin, list, tuple, or None, default=None
348+ The transformers to apply to the series before extracting intervals and
349+ shapelets. If None, use the series as is.
350+
351+ A list or tuple of transformers will extract intervals from
352+ all transformations concatenate the output. Including None in the list or tuple
353+ will use the series as is for interval extraction.
293354 dilation : int, list or None, default=None
294355 Add dilation to extracted intervals. No dilation is added if None or 1. If a
295356 list of ints, a random dilation value is selected from the list for each
@@ -338,8 +399,8 @@ class RandomIntervalRegressor(RegressorMixin, BaseTimeSeriesEstimator):
338399 >>> reg.fit(X, y)
339400 RandomIntervalRegressor(...)
340401 >>> reg.predict(X)
341- array([0.46836751 , 1.32023847 , 1.13355919 , 0.63979608 , 0.58309353 ,
342- 1.18197903 , 0.57859747 , 1.0772939 ])
402+ array([0.44924979 , 1.31424037 , 1.11951504 , 0.63780969 , 0.58123516 ,
403+ 1.17135463 , 0.56450198 , 1.10128837 ])
343404 """
344405
345406 def __init__ (
@@ -348,6 +409,7 @@ def __init__(
348409 min_interval_length = 3 ,
349410 max_interval_length = np .inf ,
350411 features = None ,
412+ series_transformers = None ,
351413 dilation = None ,
352414 estimator = None ,
353415 n_jobs = 1 ,
@@ -358,6 +420,7 @@ def __init__(
358420 self .min_interval_length = min_interval_length
359421 self .max_interval_length = max_interval_length
360422 self .features = features
423+ self .series_transformers = series_transformers
361424 self .dilation = dilation
362425 self .estimator = estimator
363426 self .random_state = random_state
@@ -389,17 +452,42 @@ def fit(self, X: Union[np.ndarray, List[np.ndarray]], y: np.ndarray) -> object:
389452 self .n_instances_ , self .n_channels_ , self .n_timepoints_ = X .shape
390453
391454 self ._n_jobs = check_n_jobs (self .n_jobs )
455+ rng = check_random_state (self .random_state )
392456
393- self ._transformer = RandomIntervalTransformer (
394- n_intervals = self .n_intervals ,
395- min_interval_length = self .min_interval_length ,
396- max_interval_length = self .max_interval_length ,
397- features = self .features ,
398- dilation = self .dilation ,
399- random_state = self .random_state ,
400- n_jobs = self ._n_jobs ,
401- parallel_backend = self .parallel_backend ,
402- )
457+ if isinstance (self .series_transformers , (list , tuple )):
458+ self ._series_transformers = [
459+ None if st is None else _clone_estimator (st , random_state = rng )
460+ for st in self .series_transformers
461+ ]
462+ else :
463+ self ._series_transformers = [
464+ None
465+ if self .series_transformers is None
466+ else _clone_estimator (self .series_transformers , random_state = rng )
467+ ]
468+
469+ X_t = np .empty ((X .shape [0 ], 0 ))
470+ self ._transformers = []
471+ for st in self ._series_transformers :
472+ if st is not None :
473+ s = st .fit_transform (X , y )
474+ else :
475+ s = X
476+
477+ ct = RandomIntervalTransformer (
478+ n_intervals = self .n_intervals ,
479+ min_interval_length = self .min_interval_length ,
480+ max_interval_length = self .max_interval_length ,
481+ features = self .features ,
482+ dilation = self .dilation ,
483+ n_jobs = self ._n_jobs ,
484+ parallel_backend = self .parallel_backend ,
485+ )
486+ _set_random_states (ct , rng )
487+ self ._transformers .append (ct )
488+ t = ct .fit_transform (s , y )
489+
490+ X_t = np .hstack ((X_t , t ))
403491
404492 self ._estimator = _clone_estimator (
405493 RandomForestRegressor (n_estimators = 200 )
@@ -412,7 +500,6 @@ def fit(self, X: Union[np.ndarray, List[np.ndarray]], y: np.ndarray) -> object:
412500 if m is not None :
413501 self ._estimator .n_jobs = self ._n_jobs
414502
415- X_t = self ._transformer .fit_transform (X , y )
416503 self ._estimator .fit (X_t , y )
417504
418505 return self
@@ -435,7 +522,17 @@ def predict(self, X: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
435522 X = self ._validate_data (X = X , reset = False , ensure_min_series_length = 3 )
436523 X = self ._convert_X (X )
437524
438- return self ._estimator .predict (self ._transformer .transform (X ))
525+ X_t = np .empty ((X .shape [0 ], 0 ))
526+ for i , st in enumerate (self ._series_transformers ):
527+ if st is not None :
528+ s = st .transform (X )
529+ else :
530+ s = X
531+
532+ t = self ._transformers [i ].transform (s )
533+ X_t = np .hstack ((X_t , t ))
534+
535+ return self ._estimator .predict (X_t )
439536
440537 @classmethod
441538 def get_test_params (
0 commit comments