Skip to content

Commit 3473933

Browse files
committed
[fix] Pull upstream and fix the incompatible codes
Preparation before merging to the refactor_development.
1 parent 8d68ab0 commit 3473933

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

autoPyTorch/datasets/base_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABCMeta
2-
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
2+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast, Callable
33

44
import numpy as np
55

@@ -22,6 +22,7 @@
2222
from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix
2323

2424
BaseDatasetType = Union[Tuple[np.ndarray, np.ndarray], Dataset]
25+
SplitFunc = Callable[[Union[int, float], np.ndarray, Any], List[Tuple[np.ndarray, np.ndarray]]]
2526

2627

2728
def check_valid_data(data: Any) -> None:

autoPyTorch/datasets/resampling_strategy.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from typing_extensions import Protocol
1717

1818

19+
SplitFunc = Callable[[Union[int, float], np.ndarray, Any], List[Tuple[np.ndarray, np.ndarray]]]
20+
21+
1922
# Use callback protocol as workaround, since callable with function fields count 'self' as argument
2023
class CrossValFunc(Protocol):
2124
"""TODO: This class is not required anymore, because CrossValTypes class does not require get_validators()"""
@@ -150,6 +153,11 @@ def get_validators(*choices: CrossValFunc):
150153
"""TODO: to be compatible, it is here now, but will be deprecated soon."""
151154
return {choice.name: choice.value for choice in choices}
152155

156+
def is_stratified(self) -> bool:
157+
stratified = [self.stratified_k_fold_cross_validation,
158+
self.stratified_shuffle_split_cross_validation]
159+
return getattr(self, self.name) in stratified
160+
153161

154162
class HoldoutValTypes(Enum):
155163
"""The type of hold out validation (refer to CrossValTypes' doc-string)"""

0 commit comments

Comments
 (0)