Skip to content

Commit 2d2ebb8

Browse files
committed
[refactor] Change files so that we can see the difference easier
1 parent bef4323 commit 2d2ebb8

File tree

2 files changed

+73
-73
lines changed

2 files changed

+73
-73
lines changed

autoPyTorch/datasets/base_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def __len__(self) -> int:
200200
def _get_indices(self) -> np.ndarray:
201201
return self.random_state.permutation(len(self)) if self.shuffle else np.arange(len(self))
202202

203-
def _process_resampling_strategy_args(self) -> None:
203+
def _check_resampling_strategy_args(self) -> None:
204204
if not any(isinstance(self.resampling_strategy, val_type)
205205
for val_type in [HoldoutValTypes, CrossValTypes]):
206206
raise ValueError(f"resampling_strategy {self.resampling_strategy} is not supported.")
@@ -231,7 +231,7 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]
231231
(List[Tuple[List[int], List[int]]]): splits in the [train_indices, val_indices] format
232232
"""
233233
# check if the requirements are met and if we can get splits
234-
self._process_resampling_strategy_args()
234+
self._check_resampling_strategy_args()
235235

236236
labels_to_stratify = self.train_tensors[-1] if self.is_stratify else None
237237

autoPyTorch/datasets/resampling_strategy.py

Lines changed: 71 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -22,77 +22,6 @@ class _ResamplingStrategyArgs(NamedTuple):
2222
stratify: bool = False
2323

2424

25-
class HoldoutFuncs():
26-
@staticmethod
27-
def holdout_validation(
28-
indices: np.ndarray,
29-
random_state: Optional[np.random.RandomState] = None,
30-
val_share: Optional[float] = None,
31-
shuffle: bool = False,
32-
labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None
33-
) -> List[Tuple[np.ndarray, np.ndarray]]:
34-
35-
""" SKLearn requires shuffle=True for stratify """
36-
train, val = train_test_split(
37-
indices, test_size=val_share,
38-
shuffle=shuffle if labels_to_stratify is None else True,
39-
random_state=random_state,
40-
stratify=labels_to_stratify
41-
)
42-
return [(train, val)]
43-
44-
45-
class CrossValFuncs():
46-
# (shuffle, is_stratify) -> split_fn
47-
_args2split_fn = {
48-
(True, True): StratifiedShuffleSplit,
49-
(True, False): ShuffleSplit,
50-
(False, True): StratifiedKFold,
51-
(False, False): KFold,
52-
}
53-
54-
@staticmethod
55-
def k_fold_cross_validation(
56-
indices: np.ndarray,
57-
random_state: Optional[np.random.RandomState] = None,
58-
num_splits: Optional[int] = None,
59-
shuffle: bool = False,
60-
labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None
61-
) -> List[Tuple[np.ndarray, np.ndarray]]:
62-
"""
63-
Returns:
64-
splits (List[Tuple[List, List]]): list of tuples of training and validation indices
65-
"""
66-
67-
split_fn = CrossValFuncs._args2split_fn[(shuffle, labels_to_stratify is not None)]
68-
cv = split_fn(n_splits=num_splits, random_state=random_state)
69-
splits = list(cv.split(indices))
70-
return splits
71-
72-
@staticmethod
73-
def time_series(
74-
indices: np.ndarray,
75-
random_state: Optional[np.random.RandomState] = None,
76-
num_splits: Optional[int] = None,
77-
shuffle: bool = False,
78-
labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None
79-
) -> List[Tuple[np.ndarray, np.ndarray]]:
80-
"""
81-
Returns train and validation indices respecting the temporal ordering of the data.
82-
83-
Examples:
84-
>>> indices = np.array([0, 1, 2, 3])
85-
>>> CrossValFuncs.time_series_cross_validation(3, indices)
86-
[([0], [1]),
87-
([0, 1], [2]),
88-
([0, 1, 2], [3])]
89-
90-
"""
91-
cv = TimeSeriesSplit(n_splits=num_splits)
92-
splits = list(cv.split(indices))
93-
return splits
94-
95-
9625
class CrossValTypes(IntEnum):
9726
"""The type of cross validation
9827
@@ -214,3 +143,74 @@ def __call__(
214143
shuffle=shuffle,
215144
labels_to_stratify=labels_to_stratify
216145
)
146+
147+
148+
class HoldoutFuncs():
149+
@staticmethod
150+
def holdout_validation(
151+
indices: np.ndarray,
152+
random_state: Optional[np.random.RandomState] = None,
153+
val_share: Optional[float] = None,
154+
shuffle: bool = False,
155+
labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None
156+
) -> List[Tuple[np.ndarray, np.ndarray]]:
157+
158+
""" SKLearn requires shuffle=True for stratify """
159+
train, val = train_test_split(
160+
indices, test_size=val_share,
161+
shuffle=shuffle if labels_to_stratify is None else True,
162+
random_state=random_state,
163+
stratify=labels_to_stratify
164+
)
165+
return [(train, val)]
166+
167+
168+
class CrossValFuncs():
169+
# (shuffle, is_stratify) -> split_fn
170+
_args2split_fn = {
171+
(True, True): StratifiedShuffleSplit,
172+
(True, False): ShuffleSplit,
173+
(False, True): StratifiedKFold,
174+
(False, False): KFold,
175+
}
176+
177+
@staticmethod
178+
def k_fold_cross_validation(
179+
indices: np.ndarray,
180+
random_state: Optional[np.random.RandomState] = None,
181+
num_splits: Optional[int] = None,
182+
shuffle: bool = False,
183+
labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None
184+
) -> List[Tuple[np.ndarray, np.ndarray]]:
185+
"""
186+
Returns:
187+
splits (List[Tuple[List, List]]): list of tuples of training and validation indices
188+
"""
189+
190+
split_fn = CrossValFuncs._args2split_fn[(shuffle, labels_to_stratify is not None)]
191+
cv = split_fn(n_splits=num_splits, random_state=random_state)
192+
splits = list(cv.split(indices))
193+
return splits
194+
195+
@staticmethod
196+
def time_series(
197+
indices: np.ndarray,
198+
random_state: Optional[np.random.RandomState] = None,
199+
num_splits: Optional[int] = None,
200+
shuffle: bool = False,
201+
labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None
202+
) -> List[Tuple[np.ndarray, np.ndarray]]:
203+
"""
204+
Returns train and validation indices respecting the temporal ordering of the data.
205+
206+
Examples:
207+
>>> indices = np.array([0, 1, 2, 3])
208+
>>> CrossValFuncs.time_series_cross_validation(3, indices)
209+
[([0], [1]),
210+
([0, 1], [2]),
211+
([0, 1, 2], [3])]
212+
213+
"""
214+
cv = TimeSeriesSplit(n_splits=num_splits)
215+
splits = list(cv.split(indices))
216+
return splits

0 commit comments

Comments
 (0)