1818# Use callback protocol as workaround, since callable with function fields count 'self' as argument
1919class CrossValFunc (Protocol ):
2020 def __call__ (self ,
21+ random_state : np .random .RandomState ,
2122 num_splits : int ,
2223 indices : np .ndarray ,
2324 stratify : Optional [Any ]) -> List [Tuple [np .ndarray , np .ndarray ]]:
2425 ...
2526
2627
2728class HoldOutFunc (Protocol ):
28- def __call__ (self , val_share : float , indices : np .ndarray , stratify : Optional [Any ]
29+ def __call__ (self , random_state : np .random .RandomState , val_share : float ,
30+ indices : np .ndarray , stratify : Optional [Any ]
2931 ) -> Tuple [np .ndarray , np .ndarray ]:
3032 ...
3133
@@ -85,35 +87,42 @@ def is_stratified(self) -> bool:
8587 'val_share' : 0.33 ,
8688 },
8789 CrossValTypes .k_fold_cross_validation : {
88- 'num_splits' : 3 ,
90+ 'num_splits' : 5 ,
8991 },
9092 CrossValTypes .stratified_k_fold_cross_validation : {
91- 'num_splits' : 3 ,
93+ 'num_splits' : 5 ,
9294 },
9395 CrossValTypes .shuffle_split_cross_validation : {
94- 'num_splits' : 3 ,
96+ 'num_splits' : 5 ,
9597 },
9698 CrossValTypes .time_series_cross_validation : {
97- 'num_splits' : 3 ,
99+ 'num_splits' : 5 ,
98100 },
99101} # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]]
100102
101103
102104class HoldOutFuncs ():
103105 @staticmethod
104- def holdout_validation (val_share : float ,
106+ def holdout_validation (random_state : np .random .RandomState ,
107+ val_share : float ,
105108 indices : np .ndarray ,
106109 ** kwargs : Any
107110 ) -> Tuple [np .ndarray , np .ndarray ]:
108- train , val = train_test_split (indices , test_size = val_share , shuffle = False )
111+ shuffle = kwargs .get ('shuffle' , True )
112+ train , val = train_test_split (indices , test_size = val_share ,
113+ shuffle = shuffle ,
114+ random_state = random_state if shuffle else None ,
115+ )
109116 return train , val
110117
111118 @staticmethod
112- def stratified_holdout_validation (val_share : float ,
119+ def stratified_holdout_validation (random_state : np .random .RandomState ,
120+ val_share : float ,
113121 indices : np .ndarray ,
114122 ** kwargs : Any
115123 ) -> Tuple [np .ndarray , np .ndarray ]:
116- train , val = train_test_split (indices , test_size = val_share , shuffle = True , stratify = kwargs ["stratify" ])
124+ train , val = train_test_split (indices , test_size = val_share , shuffle = True , stratify = kwargs ["stratify" ],
125+ random_state = random_state )
117126 return train , val
118127
119128 @classmethod
@@ -128,34 +137,38 @@ def get_holdout_validators(cls, *holdout_val_types: HoldoutValTypes) -> Dict[str
128137
129138class CrossValFuncs ():
130139 @staticmethod
131- def shuffle_split_cross_validation (num_splits : int ,
140+ def shuffle_split_cross_validation (random_state : np .random .RandomState ,
141+ num_splits : int ,
132142 indices : np .ndarray ,
133143 ** kwargs : Any
134144 ) -> List [Tuple [np .ndarray , np .ndarray ]]:
135- cv = ShuffleSplit (n_splits = num_splits )
145+ cv = ShuffleSplit (n_splits = num_splits , random_state = random_state )
136146 splits = list (cv .split (indices ))
137147 return splits
138148
139149 @staticmethod
140- def stratified_shuffle_split_cross_validation (num_splits : int ,
150+ def stratified_shuffle_split_cross_validation (random_state : np .random .RandomState ,
151+ num_splits : int ,
141152 indices : np .ndarray ,
142153 ** kwargs : Any
143154 ) -> List [Tuple [np .ndarray , np .ndarray ]]:
144- cv = StratifiedShuffleSplit (n_splits = num_splits )
155+ cv = StratifiedShuffleSplit (n_splits = num_splits , random_state = random_state )
145156 splits = list (cv .split (indices , kwargs ["stratify" ]))
146157 return splits
147158
148159 @staticmethod
149- def stratified_k_fold_cross_validation (num_splits : int ,
160+ def stratified_k_fold_cross_validation (random_state : np .random .RandomState ,
161+ num_splits : int ,
150162 indices : np .ndarray ,
151163 ** kwargs : Any
152164 ) -> List [Tuple [np .ndarray , np .ndarray ]]:
153- cv = StratifiedKFold (n_splits = num_splits )
165+ cv = StratifiedKFold (n_splits = num_splits , random_state = random_state )
154166 splits = list (cv .split (indices , kwargs ["stratify" ]))
155167 return splits
156168
157169 @staticmethod
158- def k_fold_cross_validation (num_splits : int ,
170+ def k_fold_cross_validation (random_state : np .random .RandomState ,
171+ num_splits : int ,
159172 indices : np .ndarray ,
160173 ** kwargs : Any
161174 ) -> List [Tuple [np .ndarray , np .ndarray ]]:
@@ -169,12 +182,14 @@ def k_fold_cross_validation(num_splits: int,
169182 Returns:
170183 splits (List[Tuple[List, List]]): list of tuples of training and validation indices
171184 """
172- cv = KFold (n_splits = num_splits )
185+ shuffle = kwargs .get ('shuffle' , True )
186+ cv = KFold (n_splits = num_splits , random_state = random_state if shuffle else None , shuffle = shuffle )
173187 splits = list (cv .split (indices ))
174188 return splits
175189
176190 @staticmethod
177- def time_series_cross_validation (num_splits : int ,
191+ def time_series_cross_validation (random_state : np .random .RandomState ,
192+ num_splits : int ,
178193 indices : np .ndarray ,
179194 ** kwargs : Any
180195 ) -> List [Tuple [np .ndarray , np .ndarray ]]:
@@ -196,7 +211,7 @@ def time_series_cross_validation(num_splits: int,
196211 ([0, 1, 2], [3])]
197212
198213 """
199- cv = TimeSeriesSplit (n_splits = num_splits )
214+ cv = TimeSeriesSplit (n_splits = num_splits , random_state = random_state )
200215 splits = list (cv .split (indices ))
201216 return splits
202217
0 commit comments