11from enum import IntEnum
2- from typing import Any , Dict , List , Optional , Tuple , Union
2+ from typing import Any , Dict , List , Optional , Tuple , Union , Callable
33
44import numpy as np
55
1515from typing_extensions import Protocol
1616
1717
18+ SplitFunc = Callable [[Union [int , float ], np .ndarray , Any ], List [Tuple [np .ndarray , np .ndarray ]]]
19+
20+
1821# Use callback protocol as workaround, since callable with function fields count 'self' as argument
1922class CROSS_VAL_FN (Protocol ):
23+ """TODO: deprecate soon"""
2024 def __call__ (self ,
2125 num_splits : int ,
2226 indices : np .ndarray ,
@@ -25,26 +29,59 @@ def __call__(self,
2529
2630
2731class HOLDOUT_FN (Protocol ):
32+ """TODO: deprecate soon"""
2833 def __call__ (self , val_share : float , indices : np .ndarray , stratify : Optional [Any ]
2934 ) -> Tuple [np .ndarray , np .ndarray ]:
3035 ...
3136
3237
3338class CrossValTypes (IntEnum ):
39+ """The type of cross validation
40+
41+ This class is used to specify the cross validation function
42+ and is not supposed to be instantiated.
43+
44+ Examples: This class is supposed to be used as follows
45+ >>> cv_type = CrossValTypes.k_fold_cross_validation
46+ >>> print(cv_type.name)
47+
48+ k_fold_cross_validation
49+
50+ >>> for cross_val_type in CrossValTypes:
51+ print(cross_val_type.name, cross_val_type.value)
52+
53+ stratified_k_fold_cross_validation 1
54+ k_fold_cross_validation 2
55+ stratified_shuffle_split_cross_validation 3
56+ shuffle_split_cross_validation 4
57+ time_series_cross_validation 5
58+ """
3459 stratified_k_fold_cross_validation = 1
3560 k_fold_cross_validation = 2
3661 stratified_shuffle_split_cross_validation = 3
3762 shuffle_split_cross_validation = 4
3863 time_series_cross_validation = 5
3964
65+ def is_stratified (self ) -> bool :
66+ stratified = [self .stratified_k_fold_cross_validation ,
67+ self .stratified_shuffle_split_cross_validation ]
68+ return getattr (self , self .name ) in stratified
69+
4070
4171class HoldoutValTypes (IntEnum ):
72+ """The type of hold out validation (refer to CrossValTypes' doc-string)"""
4273 holdout_validation = 6
4374 stratified_holdout_validation = 7
4475
76+ def is_stratified (self ) -> bool :
77+ stratified = [self .stratified_holdout_validation ]
78+ return getattr (self , self .name ) in stratified
79+
4580
81+ """TODO: deprecate soon"""
4682RESAMPLING_STRATEGIES = [CrossValTypes , HoldoutValTypes ]
4783
84+ """TODO: deprecate soon"""
4885DEFAULT_RESAMPLING_PARAMETERS = {
4986 HoldoutValTypes .holdout_validation : {
5087 'val_share' : 0.33 ,
@@ -67,15 +104,8 @@ class HoldoutValTypes(IntEnum):
67104} # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]]
68105
69106
70- def get_cross_validators (* cross_val_types : CrossValTypes ) -> Dict [str , CROSS_VAL_FN ]:
71- cross_validators = {} # type: Dict[str, CROSS_VAL_FN]
72- for cross_val_type in cross_val_types :
73- cross_val_fn = globals ()[cross_val_type .name ]
74- cross_validators [cross_val_type .name ] = cross_val_fn
75- return cross_validators
76-
77-
78107def get_holdout_validators (* holdout_val_types : HoldoutValTypes ) -> Dict [str , HOLDOUT_FN ]:
108+ """TODO: deprecate soon"""
79109 holdout_validators = {} # type: Dict[str, HOLDOUT_FN]
80110 for holdout_val_type in holdout_val_types :
81111 holdout_val_fn = globals ()[holdout_val_type .name ]
@@ -84,70 +114,93 @@ def get_holdout_validators(*holdout_val_types: HoldoutValTypes) -> Dict[str, HOL
84114
85115
86116def is_stratified (val_type : Union [str , CrossValTypes , HoldoutValTypes ]) -> bool :
117+ """TODO: deprecate soon"""
87118 if isinstance (val_type , str ):
88119 return val_type .lower ().startswith ("stratified" )
89120 else :
90121 return val_type .name .lower ().startswith ("stratified" )
91122
92123
93- def holdout_validation (val_share : float , indices : np .ndarray , ** kwargs : Any ) -> Tuple [np .ndarray , np .ndarray ]:
94- train , val = train_test_split (indices , test_size = val_share , shuffle = False )
95- return train , val
96-
97-
98- def stratified_holdout_validation (val_share : float , indices : np .ndarray , ** kwargs : Any ) \
99- -> Tuple [np .ndarray , np .ndarray ]:
100- train , val = train_test_split (indices , test_size = val_share , shuffle = False , stratify = kwargs ["stratify" ])
101- return train , val
102-
103-
104- def shuffle_split_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs : Any ) \
105- -> List [Tuple [np .ndarray , np .ndarray ]]:
106- cv = ShuffleSplit (n_splits = num_splits )
107- splits = list (cv .split (indices ))
108- return splits
109-
110-
111- def stratified_shuffle_split_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs : Any ) \
112- -> List [Tuple [np .ndarray , np .ndarray ]]:
113- cv = StratifiedShuffleSplit (n_splits = num_splits )
114- splits = list (cv .split (indices , kwargs ["stratify" ]))
115- return splits
116-
117-
118- def stratified_k_fold_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs : Any ) \
119- -> List [Tuple [np .ndarray , np .ndarray ]]:
120- cv = StratifiedKFold (n_splits = num_splits )
121- splits = list (cv .split (indices , kwargs ["stratify" ]))
122- return splits
123-
124-
125- def k_fold_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs : Any ) -> List [Tuple [np .ndarray , np .ndarray ]]:
126- """
127- Standard k fold cross validation.
128-
129- :param indices: array of indices to be split
130- :param num_splits: number of cross validation splits
131- :return: list of tuples of training and validation indices
132- """
133- cv = KFold (n_splits = num_splits )
134- splits = list (cv .split (indices ))
135- return splits
136-
137-
138- def time_series_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs : Any ) \
139- -> List [Tuple [np .ndarray , np .ndarray ]]:
140- """
141- Returns train and validation indices respecting the temporal ordering of the data.
142- Dummy example: [0, 1, 2, 3] with 3 folds yields
143- [0] [1]
144- [0, 1] [2]
145- [0, 1, 2] [3]
146-
147- :param indices: array of indices to be split
148- :param num_splits: number of cross validation splits
149- :return: list of tuples of training and validation indices
150- """
151- cv = TimeSeriesSplit (n_splits = num_splits )
152- splits = list (cv .split (indices ))
153- return splits
124+ class HoldOutFuncs ():
125+ @staticmethod
126+ def holdout_validation (val_share : float , indices : np .ndarray , ** kwargs : Any ) -> Tuple [np .ndarray , np .ndarray ]:
127+ train , val = train_test_split (indices , test_size = val_share , shuffle = False )
128+ return train , val
129+
130+ @staticmethod
131+ def stratified_holdout_validation (val_share : float , indices : np .ndarray , ** kwargs : Any ) \
132+ -> Tuple [np .ndarray , np .ndarray ]:
133+ train , val = train_test_split (indices , test_size = val_share , shuffle = False , stratify = kwargs ["stratify" ])
134+ return train , val
135+
136+ @classmethod
137+ def get_holdout_validators (cls , * holdout_val_types : Tuple [HoldoutValTypes ]) -> Dict [str , SplitFunc ]:
138+
139+ holdout_validators = {
140+ holdout_val_type .name : getattr (cls , holdout_val_type .name )
141+ for holdout_val_type in holdout_val_types
142+ }
143+ return holdout_validators
144+
145+
146+ class CrossValFuncs ():
147+ @staticmethod
148+ def shuffle_split_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs : Any ) \
149+ -> List [Tuple [np .ndarray , np .ndarray ]]:
150+ cv = ShuffleSplit (n_splits = num_splits )
151+ splits = list (cv .split (indices ))
152+ return splits
153+
154+ @staticmethod
155+ def stratified_shuffle_split_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs : Any ) \
156+ -> List [Tuple [np .ndarray , np .ndarray ]]:
157+ cv = StratifiedShuffleSplit (n_splits = num_splits )
158+ splits = list (cv .split (indices , kwargs ["stratify" ]))
159+ return splits
160+
161+ @staticmethod
162+ def stratified_k_fold_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs : Any ) \
163+ -> List [Tuple [np .ndarray , np .ndarray ]]:
164+ cv = StratifiedKFold (n_splits = num_splits )
165+ splits = list (cv .split (indices , kwargs ["stratify" ]))
166+ return splits
167+
168+ @staticmethod
169+ def k_fold_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs : Any ) \
170+ -> List [Tuple [np .ndarray , np .ndarray ]]:
171+ """
172+ Standard k fold cross validation.
173+
174+ :param indices: array of indices to be split
175+ :param num_splits: number of cross validation splits
176+ :return: list of tuples of training and validation indices
177+ """
178+ cv = KFold (n_splits = num_splits )
179+ splits = list (cv .split (indices ))
180+ return splits
181+
182+ @staticmethod
183+ def time_series_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs : Any ) \
184+ -> List [Tuple [np .ndarray , np .ndarray ]]:
185+ """
186+ Returns train and validation indices respecting the temporal ordering of the data.
187+ Dummy example: [0, 1, 2, 3] with 3 folds yields
188+ [0] [1]
189+ [0, 1] [2]
190+ [0, 1, 2] [3]
191+
192+ :param indices: array of indices to be split
193+ :param num_splits: number of cross validation splits
194+ :return: list of tuples of training and validation indices
195+ """
196+ cv = TimeSeriesSplit (n_splits = num_splits )
197+ splits = list (cv .split (indices ))
198+ return splits
199+
200+ @classmethod
201+ def get_cross_validators (cls , * cross_val_types : CrossValTypes ) -> Dict [str , SplitFunc ]:
202+ cross_validators = {
203+ cross_val_type .name : getattr (cls , cross_val_type .name )
204+ for cross_val_type in cross_val_types
205+ }
206+ return cross_validators
0 commit comments