1- from enum import IntEnum
1+ from enum import Enum
2+ from functools import partial
23from typing import Any , Dict , List , Optional , Tuple , Union
34
45import numpy as np
1718
1819# Use callback protocol as workaround, since callable with function fields count 'self' as argument
1920class CrossValFunc (Protocol ):
21+ """TODO: This class is not required anymore, because CrossValTypes class does not require get_validators()"""
2022 def __call__ (self ,
2123 num_splits : int ,
2224 indices : np .ndarray ,
2325 stratify : Optional [Any ]) -> List [Tuple [np .ndarray , np .ndarray ]]:
2426 ...
2527
2628
27- class HoldOutFunc (Protocol ):
29+ class HoldoutValFunc (Protocol ):
2830 def __call__ (self , val_share : float , indices : np .ndarray , stratify : Optional [Any ]
2931 ) -> Tuple [np .ndarray , np .ndarray ]:
3032 ...
3133
3234
33- class CrossValTypes (IntEnum ):
34- """The type of cross validation
35-
36- This class is used to specify the cross validation function
37- and is not supposed to be instantiated.
38-
39- Examples: This class is supposed to be used as follows
40- >>> cv_type = CrossValTypes.k_fold_cross_validation
41- >>> print(cv_type.name)
42-
43- k_fold_cross_validation
44-
45- >>> for cross_val_type in CrossValTypes:
46- print(cross_val_type.name, cross_val_type.value)
47-
48- stratified_k_fold_cross_validation 1
49- k_fold_cross_validation 2
50- stratified_shuffle_split_cross_validation 3
51- shuffle_split_cross_validation 4
52- time_series_cross_validation 5
53- """
54- stratified_k_fold_cross_validation = 1
55- k_fold_cross_validation = 2
56- stratified_shuffle_split_cross_validation = 3
57- shuffle_split_cross_validation = 4
58- time_series_cross_validation = 5
59-
60- def is_stratified (self ) -> bool :
61- stratified = [self .stratified_k_fold_cross_validation ,
62- self .stratified_shuffle_split_cross_validation ]
63- return getattr (self , self .name ) in stratified
64-
65-
66- class HoldoutValTypes (IntEnum ):
67- """TODO: change to enum using functools.partial"""
68- """The type of hold out validation (refer to CrossValTypes' doc-string)"""
69- holdout_validation = 6
70- stratified_holdout_validation = 7
71-
72- def is_stratified (self ) -> bool :
73- stratified = [self .stratified_holdout_validation ]
74- return getattr (self , self .name ) in stratified
75-
76-
77- """TODO: deprecate soon"""
78- RESAMPLING_STRATEGIES = [CrossValTypes , HoldoutValTypes ]
79-
80- """TODO: deprecate soon"""
81- DEFAULT_RESAMPLING_PARAMETERS = {
82- HoldoutValTypes .holdout_validation : {
83- 'val_share' : 0.33 ,
84- },
85- HoldoutValTypes .stratified_holdout_validation : {
86- 'val_share' : 0.33 ,
87- },
88- CrossValTypes .k_fold_cross_validation : {
89- 'num_splits' : 3 ,
90- },
91- CrossValTypes .stratified_k_fold_cross_validation : {
92- 'num_splits' : 3 ,
93- },
94- CrossValTypes .shuffle_split_cross_validation : {
95- 'num_splits' : 3 ,
96- },
97- CrossValTypes .time_series_cross_validation : {
98- 'num_splits' : 3 ,
99- },
100- } # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]]
101-
102-
103- class HoldOutFuncs ():
35+ class HoldoutValFuncs ():
10436 @staticmethod
105- def holdout_validation (val_share : float , indices : np .ndarray , ** kwargs : Any ) -> Tuple [np .ndarray , np .ndarray ]:
37+ def holdout_validation (val_share : float , indices : np .ndarray , stratify : Optional [Any ] = None ) \
38+ -> Tuple [np .ndarray , np .ndarray ]:
10639 train , val = train_test_split (indices , test_size = val_share , shuffle = False )
10740 return train , val
10841
10942 @staticmethod
110- def stratified_holdout_validation (val_share : float , indices : np .ndarray , ** kwargs : Any ) \
43+ def stratified_holdout_validation (val_share : float , indices : np .ndarray , stratify : Optional [ Any ] = None ) \
11144 -> Tuple [np .ndarray , np .ndarray ]:
112- train , val = train_test_split (indices , test_size = val_share , shuffle = False , stratify = kwargs [ " stratify" ] )
45+ train , val = train_test_split (indices , test_size = val_share , shuffle = False , stratify = stratify )
11346 return train , val
11447
115- @classmethod
116- def get_holdout_validators (cls , * holdout_val_types : Tuple [HoldoutValTypes ]) -> Dict [str , HoldOutFunc ]:
117-
118- holdout_validators = {
119- holdout_val_type .name : getattr (cls , holdout_val_type .name )
120- for holdout_val_type in holdout_val_types
121- }
122- return holdout_validators
123-
12448
12549class CrossValFuncs ():
12650 @staticmethod
127- def shuffle_split_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs : Any ) \
51+ def shuffle_split_cross_validation (num_splits : int , indices : np .ndarray , stratify : Optional [ Any ] = None ) \
12852 -> List [Tuple [np .ndarray , np .ndarray ]]:
12953 cv = ShuffleSplit (n_splits = num_splits )
13054 splits = list (cv .split (indices ))
13155 return splits
13256
13357 @staticmethod
134- def stratified_shuffle_split_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs : Any ) \
58+ def stratified_shuffle_split_cross_validation (num_splits : int , indices : np .ndarray ,
59+ stratify : Optional [Any ] = None ) \
13560 -> List [Tuple [np .ndarray , np .ndarray ]]:
13661 cv = StratifiedShuffleSplit (n_splits = num_splits )
137- splits = list (cv .split (indices , kwargs [ " stratify" ] ))
62+ splits = list (cv .split (indices , stratify ))
13863 return splits
13964
14065 @staticmethod
141- def stratified_k_fold_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs : Any ) \
66+ def stratified_k_fold_cross_validation (num_splits : int , indices : np .ndarray , stratify : Optional [ Any ] = None ) \
14267 -> List [Tuple [np .ndarray , np .ndarray ]]:
14368 cv = StratifiedKFold (n_splits = num_splits )
144- splits = list (cv .split (indices , kwargs [ " stratify" ] ))
69+ splits = list (cv .split (indices , stratify ))
14570 return splits
14671
14772 @staticmethod
148- def k_fold_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs : Any ) \
73+ def k_fold_cross_validation (num_splits : int , indices : np .ndarray , stratify : Optional [ Any ] = None ) \
14974 -> List [Tuple [np .ndarray , np .ndarray ]]:
15075 """
15176 Standard k fold cross validation.
@@ -159,7 +84,7 @@ def k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any)
15984 return splits
16085
16186 @staticmethod
162- def time_series_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs : Any ) \
87+ def time_series_cross_validation (num_splits : int , indices : np .ndarray , stratify : Optional [ Any ] = None ) \
16388 -> List [Tuple [np .ndarray , np .ndarray ]]:
16489 """
16590 Returns train and validation indices respecting the temporal ordering of the data.
@@ -176,10 +101,96 @@ def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs:
176101 splits = list (cv .split (indices ))
177102 return splits
178103
179- @classmethod
180- def get_cross_validators (cls , * cross_val_types : CrossValTypes ) -> Dict [str , CrossValFunc ]:
181- cross_validators = {
182- cross_val_type .name : getattr (cls , cross_val_type .name )
183- for cross_val_type in cross_val_types
184- }
185- return cross_validators
104+
105+ class CrossValTypes (Enum ):
106+ """The type of cross validation
107+
108+ This class is used to specify the cross validation function
109+ and is not supposed to be instantiated.
110+
111+ Examples: This class is supposed to be used as follows
112+ >>> cv_type = CrossValTypes.k_fold_cross_validation
113+ >>> print(cv_type.name)
114+
115+ k_fold_cross_validation
116+
117+ >>> print(cv_type.value)
118+
119+ functools.partial(<function CrossValTypes.k_fold_cross_validation at ...>)
120+
121+ >>> for cross_val_type in CrossValTypes:
122+ print(cross_val_type.name)
123+
124+ stratified_k_fold_cross_validation
125+ k_fold_cross_validation
126+ stratified_shuffle_split_cross_validation
127+ shuffle_split_cross_validation
128+ time_series_cross_validation
129+
130+ Additionally, CrossValTypes.<function> can be called directly.
131+ """
132+ stratified_k_fold_cross_validation = partial (CrossValFuncs .stratified_k_fold_cross_validation )
133+ k_fold_cross_validation = partial (CrossValFuncs .k_fold_cross_validation )
134+ stratified_shuffle_split_cross_validation = partial (CrossValFuncs .stratified_shuffle_split_cross_validation )
135+ shuffle_split_cross_validation = partial (CrossValFuncs .shuffle_split_cross_validation )
136+ time_series_cross_validation = partial (CrossValFuncs .time_series_cross_validation )
137+
138+ def is_stratified (self ) -> bool :
139+ stratified = [self .stratified_k_fold_cross_validation ,
140+ self .stratified_shuffle_split_cross_validation ]
141+ return getattr (self , self .name ) in stratified
142+
143+ def __call__ (self , num_splits : int , indices : np .ndarray , stratify : Optional [Any ]
144+ ) -> Tuple [np .ndarray , np .ndarray ]:
145+ """TODO: doc-string and test files"""
146+ self .value (num_splits = num_splits , indices = indices , stratify = stratify )
147+
148+ @staticmethod
149+ def get_validators (* choices : CrossValFunc ):
150+ """TODO: to be compatible, it is here now, but will be deprecated soon."""
151+ return {choice .name : choice .value for choice in choices }
152+
153+
154+ class HoldoutValTypes (Enum ):
155+ """The type of hold out validation (refer to CrossValTypes' doc-string)"""
156+ holdout_validation = partial (HoldoutValFuncs .holdout_validation )
157+ stratified_holdout_validation = partial (HoldoutValFuncs .stratified_holdout_validation )
158+
159+ def is_stratified (self ) -> bool :
160+ stratified = [self .stratified_holdout_validation ]
161+ return getattr (self , self .name ) in stratified
162+
163+ def __call__ (self , val_share : float , indices : np .ndarray , stratify : Optional [Any ]
164+ ) -> Tuple [np .ndarray , np .ndarray ]:
165+ self .value (val_share = val_share , indices = indices , stratify = stratify )
166+
167+ @staticmethod
168+ def get_validators (* choices : HoldoutValFunc ):
169+ """TODO: to be compatible, it is here now, but will be deprecated soon."""
170+ return {choice .name : choice .value for choice in choices }
171+
172+
173+ """TODO: deprecate soon (Will rename CrossValTypes -> CrossValFunc)"""
174+ RESAMPLING_STRATEGIES = [CrossValTypes , HoldoutValTypes ]
175+
176+ """TODO: deprecate soon"""
177+ DEFAULT_RESAMPLING_PARAMETERS = {
178+ HoldoutValTypes .holdout_validation : {
179+ 'val_share' : 0.33 ,
180+ },
181+ HoldoutValTypes .stratified_holdout_validation : {
182+ 'val_share' : 0.33 ,
183+ },
184+ CrossValTypes .k_fold_cross_validation : {
185+ 'num_splits' : 3 ,
186+ },
187+ CrossValTypes .stratified_k_fold_cross_validation : {
188+ 'num_splits' : 3 ,
189+ },
190+ CrossValTypes .shuffle_split_cross_validation : {
191+ 'num_splits' : 3 ,
192+ },
193+ CrossValTypes .time_series_cross_validation : {
194+ 'num_splits' : 3 ,
195+ },
196+ } # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]]
0 commit comments