@@ -22,77 +22,6 @@ class _ResamplingStrategyArgs(NamedTuple):
2222 stratify : bool = False
2323
2424
25- class HoldoutFuncs ():
26- @staticmethod
27- def holdout_validation (
28- indices : np .ndarray ,
29- random_state : Optional [np .random .RandomState ] = None ,
30- val_share : Optional [float ] = None ,
31- shuffle : bool = False ,
32- labels_to_stratify : Optional [Union [Tuple [np .ndarray , np .ndarray ], Dataset ]] = None
33- ) -> List [Tuple [np .ndarray , np .ndarray ]]:
34-
35- """ SKLearn requires shuffle=True for stratify """
36- train , val = train_test_split (
37- indices , test_size = val_share ,
38- shuffle = shuffle if labels_to_stratify is None else True ,
39- random_state = random_state ,
40- stratify = labels_to_stratify
41- )
42- return [(train , val )]
43-
44-
45- class CrossValFuncs ():
46- # (shuffle, is_stratify) -> split_fn
47- _args2split_fn = {
48- (True , True ): StratifiedShuffleSplit ,
49- (True , False ): ShuffleSplit ,
50- (False , True ): StratifiedKFold ,
51- (False , False ): KFold ,
52- }
53-
54- @staticmethod
55- def k_fold_cross_validation (
56- indices : np .ndarray ,
57- random_state : Optional [np .random .RandomState ] = None ,
58- num_splits : Optional [int ] = None ,
59- shuffle : bool = False ,
60- labels_to_stratify : Optional [Union [Tuple [np .ndarray , np .ndarray ], Dataset ]] = None
61- ) -> List [Tuple [np .ndarray , np .ndarray ]]:
62- """
63- Returns:
64- splits (List[Tuple[List, List]]): list of tuples of training and validation indices
65- """
66-
67- split_fn = CrossValFuncs ._args2split_fn [(shuffle , labels_to_stratify is not None )]
68- cv = split_fn (n_splits = num_splits , random_state = random_state )
69- splits = list (cv .split (indices ))
70- return splits
71-
72- @staticmethod
73- def time_series (
74- indices : np .ndarray ,
75- random_state : Optional [np .random .RandomState ] = None ,
76- num_splits : Optional [int ] = None ,
77- shuffle : bool = False ,
78- labels_to_stratify : Optional [Union [Tuple [np .ndarray , np .ndarray ], Dataset ]] = None
79- ) -> List [Tuple [np .ndarray , np .ndarray ]]:
80- """
81- Returns train and validation indices respecting the temporal ordering of the data.
82-
83- Examples:
84- >>> indices = np.array([0, 1, 2, 3])
85- >>> CrossValFuncs.time_series_cross_validation(3, indices)
86- [([0], [1]),
87- ([0, 1], [2]),
88- ([0, 1, 2], [3])]
89-
90- """
91- cv = TimeSeriesSplit (n_splits = num_splits )
92- splits = list (cv .split (indices ))
93- return splits
94-
95-
9625class CrossValTypes (IntEnum ):
9726 """The type of cross validation
9827
@@ -214,3 +143,74 @@ def __call__(
214143 shuffle = shuffle ,
215144 labels_to_stratify = labels_to_stratify
216145 )
146+
147+
148+ class HoldoutFuncs ():
149+ @staticmethod
150+ def holdout_validation (
151+ indices : np .ndarray ,
152+ random_state : Optional [np .random .RandomState ] = None ,
153+ val_share : Optional [float ] = None ,
154+ shuffle : bool = False ,
155+ labels_to_stratify : Optional [Union [Tuple [np .ndarray , np .ndarray ], Dataset ]] = None
156+ ) -> List [Tuple [np .ndarray , np .ndarray ]]:
157+
158+ """ SKLearn requires shuffle=True for stratify """
159+ train , val = train_test_split (
160+ indices , test_size = val_share ,
161+ shuffle = shuffle if labels_to_stratify is None else True ,
162+ random_state = random_state ,
163+ stratify = labels_to_stratify
164+ )
165+ return [(train , val )]
166+
167+
168+ class CrossValFuncs ():
169+ # (shuffle, is_stratify) -> split_fn
170+ _args2split_fn = {
171+ (True , True ): StratifiedShuffleSplit ,
172+ (True , False ): ShuffleSplit ,
173+ (False , True ): StratifiedKFold ,
174+ (False , False ): KFold ,
175+ }
176+
177+ @staticmethod
178+ def k_fold_cross_validation (
179+ indices : np .ndarray ,
180+ random_state : Optional [np .random .RandomState ] = None ,
181+ num_splits : Optional [int ] = None ,
182+ shuffle : bool = False ,
183+ labels_to_stratify : Optional [Union [Tuple [np .ndarray , np .ndarray ], Dataset ]] = None
184+ ) -> List [Tuple [np .ndarray , np .ndarray ]]:
185+ """
186+ Returns:
187+ splits (List[Tuple[List, List]]): list of tuples of training and validation indices
188+ """
189+
190+ split_fn = CrossValFuncs ._args2split_fn [(shuffle , labels_to_stratify is not None )]
191+ cv = split_fn (n_splits = num_splits , random_state = random_state )
192+ splits = list (cv .split (indices ))
193+ return splits
194+
195+ @staticmethod
196+ def time_series (
197+ indices : np .ndarray ,
198+ random_state : Optional [np .random .RandomState ] = None ,
199+ num_splits : Optional [int ] = None ,
200+ shuffle : bool = False ,
201+ labels_to_stratify : Optional [Union [Tuple [np .ndarray , np .ndarray ], Dataset ]] = None
202+ ) -> List [Tuple [np .ndarray , np .ndarray ]]:
203+ """
204+ Returns train and validation indices respecting the temporal ordering of the data.
205+
206+ Examples:
207+ >>> indices = np.array([0, 1, 2, 3])
208+ >>> CrossValFuncs.time_series_cross_validation(3, indices)
209+ [([0], [1]),
210+ ([0, 1], [2]),
211+ ([0, 1, 2], [3])]
212+
213+ """
214+ cv = TimeSeriesSplit (n_splits = num_splits )
215+ splits = list (cv .split (indices ))
216+ return splits
0 commit comments