Skip to content

Commit 910e7d4

Browse files
committed
[fix] Fix most test cases
1 parent eee3b1c commit 910e7d4

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

autoPyTorch/datasets/resampling_strategy.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from enum import Enum
2-
from functools import partial
1+
from enum import IntEnum
32
from typing import List, NamedTuple, Optional, Tuple, Union
43

54
import numpy as np
@@ -92,7 +91,7 @@ def time_series(
9291
return splits
9392

9493

95-
class CrossValTypes(Enum):
94+
class CrossValTypes(IntEnum):
9695
"""The type of cross validation
9796
9897
This class is used to specify the cross validation function
@@ -107,11 +106,11 @@ class CrossValTypes(Enum):
107106
>>> for cross_val_type in CrossValTypes:
108107
print(cross_val_type.name, cross_val_type.value)
109108
110-
k_fold_cross_validation functools.partial(<function CrossValFuncs.k_fold_cross_validation at ...>)
111-
time_series <function CrossValFuncs.time_series>
109+
k_fold_cross_validation 100
110+
time_series 101
112111
"""
113-
k_fold_cross_validation = partial(CrossValFuncs.k_fold_cross_validation)
114-
time_series = partial(CrossValFuncs.time_series)
112+
k_fold_cross_validation = 100
113+
time_series = 101
115114

116115
def __call__(
117116
self,
@@ -140,8 +139,9 @@ def __call__(
140139

141140
default_num_splits = _ResamplingStrategyArgs().num_splits
142141
num_splits = num_splits if num_splits is not None else default_num_splits
142+
split_fn = getattr(CrossValFuncs, self.name)
143143

144-
return self.value(
144+
return split_fn(
145145
random_state=random_state if shuffle else None,
146146
num_splits=num_splits,
147147
indices=indices,
@@ -150,7 +150,7 @@ def __call__(
150150
)
151151

152152

153-
class HoldoutValTypes(Enum):
153+
class HoldoutValTypes(IntEnum):
154154
"""The type of holdout validation
155155
156156
This class is used to specify the holdout validation function
@@ -164,7 +164,7 @@ class HoldoutValTypes(Enum):
164164
165165
>>> print(holdout_type.value)
166166
167-
functools.partial(<function HoldoutValTypes.holdout_validation at ...>)
167+
0
168168
169169
>>> for holdout_type in HoldoutValTypes:
170170
print(holdout_type.name)
@@ -174,7 +174,7 @@ class HoldoutValTypes(Enum):
174174
Additionally, HoldoutValTypes.<function> can be called directly.
175175
"""
176176

177-
holdout_validation = partial(HoldoutFuncs.holdout_validation)
177+
holdout_validation = 0
178178

179179
def __call__(
180180
self,
@@ -203,8 +203,9 @@ def __call__(
203203

204204
default_val_share = _ResamplingStrategyArgs().val_share
205205
val_share = val_share if val_share is not None else default_val_share
206+
split_fn = getattr(HoldoutFuncs, self.name)
206207

207-
return self.value(
208+
return split_fn(
208209
random_state=random_state if shuffle else None,
209210
val_share=val_share,
210211
indices=indices,

0 commit comments

Comments
 (0)