Skip to content

Commit d0a8c6d

Browse files
More content and compatability for earlier sklearn versions (#4)
* actually change python version * dummy classifiers and sklearn lower bound change * test fix * test fix * dev * early sklearn version fixes
1 parent 1092774 commit d0a8c6d

25 files changed

+756
-302
lines changed

.github/workflows/release.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
python -m pip install build
3737
python -m build
3838
39-
- name: Store built files
39+
- name: Store build files
4040
uses: actions/upload-artifact@v3
4141
with:
4242
name: dist
@@ -73,10 +73,10 @@ jobs:
7373

7474
- if: matrix.os == 'windows-latest'
7575
name: Windows install
76-
run: python -m pip install "${env:WHEELNAME}[optional_dependencies,dev]"
76+
run: python -m pip install "${env:WHEELNAME}[extras,dev]"
7777
- if: matrix.os != 'windows-latest'
7878
name: Unix install
79-
run: python -m pip install "${{ env.WHEELNAME }}[optional_dependencies,dev]"
79+
run: python -m pip install "${{ env.WHEELNAME }}[extras,dev]"
8080

8181
- name: Tests
8282
run: python -m pytest

.github/workflows/tests.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ jobs:
1818
with:
1919
python-version: "3.10"
2020

21-
- id: file_changes
22-
uses: trilom/file-changes-action@v1.2.4
21+
- uses: trilom/file-changes-action@v1.2.4
2322
with:
2423
output: " "
2524

@@ -43,7 +42,7 @@ jobs:
4342
python-version: ${{ matrix.python-version }}
4443

4544
- name: Install
46-
run: python -m pip install .[dev,optional_dependencies]
45+
run: python -m pip install .[dev,extras]
4746

4847
- name: Tests
4948
run: python -m pytest
@@ -63,7 +62,7 @@ jobs:
6362
run: echo "NUMBA_DISABLE_JIT=1" >> $GITHUB_ENV
6463

6564
- name: Install
66-
run: python -m pip install .[dev,optional_dependencies]
65+
run: python -m pip install .[dev,extras]
6766

6867
- name: Tests
6968
run: python -m pytest --cov=tsml --cov-report=xml

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "tsml"
7-
version = "0.0.3"
7+
version = "0.0.4"
88
description = "A toolkit for time series machine learning algorithms."
99
authors = [
1010
{name = "Matthew Middlehurst", email = "m.middlehurst@uea.ac.uk"},
@@ -41,8 +41,9 @@ dependencies = [
4141
]
4242

4343
[project.optional-dependencies]
44-
optional_dependencies = [
44+
extras = [
4545
"pycatch22",
46+
"pyfftw"
4647
]
4748
dev = [
4849
"pre-commit",

tsml/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# -*- coding: utf-8 -*-
22
"""tsml."""
33

4-
__version__ = "0.0.3"
4+
__version__ = "0.0.4"

tsml/base.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"_clone_estimator",
88
]
99

10+
from abc import ABCMeta
1011
from typing import List, Tuple, Union
1112

1213
import numpy as np
@@ -19,7 +20,7 @@
1920
from tsml.utils.validation import _num_features, check_X, check_X_y
2021

2122

22-
class BaseTimeSeriesEstimator(BaseEstimator):
23+
class BaseTimeSeriesEstimator(BaseEstimator, metaclass=ABCMeta):
2324
"""Base class for time series estimators in tsml."""
2425

2526
def _validate_data(
@@ -40,7 +41,7 @@ def _validate_data(
4041
4142
Parameters
4243
----------
43-
X : ndarray or list of ndarrays of shape (n_samples, n_dimensions, \
44+
X : ndarray or list of ndarrays of shape (n_samples, n_channels, \
4445
series_length), array-like, or 'no validation', default='no validation'
4546
The input samples. ideally a 3D numpy array or a list of 2D numpy
4647
arrays.
@@ -109,6 +110,67 @@ def _validate_data(
109110

110111
return out
111112

113+
def _convert_X(
114+
self, X: Union[np.ndarray, List[np.ndarray]], concatenate_channels: bool = False
115+
) -> Union[np.ndarray, List[np.ndarray]]:
116+
dtypes = self._get_tags()["X_types"]
117+
118+
if isinstance(X, np.ndarray) and X.ndim == 3:
119+
if "3darray" in dtypes:
120+
return X
121+
elif dtypes[0] == "2darray":
122+
if X.shape[1] == 1 or concatenate_channels:
123+
return X.reshape((X.shape[0], -1))
124+
else:
125+
raise ValueError(
126+
"Can only convert 3D numpy array with 1 channel to 2D numpy "
127+
f"array if concatenate_channels is True, found {X.shape[1]} "
128+
"channels."
129+
)
130+
elif dtypes[0] == "np_list":
131+
return [x for x in X]
132+
elif isinstance(X, np.ndarray) and X.ndim == 2:
133+
if "2darray" in dtypes:
134+
return X
135+
elif dtypes[0] == "3darray":
136+
return X.reshape((X.shape[0], 1, -1))
137+
elif dtypes[0] == "np_list":
138+
return [x.reshape(1, X.shape[1]) for x in X]
139+
elif isinstance(X, list) and all(
140+
isinstance(x, np.ndarray) and x.ndim == 2 for x in X
141+
):
142+
if "np_list" in dtypes:
143+
return X
144+
elif dtypes[0] == "3darray":
145+
max_len = max(x.shape[1] for x in X)
146+
arr = np.zeros((len(X), X[0].shape[0], max_len))
147+
148+
for i, x in enumerate(X):
149+
arr[i, :, : x.shape[1]] = x
150+
151+
return arr
152+
elif dtypes[0] == "2darray":
153+
if X[0].shape[0] == 1 or concatenate_channels:
154+
max_len = max(x.shape[1] for x in X)
155+
arr = np.zeros((len(X), X[0].shape[0], max_len))
156+
157+
for i, x in enumerate(X):
158+
arr[i, :, : x.shape[1]] = x
159+
160+
return arr.reshape((arr.shape[0], -1))
161+
else:
162+
raise ValueError(
163+
"Can only convert list of 2D numpy arrays with 1 channel to 2D "
164+
"numpy array if concatenate_channels is True, found "
165+
f"{X[0].shape[0]} channels."
166+
)
167+
else:
168+
raise ValueError(
169+
"X must be a 2D/3D numpy array or a list of 2D numpy arrays, got "
170+
f"{f'list of {type(X[0])}' if isinstance(X, list) else type(X)} "
171+
"instead."
172+
)
173+
112174
def _check_n_features(self, X: Union[np.ndarray, List[np.ndarray]], reset: bool):
113175
"""Set the `n_features_in_` attribute, or check against it.
114176
@@ -117,14 +179,14 @@ def _check_n_features(self, X: Union[np.ndarray, List[np.ndarray]], reset: bool)
117179
Parameters
118180
----------
119181
X : ndarray or list of ndarrays of shape \
120-
(n_samples, n_dimensions, series_length)
182+
(n_samples, n_channels, series_length)
121183
The input samples. Should be a 3D numpy array or a list of 2D numpy
122184
arrays.
123185
reset : bool
124186
If True, the `n_features_in_` attribute is set to
125-
`(n_dimensions, min_series_length, max_series_length)`.
187+
`(n_channels, min_series_length, max_series_length)`.
126188
If False and the attribute exists, then check that it is equal to
127-
`(n_dimensions, min_series_length, max_series_length)`.
189+
`(n_channels, min_series_length, max_series_length)`.
128190
If False and the attribute does *not* exist, then the check is skipped.
129191
.. note::
130192
It is recommended to call reset=True in `fit`. All other methods that
@@ -137,7 +199,7 @@ def _check_n_features(self, X: Union[np.ndarray, List[np.ndarray]], reset: bool)
137199
raise ValueError(
138200
"X does not contain any features to extract, but "
139201
f"{self.__class__.__name__} is expecting "
140-
f"{self.n_features_in_[0]} dimensions as input."
202+
f"{self.n_features_in_[0]} channels as input."
141203
) from e
142204
# If the number of features is not defined and reset=True,
143205
# then we skip this check
@@ -155,8 +217,8 @@ def _check_n_features(self, X: Union[np.ndarray, List[np.ndarray]], reset: bool)
155217

156218
if n_features[0] != self.n_features_in_[0]:
157219
raise ValueError(
158-
f"X has {n_features[0]} dimensions, but {self.__class__.__name__} "
159-
f"is expecting {self.n_features_in_[0]} dimensions as input."
220+
f"X has {n_features[0]} channels, but {self.__class__.__name__} "
221+
f"is expecting {self.n_features_in_[0]} channels as input."
160222
)
161223

162224
tags = _safe_tags(self)

tsml/datasets/_data_io.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -161,15 +161,15 @@ def load_from_ts_file(
161161
if len(tokens) != 2:
162162
raise IOError(
163163
"Invalid .ts file. @dimension tag requires a int value "
164-
"(the number of dimensions for the problem)."
164+
"(the number of channels for the problem)."
165165
)
166166

167167
try:
168168
dimensions = int(tokens[1])
169169
except ValueError:
170170
raise IOError(
171171
"Invalid .ts file. @dimension tag requires a int value "
172-
"(the number of dimensions for the problem)."
172+
"(the number of channels for the problem)."
173173
)
174174

175175
dimensions_tag = True
@@ -194,15 +194,15 @@ def load_from_ts_file(
194194
if len(tokens) != 2:
195195
raise IOError(
196196
"Invalid .ts file. @serieslength tag requires a int value "
197-
"(the number of dimensions for the problem)."
197+
"(the series length for the problem)."
198198
)
199199

200200
try:
201201
serieslength = int(tokens[1])
202202
except ValueError:
203203
raise IOError(
204204
"Invalid .ts file. @serieslength tag requires a int value "
205-
"(the number of dimensions for the problem)."
205+
"(the series length for the problem)."
206206
)
207207

208208
serieslength_tag = True
@@ -341,13 +341,13 @@ def load_from_ts_file(
341341
) and data_dims > 1:
342342
raise IOError(
343343
"Value mismatch in .ts file. @univariate tag is missing or True "
344-
"but data has more than one dimension."
344+
"but data has more than one channel."
345345
)
346346

347347
if dimensions_tag and dimensions != data_dims:
348348
raise IOError(
349349
f"Value mismatch in .ts file. @dimensions tag value {dimensions} "
350-
f"and read number of dimensions {data_dims} do not match."
350+
f"and read number of channels {data_dims} do not match."
351351
)
352352

353353
if serieslength_tag and serieslength != data_length:
@@ -375,19 +375,19 @@ def load_from_ts_file(
375375

376376
line = line.split(":")
377377

378-
# Does not support different number of dimensions
378+
# Does not support different number of channels
379379
read_dims = len(line) - 1 if has_labels else len(line)
380380
if read_dims != data_dims:
381381
raise IOError(
382-
"Unable to read .ts file. Inconsistent number of dimensions."
382+
"Unable to read .ts file. Inconsistent number of channels."
383383
f"Expected {data_dims} but read {read_dims} on line {data_idx}."
384384
)
385385

386386
dimensions = line[:data_dims]
387387
if not equallength:
388388
data_length = len(dimensions[0].strip().split(","))
389389

390-
# Process the data for each dimension
390+
# Process the data for each channel
391391
series = np.zeros((data_dims, data_length), dtype=X_dtype)
392392
for i in range(data_dims):
393393
series[i, :] = dimensions[i].strip().split(",")
@@ -500,7 +500,7 @@ def load_equal_minimal_japanese_vowels(
500500
stripped down version of the JapaneseVowels problem that is used in correctness
501501
tests for classification. It has been altered so all series are equal length. It
502502
loads a nine class classification problem with 20 cases for both the train and test
503-
split, 12 dimensions and a series length of 25.
503+
split, 12 channels and a series length of 25.
504504
505505
For the full dataset see
506506
http://www.timeseriesclassification.com/description.php?Dataset=JapaneseVowels
@@ -534,7 +534,7 @@ def load_minimal_japanese_vowels(
534534
This is an unequal length multivariate time series classification problem. It is a
535535
stripped down version of the JapaneseVowels problem that is used in correctness
536536
tests for classification. It loads a nine class classification problem with 20 cases
537-
for both the train and test split and 12 dimensions.
537+
for both the train and test split and 12 channels.
538538
539539
For the full dataset see
540540
http://www.timeseriesclassification.com/description.php?Dataset=JapaneseVowels

0 commit comments

Comments
 (0)