Skip to content

Commit fb4f7ab

Browse files
Interval classifiers (#5)
* actually change python version * dummy classifiers and sklearn lower bound change * test fix * test fix * dev * early sklearn version fixes * all interval classifiers
1 parent d0a8c6d commit fb4f7ab

File tree

20 files changed

+617
-261
lines changed

20 files changed

+617
-261
lines changed

pyproject.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "tsml"
7-
version = "0.0.4"
7+
version = "0.0.5"
88
description = "A toolkit for time series machine learning algorithms."
99
authors = [
1010
{name = "Matthew Middlehurst", email = "m.middlehurst@uea.ac.uk"},
@@ -42,8 +42,9 @@ dependencies = [
4242

4343
[project.optional-dependencies]
4444
extras = [
45-
"pycatch22",
46-
"pyfftw"
45+
"pycatch22>=0.4.2",
46+
"pyfftw>=0.12.0",
47+
"statsmodels>=0.12.1",
4748
]
4849
dev = [
4950
"pre-commit",

tsml/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# -*- coding: utf-8 -*-
22
"""tsml."""
33

4-
__version__ = "0.0.4"
4+
__version__ = "0.0.5"

tsml/dummy/_dummy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def fit(self, X, y):
104104
for index, classVal in enumerate(self.classes_):
105105
self.class_dictionary_[classVal] = index
106106

107-
if len(self.classes_) == 1:
107+
if self.n_classes_ == 1:
108108
return self
109109

110110
self._clf = SklearnDummyClassifier(
@@ -120,12 +120,12 @@ def predict(self, X) -> np.ndarray:
120120
""""""
121121
check_is_fitted(self)
122122

123+
X = self._validate_data(X=X, reset=False, ensure_min_series_length=1)
124+
123125
# treat case of single class seen in fit
124126
if self.n_classes_ == 1:
125127
return np.repeat(list(self.class_dictionary_.keys()), X.shape[0], axis=0)
126128

127-
X = self._validate_data(X=X, reset=False, ensure_min_series_length=1)
128-
129129
return self._clf.predict(np.zeros(X.shape))
130130

131131
def predict_proba(self, X) -> np.ndarray:

tsml/feature_based/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66
"Catch22Regressor",
77
]
88

9-
from tsml.feature_based._catch22_classifier import Catch22Classifier, Catch22Regressor
9+
from tsml.feature_based._catch22 import Catch22Classifier, Catch22Regressor

tsml/feature_based/_catch22_classifier.py renamed to tsml/feature_based/_catch22.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ def fit(self, X, y):
123123
for index, classVal in enumerate(self.classes_):
124124
self.class_dictionary_[classVal] = index
125125

126+
if self.n_classes_ == 1:
127+
return self
128+
126129
self._n_jobs = check_n_jobs(self.n_jobs)
127130

128131
self._transformer = Catch22Transformer(
@@ -164,6 +167,10 @@ def predict(self, X) -> np.ndarray:
164167
"""
165168
check_is_fitted(self)
166169

170+
# treat case of single class seen in fit
171+
if self.n_classes_ == 1:
172+
return np.repeat(list(self.class_dictionary_.keys()), X.shape[0], axis=0)
173+
167174
X = self._validate_data(X=X, reset=False)
168175

169176
return self._estimator.predict(self._transformer.transform(X))
@@ -183,6 +190,10 @@ def predict_proba(self, X) -> np.ndarray:
183190
"""
184191
check_is_fitted(self)
185192

193+
# treat case of single class seen in fit
194+
if self.n_classes_ == 1:
195+
return np.repeat([[1]], X.shape[0], axis=0)
196+
186197
X = self._validate_data(X=X, reset=False)
187198

188199
m = getattr(self._estimator, "predict_proba", None)

tsml/interval_based/__init__.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,28 @@
55
"BaseIntervalForest",
66
"CIFClassifier",
77
"CIFRegressor",
8-
# "DrCIFClassifier",
9-
# "DrCIFRegressor",
8+
"DrCIFClassifier",
9+
"DrCIFRegressor",
1010
"IntervalForestClassifier",
1111
"IntervalForestRegressor",
1212
"RandomIntervalClassifier",
1313
"RandomIntervalRegressor",
1414
"SupervisedIntervalClassifier",
15-
# "RISEClassifier",
16-
# "RISERegressor",
17-
# "STSFClassifier",
18-
# "RSTSFClassifier",
15+
"RISEClassifier",
16+
"RISERegressor",
17+
"STSFClassifier",
18+
"RSTSFClassifier",
1919
"TSFClassifier",
2020
"TSFRegressor",
2121
]
2222

2323
from tsml.interval_based._base import BaseIntervalForest
24-
from tsml.interval_based._cif import CIFClassifier, CIFRegressor
24+
from tsml.interval_based._cif import (
25+
CIFClassifier,
26+
CIFRegressor,
27+
DrCIFClassifier,
28+
DrCIFRegressor,
29+
)
2530
from tsml.interval_based._interval_forest import (
2631
IntervalForestClassifier,
2732
IntervalForestRegressor,
@@ -31,4 +36,6 @@
3136
RandomIntervalRegressor,
3237
SupervisedIntervalClassifier,
3338
)
39+
from tsml.interval_based._rise import RISEClassifier, RISERegressor
40+
from tsml.interval_based._stsf import RSTSFClassifier, STSFClassifier
3441
from tsml.interval_based._tsf import TSFClassifier, TSFRegressor

0 commit comments

Comments
 (0)