11from enum import IntEnum
2- from typing import Any , Dict , List , Optional , Tuple , Union , Callable
2+ from typing import Any , Dict , List , Optional , Tuple , Union
33
44import numpy as np
55
1515from typing_extensions import Protocol
1616
1717
18- SplitFunc = Callable [[Union [int , float ], np .ndarray , Any ], List [Tuple [np .ndarray , np .ndarray ]]]
19-
20-
2118# Use callback protocol as workaround, since callable with function fields count 'self' as argument
22- class CROSS_VAL_FN (Protocol ):
23- """TODO: deprecate soon"""
19+ class CrossValFunc (Protocol ):
2420 def __call__ (self ,
2521 num_splits : int ,
2622 indices : np .ndarray ,
2723 stratify : Optional [Any ]) -> List [Tuple [np .ndarray , np .ndarray ]]:
2824 ...
2925
3026
31- class HOLDOUT_FN (Protocol ):
32- """TODO: deprecate soon"""
27+ class HoldOutFunc (Protocol ):
3328 def __call__ (self , val_share : float , indices : np .ndarray , stratify : Optional [Any ]
3429 ) -> Tuple [np .ndarray , np .ndarray ]:
3530 ...
@@ -104,23 +99,6 @@ def is_stratified(self) -> bool:
10499} # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]]
105100
106101
107- def get_holdout_validators (* holdout_val_types : HoldoutValTypes ) -> Dict [str , HOLDOUT_FN ]:
108- """TODO: deprecate soon"""
109- holdout_validators = {} # type: Dict[str, HOLDOUT_FN]
110- for holdout_val_type in holdout_val_types :
111- holdout_val_fn = globals ()[holdout_val_type .name ]
112- holdout_validators [holdout_val_type .name ] = holdout_val_fn
113- return holdout_validators
114-
115-
116- def is_stratified (val_type : Union [str , CrossValTypes , HoldoutValTypes ]) -> bool :
117- """TODO: deprecate soon"""
118- if isinstance (val_type , str ):
119- return val_type .lower ().startswith ("stratified" )
120- else :
121- return val_type .name .lower ().startswith ("stratified" )
122-
123-
124102class HoldOutFuncs ():
125103 @staticmethod
126104 def holdout_validation (val_share : float , indices : np .ndarray , ** kwargs : Any ) -> Tuple [np .ndarray , np .ndarray ]:
@@ -134,7 +112,7 @@ def stratified_holdout_validation(val_share: float, indices: np.ndarray, **kwarg
134112 return train , val
135113
136114 @classmethod
137- def get_holdout_validators (cls , * holdout_val_types : Tuple [HoldoutValTypes ]) -> Dict [str , SplitFunc ]:
115+ def get_holdout_validators (cls , * holdout_val_types : Tuple [HoldoutValTypes ]) -> Dict [str , HoldOutFunc ]:
138116
139117 holdout_validators = {
140118 holdout_val_type .name : getattr (cls , holdout_val_type .name )
@@ -198,7 +176,7 @@ def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs:
198176 return splits
199177
200178 @classmethod
201- def get_cross_validators (cls , * cross_val_types : CrossValTypes ) -> Dict [str , SplitFunc ]:
179+ def get_cross_validators (cls , * cross_val_types : CrossValTypes ) -> Dict [str , CrossValFunc ]:
202180 cross_validators = {
203181 cross_val_type .name : getattr (cls , cross_val_type .name )
204182 for cross_val_type in cross_val_types
0 commit comments