Skip to content

Commit c470099

Browse files
committed
[fix]: back to the renamed version of CROSS_VAL_FN from temporal SplitFunc typing.
1 parent ffde177 commit c470099

File tree

2 files changed

+10
-31
lines changed

2 files changed

+10
-31
lines changed

autoPyTorch/datasets/base_dataset.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABCMeta
2-
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast, Callable
2+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
33

44
import numpy as np
55

@@ -15,14 +15,15 @@
1515
from autoPyTorch.datasets.resampling_strategy import (
1616
CrossValFuncs,
1717
CrossValTypes,
18+
CrossValFunc,
1819
DEFAULT_RESAMPLING_PARAMETERS,
1920
HoldoutValTypes,
2021
HoldOutFuncs,
22+
HoldOutFunc
2123
)
2224
from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix
2325

2426
BaseDatasetType = Union[Tuple[np.ndarray, np.ndarray], Dataset]
25-
SplitFunc = Callable[[Union[int, float], np.ndarray, Any], List[Tuple[np.ndarray, np.ndarray]]]
2627

2728

2829
def check_valid_data(data: Any) -> None:
@@ -102,8 +103,8 @@ def __init__(
102103
if not hasattr(train_tensors[0], 'shape'):
103104
type_check(train_tensors, val_tensors)
104105
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
105-
self.cross_validators: Dict[str, SplitFunc] = {}
106-
self.holdout_validators: Dict[str, SplitFunc] = {}
106+
self.cross_validators: Dict[str, CrossValFunc] = {}
107+
self.holdout_validators: Dict[str, HoldOutFunc] = {}
107108
self.rng = np.random.RandomState(seed=seed)
108109
self.shuffle = shuffle
109110
self.resampling_strategy = resampling_strategy

autoPyTorch/datasets/resampling_strategy.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import IntEnum
2-
from typing import Any, Dict, List, Optional, Tuple, Union, Callable
2+
from typing import Any, Dict, List, Optional, Tuple, Union
33

44
import numpy as np
55

@@ -15,21 +15,16 @@
1515
from typing_extensions import Protocol
1616

1717

18-
SplitFunc = Callable[[Union[int, float], np.ndarray, Any], List[Tuple[np.ndarray, np.ndarray]]]
19-
20-
2118
# Use callback protocol as workaround, since callable with function fields count 'self' as argument
22-
class CROSS_VAL_FN(Protocol):
23-
"""TODO: deprecate soon"""
19+
class CrossValFunc(Protocol):
2420
def __call__(self,
2521
num_splits: int,
2622
indices: np.ndarray,
2723
stratify: Optional[Any]) -> List[Tuple[np.ndarray, np.ndarray]]:
2824
...
2925

3026

31-
class HOLDOUT_FN(Protocol):
32-
"""TODO: deprecate soon"""
27+
class HoldOutFunc(Protocol):
3328
def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any]
3429
) -> Tuple[np.ndarray, np.ndarray]:
3530
...
@@ -104,23 +99,6 @@ def is_stratified(self) -> bool:
10499
} # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]]
105100

106101

107-
def get_holdout_validators(*holdout_val_types: HoldoutValTypes) -> Dict[str, HOLDOUT_FN]:
108-
"""TODO: deprecate soon"""
109-
holdout_validators = {} # type: Dict[str, HOLDOUT_FN]
110-
for holdout_val_type in holdout_val_types:
111-
holdout_val_fn = globals()[holdout_val_type.name]
112-
holdout_validators[holdout_val_type.name] = holdout_val_fn
113-
return holdout_validators
114-
115-
116-
def is_stratified(val_type: Union[str, CrossValTypes, HoldoutValTypes]) -> bool:
117-
"""TODO: deprecate soon"""
118-
if isinstance(val_type, str):
119-
return val_type.lower().startswith("stratified")
120-
else:
121-
return val_type.name.lower().startswith("stratified")
122-
123-
124102
class HoldOutFuncs():
125103
@staticmethod
126104
def holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) -> Tuple[np.ndarray, np.ndarray]:
@@ -134,7 +112,7 @@ def stratified_holdout_validation(val_share: float, indices: np.ndarray, **kwarg
134112
return train, val
135113

136114
@classmethod
137-
def get_holdout_validators(cls, *holdout_val_types: Tuple[HoldoutValTypes]) -> Dict[str, SplitFunc]:
115+
def get_holdout_validators(cls, *holdout_val_types: Tuple[HoldoutValTypes]) -> Dict[str, HoldOutFunc]:
138116

139117
holdout_validators = {
140118
holdout_val_type.name: getattr(cls, holdout_val_type.name)
@@ -198,7 +176,7 @@ def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs:
198176
return splits
199177

200178
@classmethod
201-
def get_cross_validators(cls, *cross_val_types: CrossValTypes) -> Dict[str, SplitFunc]:
179+
def get_cross_validators(cls, *cross_val_types: CrossValTypes) -> Dict[str, CrossValFunc]:
202180
cross_validators = {
203181
cross_val_type.name: getattr(cls, cross_val_type.name)
204182
for cross_val_type in cross_val_types

0 commit comments

Comments
 (0)