11from abc import ABCMeta
2- from typing import Any , Dict , List , Optional , Sequence , Tuple , Union , cast
2+ from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
33
44import numpy as np
55
1313
1414from autoPyTorch .constants import CLASSIFICATION_OUTPUTS , STRING_TO_OUTPUT_TYPES
1515from autoPyTorch .datasets .resampling_strategy import (
16- CROSS_VAL_FN ,
1716 CrossValTypes ,
1817 DEFAULT_RESAMPLING_PARAMETERS ,
19- HOLDOUT_FN ,
20- HoldoutValTypes ,
21- get_cross_validators ,
22- get_holdout_validators ,
23- is_stratified ,
18+ HoldoutValTypes
2419)
2520from autoPyTorch .utils .common import FitRequirement , hash_array_or_matrix
2621
@@ -112,8 +107,6 @@ def __init__(
112107 if not hasattr (train_tensors [0 ], 'shape' ):
113108 type_check (train_tensors , val_tensors )
114109 self .train_tensors , self .val_tensors , self .test_tensors = train_tensors , val_tensors , test_tensors
115- self .cross_validators : Dict [str , CROSS_VAL_FN ] = {}
116- self .holdout_validators : Dict [str , HOLDOUT_FN ] = {}
117110 self .rng = np .random .RandomState (seed = seed )
118111 self .shuffle = shuffle
119112 self .resampling_strategy = resampling_strategy
@@ -133,9 +126,6 @@ def __init__(
133126 # TODO: Look for a criteria to define small enough to preprocess
134127 self .is_small_preprocess = True
135128
136- # Make sure cross validation splits are created once
137- self .cross_validators = get_cross_validators (* CrossValTypes )
138- self .holdout_validators = get_holdout_validators (* HoldoutValTypes )
139129 self .splits = self .get_splits_from_resampling_strategy ()
140130
141131 # We also need to be able to transform the data, be it for pre-processing
@@ -203,106 +193,82 @@ def __len__(self) -> int:
203193 def _get_indices (self ) -> np .ndarray :
204194 return self .rng .permutation (len (self )) if self .shuffle else np .arange (len (self ))
205195
196+ def _process_resampling_strategy_args (self ) -> None :
197+ """TODO: Refactor this function after introducing BaseDict"""
198+
199+ if not any (isinstance (self .resampling_strategy , val_type )
200+ for val_type in [HoldoutValTypes , CrossValTypes ]):
201+ raise ValueError (f"resampling_strategy { self .resampling_strategy } is not supported." )
202+
203+ if self .splitting_params is not None and \
204+ not isinstance (self .resampling_strategy_args , dict ):
205+
206+ raise TypeError ("resampling_strategy_args must be dict or None,"
207+ f" but got { type (self .resampling_strategy_args )} " )
208+
209+ if self .resampling_strategy_args is None :
210+ self .resampling_strategy_args = {}
211+
212+ if isinstance (self .resampling_strategy , HoldoutValTypes ):
213+ val_share = DEFAULT_RESAMPLING_PARAMETERS [self .resampling_strategy ].get (
214+ 'val_share' , None )
215+ self .resampling_strategy_args ['val_share' ] = val_share
216+ elif isinstance (self .splitting_type , CrossValTypes ):
217+ num_splits = DEFAULT_RESAMPLING_PARAMETERS [self .resampling_strategy ].get (
218+ 'num_splits' , None )
219+ self .resampling_strategy_args ['num_splits' ] = num_splits
220+
221+ """Comment: Do we need this raise Error?"""
222+ if self .val_tensors is not None : # if we need it, we should share it with cross val as well
223+ raise ValueError ('`val_share` specified, but the Dataset was'
224+ ' a given a pre-defined split at initialization already.' )
225+
226+ val_share = self .resampling_strategy_args .get ('val_share' , None )
227+ num_splits = self .resampling_strategy_args .get ('num_splits' , None )
228+
229+ if val_share is not None and (val_share < 0 or val_share > 1 ):
230+ raise ValueError (f"`val_share` must be between 0 and 1, got { val_share } ." )
231+
232+ if num_splits is not None :
233+ if num_splits <= 0 :
234+ raise ValueError (f"`num_splits` must be a positive integer, got { num_splits } ." )
235+ elif not isinstance (num_splits , int ):
236+ raise ValueError (f"`num_splits` must be an integer, got { num_splits } ." )
237+
206238 def get_splits_from_resampling_strategy (self ) -> List [Tuple [List [int ], List [int ]]]:
207239 """
208240 Creates a set of splits based on a resampling strategy provided
209241
210242 Returns
211243 (List[Tuple[List[int], List[int]]]): splits in the [train_indices, val_indices] format
212244 """
213- splits = []
214- if isinstance (self .resampling_strategy , HoldoutValTypes ):
215- val_share = DEFAULT_RESAMPLING_PARAMETERS [self .resampling_strategy ].get (
216- 'val_share' , None )
217- if self .resampling_strategy_args is not None :
218- val_share = self .resampling_strategy_args .get ('val_share' , val_share )
219- splits .append (
220- self .create_holdout_val_split (
221- holdout_val_type = self .resampling_strategy ,
222- val_share = val_share ,
223- )
224- )
225- elif isinstance (self .resampling_strategy , CrossValTypes ):
226- num_splits = DEFAULT_RESAMPLING_PARAMETERS [self .resampling_strategy ].get (
227- 'num_splits' , None )
228- if self .resampling_strategy_args is not None :
229- num_splits = self .resampling_strategy_args .get ('num_splits' , num_splits )
230- # Create the split if it was not created before
231- splits .extend (
232- self .create_cross_val_splits (
233- cross_val_type = self .resampling_strategy ,
234- num_splits = cast (int , num_splits ),
235- )
236- )
237- else :
238- raise ValueError (f"Unsupported resampling strategy={ self .resampling_strategy } " )
239- return splits
240245
241- def create_cross_val_splits (
242- self ,
243- cross_val_type : CrossValTypes ,
244- num_splits : int
245- ) -> List [Tuple [Union [List [int ], np .ndarray ], Union [List [int ], np .ndarray ]]]:
246- """
247- This function creates the cross validation split for the given task.
246+ # check if the requirements are met and if we can get splits
247+ self ._process_resampling_strategy_args ()
248248
249- It is done once per dataset to have comparable results among pipelines
250- Args:
251- cross_val_type (CrossValTypes):
252- num_splits (int): number of splits to be created
253-
254- Returns:
255- (List[Tuple[Union[List[int], np.ndarray], Union[List[int], np.ndarray]]]):
256- list containing 'num_splits' splits.
257- """
258- # Create just the split once
259- # This is gonna be called multiple times, because the current dataset
260- # is being used for multiple pipelines. That is, to be efficient with memory
261- # we dump the dataset to memory and read it on a need basis. So this function
262- # should be robust against multiple calls, and it does so by remembering the splits
263- if not isinstance (cross_val_type , CrossValTypes ):
264- raise NotImplementedError (f'The selected `cross_val_type` "{ cross_val_type } " is not implemented.' )
265249 kwargs = {}
266- if is_stratified (cross_val_type ):
250+ if self . resampling_strategy . is_stratified ():
267251 # we need additional information about the data for stratification
268252 kwargs ["stratify" ] = self .train_tensors [- 1 ]
269- splits = self .cross_validators [cross_val_type .name ](
270- num_splits , self ._get_indices (), ** kwargs )
271- return splits
272253
273- def create_holdout_val_split (
274- self ,
275- holdout_val_type : HoldoutValTypes ,
276- val_share : float ,
277- ) -> Tuple [np .ndarray , np .ndarray ]:
278- """
279- This function creates the holdout split for the given task.
254+ if isinstance (self .resampling_strategy , HoldoutValTypes ):
255+ val_share = self .resampling_strategy_args ['val_share' ]
280256
281- It is done once per dataset to have comparable results among pipelines
282- Args:
283- holdout_val_type (HoldoutValTypes):
284- val_share (float): share of the validation data
257+ return self .resampling_strategy (
258+ val_share = val_share ,
259+ indices = self ._get_indices (),
260+ ** kwargs
261+ )
262+ elif isinstance (self .resampling_strategy , CrossValTypes ):
263+ num_splits = self .resampling_strategy_args ['num_splits' ]
285264
286- Returns:
287- (Tuple[np.ndarray, np.ndarray]): Tuple containing (train_indices, val_indices)
288- """
289- if holdout_val_type is None :
290- raise ValueError (
291- '`val_share` specified, but `holdout_val_type` not specified.'
265+ return self .create_cross_val_splits (
266+ num_splits = int (num_splits ),
267+ indices = self ._get_indices (),
268+ ** kwargs
292269 )
293- if self .val_tensors is not None :
294- raise ValueError (
295- '`val_share` specified, but the Dataset was a given a pre-defined split at initialization already.' )
296- if val_share < 0 or val_share > 1 :
297- raise ValueError (f"`val_share` must be between 0 and 1, got { val_share } ." )
298- if not isinstance (holdout_val_type , HoldoutValTypes ):
299- raise NotImplementedError (f'The specified `holdout_val_type` "{ holdout_val_type } " is not supported.' )
300- kwargs = {}
301- if is_stratified (holdout_val_type ):
302- # we need additional information about the data for stratification
303- kwargs ["stratify" ] = self .train_tensors [- 1 ]
304- train , val = self .holdout_validators [holdout_val_type .name ](val_share , self ._get_indices (), ** kwargs )
305- return train , val
270+ else :
271+ raise ValueError (f"Unsupported resampling strategy={ self .resampling_strategy } " )
306272
307273 def get_dataset_for_training (self , split_id : int ) -> Tuple [Dataset , Dataset ]:
308274 """
0 commit comments