1- from enum import Enum
2- from functools import partial
1+ from enum import IntEnum
32from typing import List , NamedTuple , Optional , Tuple , Union
43
54import numpy as np
@@ -92,7 +91,7 @@ def time_series(
9291 return splits
9392
9493
95- class CrossValTypes (Enum ):
94+ class CrossValTypes (IntEnum ):
9695 """The type of cross validation
9796
9897 This class is used to specify the cross validation function
@@ -107,11 +106,11 @@ class CrossValTypes(Enum):
107106 >>> for cross_val_type in CrossValTypes:
108107 print(cross_val_type.name, cross_val_type.value)
109108
110- k_fold_cross_validation functools.partial(<function CrossValFuncs.k_fold_cross_validation at ...>)
111- time_series <function CrossValFuncs.time_series>
109+ k_fold_cross_validation 100
110+ time_series 101
112111 """
113- k_fold_cross_validation = partial ( CrossValFuncs . k_fold_cross_validation )
114- time_series = partial ( CrossValFuncs . time_series )
112+ k_fold_cross_validation = 100
113+ time_series = 101
115114
116115 def __call__ (
117116 self ,
@@ -140,8 +139,9 @@ def __call__(
140139
141140 default_num_splits = _ResamplingStrategyArgs ().num_splits
142141 num_splits = num_splits if num_splits is not None else default_num_splits
142+ split_fn = getattr (CrossValFuncs , self .name )
143143
144- return self . value (
144+ return split_fn (
145145 random_state = random_state if shuffle else None ,
146146 num_splits = num_splits ,
147147 indices = indices ,
@@ -150,7 +150,7 @@ def __call__(
150150 )
151151
152152
153- class HoldoutValTypes (Enum ):
153+ class HoldoutValTypes (IntEnum ):
154154 """The type of holdout validation
155155
156156 This class is used to specify the holdout validation function
@@ -164,7 +164,7 @@ class HoldoutValTypes(Enum):
164164
165165 >>> print(holdout_type.value)
166166
167- functools.partial(<function HoldoutValTypes.holdout_validation at ...>)
167+ 0
168168
169169 >>> for holdout_type in HoldoutValTypes:
170170 print(holdout_type.name)
@@ -174,7 +174,7 @@ class HoldoutValTypes(Enum):
174174 Additionally, HoldoutValTypes.<function> can be called directly.
175175 """
176176
177- holdout_validation = partial ( HoldoutFuncs . holdout_validation )
177+ holdout_validation = 0
178178
179179 def __call__ (
180180 self ,
@@ -203,8 +203,9 @@ def __call__(
203203
204204 default_val_share = _ResamplingStrategyArgs ().val_share
205205 val_share = val_share if val_share is not None else default_val_share
206+ split_fn = getattr (HoldoutFuncs , self .name )
206207
207- return self . value (
208+ return split_fn (
208209 random_state = random_state if shuffle else None ,
209210 val_share = val_share ,
210211 indices = indices ,
0 commit comments