77 "_clone_estimator" ,
88]
99
10+ from abc import ABCMeta
1011from typing import List , Tuple , Union
1112
1213import numpy as np
1920from tsml .utils .validation import _num_features , check_X , check_X_y
2021
2122
22- class BaseTimeSeriesEstimator (BaseEstimator ):
23+ class BaseTimeSeriesEstimator (BaseEstimator , metaclass = ABCMeta ):
2324 """Base class for time series estimators in tsml."""
2425
2526 def _validate_data (
@@ -40,7 +41,7 @@ def _validate_data(
4041
4142 Parameters
4243 ----------
43- X : ndarray or list of ndarrays of shape (n_samples, n_dimensions , \
44+ X : ndarray or list of ndarrays of shape (n_samples, n_channels , \
4445 series_length), array-like, or 'no validation', default='no validation'
4546 The input samples. ideally a 3D numpy array or a list of 2D numpy
4647 arrays.
@@ -109,6 +110,67 @@ def _validate_data(
109110
110111 return out
111112
113+ def _convert_X (
114+ self , X : Union [np .ndarray , List [np .ndarray ]], concatenate_channels : bool = False
115+ ) -> Union [np .ndarray , List [np .ndarray ]]:
116+ dtypes = self ._get_tags ()["X_types" ]
117+
118+ if isinstance (X , np .ndarray ) and X .ndim == 3 :
119+ if "3darray" in dtypes :
120+ return X
121+ elif dtypes [0 ] == "2darray" :
122+ if X .shape [1 ] == 1 or concatenate_channels :
123+ return X .reshape ((X .shape [0 ], - 1 ))
124+ else :
125+ raise ValueError (
126+ "Can only convert 3D numpy array with 1 channel to 2D numpy "
127+ f"array if concatenate_channels is True, found { X .shape [1 ]} "
128+ "channels."
129+ )
130+ elif dtypes [0 ] == "np_list" :
131+ return [x for x in X ]
132+ elif isinstance (X , np .ndarray ) and X .ndim == 2 :
133+ if "2darray" in dtypes :
134+ return X
135+ elif dtypes [0 ] == "3darray" :
136+ return X .reshape ((X .shape [0 ], 1 , - 1 ))
137+ elif dtypes [0 ] == "np_list" :
138+ return [x .reshape (1 , X .shape [1 ]) for x in X ]
139+ elif isinstance (X , list ) and all (
140+ isinstance (x , np .ndarray ) and x .ndim == 2 for x in X
141+ ):
142+ if "np_list" in dtypes :
143+ return X
144+ elif dtypes [0 ] == "3darray" :
145+ max_len = max (x .shape [1 ] for x in X )
146+ arr = np .zeros ((len (X ), X [0 ].shape [0 ], max_len ))
147+
148+ for i , x in enumerate (X ):
149+ arr [i , :, : x .shape [1 ]] = x
150+
151+ return arr
152+ elif dtypes [0 ] == "2darray" :
153+ if X [0 ].shape [0 ] == 1 or concatenate_channels :
154+ max_len = max (x .shape [1 ] for x in X )
155+ arr = np .zeros ((len (X ), X [0 ].shape [0 ], max_len ))
156+
157+ for i , x in enumerate (X ):
158+ arr [i , :, : x .shape [1 ]] = x
159+
160+ return arr .reshape ((arr .shape [0 ], - 1 ))
161+ else :
162+ raise ValueError (
163+ "Can only convert list of 2D numpy arrays with 1 channel to 2D "
164+ "numpy array if concatenate_channels is True, found "
165+ f"{ X [0 ].shape [0 ]} channels."
166+ )
167+ else :
168+ raise ValueError (
169+ "X must be a 2D/3D numpy array or a list of 2D numpy arrays, got "
170+ f"{ f'list of { type (X [0 ])} ' if isinstance (X , list ) else type (X )} "
171+ "instead."
172+ )
173+
112174 def _check_n_features (self , X : Union [np .ndarray , List [np .ndarray ]], reset : bool ):
113175 """Set the `n_features_in_` attribute, or check against it.
114176
@@ -117,14 +179,14 @@ def _check_n_features(self, X: Union[np.ndarray, List[np.ndarray]], reset: bool)
117179 Parameters
118180 ----------
119181 X : ndarray or list of ndarrays of shape \
120- (n_samples, n_dimensions , series_length)
182+ (n_samples, n_channels , series_length)
121183 The input samples. Should be a 3D numpy array or a list of 2D numpy
122184 arrays.
123185 reset : bool
124186 If True, the `n_features_in_` attribute is set to
125- `(n_dimensions , min_series_length, max_series_length)`.
187+ `(n_channels , min_series_length, max_series_length)`.
126188 If False and the attribute exists, then check that it is equal to
127- `(n_dimensions , min_series_length, max_series_length)`.
189+ `(n_channels , min_series_length, max_series_length)`.
128190 If False and the attribute does *not* exist, then the check is skipped.
129191 .. note::
130192 It is recommended to call reset=True in `fit`. All other methods that
@@ -137,7 +199,7 @@ def _check_n_features(self, X: Union[np.ndarray, List[np.ndarray]], reset: bool)
137199 raise ValueError (
138200 "X does not contain any features to extract, but "
139201 f"{ self .__class__ .__name__ } is expecting "
140- f"{ self .n_features_in_ [0 ]} dimensions as input."
202+ f"{ self .n_features_in_ [0 ]} channels as input."
141203 ) from e
142204 # If the number of features is not defined and reset=True,
143205 # then we skip this check
@@ -155,8 +217,8 @@ def _check_n_features(self, X: Union[np.ndarray, List[np.ndarray]], reset: bool)
155217
156218 if n_features [0 ] != self .n_features_in_ [0 ]:
157219 raise ValueError (
158- f"X has { n_features [0 ]} dimensions , but { self .__class__ .__name__ } "
159- f"is expecting { self .n_features_in_ [0 ]} dimensions as input."
220+ f"X has { n_features [0 ]} channels , but { self .__class__ .__name__ } "
221+ f"is expecting { self .n_features_in_ [0 ]} channels as input."
160222 )
161223
162224 tags = _safe_tags (self )
0 commit comments