11import os
22import uuid
33from abc import ABCMeta
4- from typing import Any , Dict , List , Optional , Sequence , Tuple , Union , cast
4+ from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
55
66import numpy as np
77
1414import torchvision
1515
1616from autoPyTorch .constants import CLASSIFICATION_OUTPUTS , STRING_TO_OUTPUT_TYPES
17- from autoPyTorch .datasets .resampling_strategy import (
18- CrossValFunc ,
19- CrossValFuncs ,
20- CrossValTypes ,
21- DEFAULT_RESAMPLING_PARAMETERS ,
22- HoldOutFunc ,
23- HoldOutFuncs ,
24- HoldoutValTypes
25- )
17+ from autoPyTorch .datasets .resampling_strategy import CrossValTypes , HoldoutTypes
2618from autoPyTorch .utils .common import FitRequirement
2719
2820BaseDatasetInputType = Union [Tuple [np .ndarray , np .ndarray ], Dataset ]
@@ -77,7 +69,7 @@ def __init__(
7769 dataset_name : Optional [str ] = None ,
7870 val_tensors : Optional [BaseDatasetInputType ] = None ,
7971 test_tensors : Optional [BaseDatasetInputType ] = None ,
80- resampling_strategy : Union [CrossValTypes , HoldoutValTypes ] = HoldoutValTypes . holdout_validation ,
72+ resampling_strategy : Union [CrossValTypes , HoldoutTypes ] = HoldoutTypes . holdout ,
8173 resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
8274 shuffle : Optional [bool ] = True ,
8375 seed : Optional [int ] = 42 ,
@@ -94,14 +86,14 @@ def __init__(
9486 validation data
9587 test_tensors (An optional tuple of objects that have a __len__ and a __getitem__ attribute):
9688 test data
97- resampling_strategy (Union[CrossValTypes, HoldoutValTypes ]),
98- (default=HoldoutValTypes.holdout_validation ):
89+ resampling_strategy (Union[CrossValTypes, HoldoutTypes ]),
90+ (default=HoldoutTypes.holdout ):
9991 strategy to split the training data.
10092 resampling_strategy_args (Optional[Dict[str, Any]]): arguments
10193 required for the chosen resampling strategy. If None, uses
10294 the default values provided in DEFAULT_RESAMPLING_PARAMETERS
10395 in ```datasets/resampling_strategy.py```.
104- shuffle: Whether to shuffle the data before performing splits
96+ shuffle: Whether to shuffle the data when performing splits
10597 seed (int), (default=1): seed to be used for reproducibility.
10698 train_transforms (Optional[torchvision.transforms.Compose]):
10799 Additional Transforms to be applied to the training data
@@ -116,12 +108,12 @@ def __init__(
116108 if not hasattr (train_tensors [0 ], 'shape' ):
117109 type_check (train_tensors , val_tensors )
118110 self .train_tensors , self .val_tensors , self .test_tensors = train_tensors , val_tensors , test_tensors
119- self .cross_validators : Dict [str , CrossValFunc ] = {}
120- self .holdout_validators : Dict [str , HoldOutFunc ] = {}
121111 self .random_state = np .random .RandomState (seed = seed )
122112 self .shuffle = shuffle
123113 self .resampling_strategy = resampling_strategy
124114 self .resampling_strategy_args = resampling_strategy_args
115+ self .is_stratify = self .resampling_strategy .get ('stratify' , False )
116+
125117 self .task_type : Optional [str ] = None
126118 self .issparse : bool = issparse (self .train_tensors [0 ])
127119 self .input_shape : Tuple [int ] = self .train_tensors [0 ].shape [1 :]
@@ -137,9 +129,6 @@ def __init__(
137129 # TODO: Look for a criteria to define small enough to preprocess
138130 self .is_small_preprocess = True
139131
140- # Make sure cross validation splits are created once
141- self .cross_validators = CrossValFuncs .get_cross_validators (* CrossValTypes )
142- self .holdout_validators = HoldOutFuncs .get_holdout_validators (* HoldoutValTypes )
143132 self .splits = self .get_splits_from_resampling_strategy ()
144133
145134 # We also need to be able to transform the data, be it for pre-processing
@@ -205,7 +194,30 @@ def __len__(self) -> int:
205194 return self .train_tensors [0 ].shape [0 ]
206195
207196 def _get_indices (self ) -> np .ndarray :
208- return self .random_state .permutation (len (self )) if self .shuffle else np .arange (len (self ))
197+ return np .arange (len (self ))
198+
199+ def _process_resampling_strategy_args (self ) -> None :
200+ if not any (isinstance (self .resampling_strategy , val_type )
201+ for val_type in [HoldoutTypes , CrossValTypes ]):
202+ raise ValueError (f"resampling_strategy { self .resampling_strategy } is not supported." )
203+
204+ if self .resampling_strategy_args is not None and \
205+ not isinstance (self .resampling_strategy_args , dict ):
206+
207+ raise TypeError ("resampling_strategy_args must be dict or None,"
208+ f" but got { type (self .resampling_strategy_args )} " )
209+
210+ val_share = self .resampling_strategy_args .get ('val_share' , None )
211+ num_splits = self .resampling_strategy_args .get ('num_splits' , None )
212+
213+ if val_share is not None and (val_share < 0 or val_share > 1 ):
214+ raise ValueError (f"`val_share` must be between 0 and 1, got { val_share } ." )
215+
216+ if num_splits is not None :
217+ if num_splits <= 0 :
218+ raise ValueError (f"`num_splits` must be a positive integer, got { num_splits } ." )
219+ elif not isinstance (num_splits , int ):
220+ raise ValueError (f"`num_splits` must be an integer, got { num_splits } ." )
209221
210222 def get_splits_from_resampling_strategy (self ) -> List [Tuple [List [int ], List [int ]]]:
211223 """
@@ -214,100 +226,33 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]
214226 Returns
215227 (List[Tuple[List[int], List[int]]]): splits in the [train_indices, val_indices] format
216228 """
217- splits = []
218- if isinstance (self .resampling_strategy , HoldoutValTypes ):
219- val_share = DEFAULT_RESAMPLING_PARAMETERS [self .resampling_strategy ].get (
220- 'val_share' , None )
221- if self .resampling_strategy_args is not None :
222- val_share = self .resampling_strategy_args .get ('val_share' , val_share )
223- splits .append (
224- self .create_holdout_val_split (
225- holdout_val_type = self .resampling_strategy ,
226- val_share = val_share ,
227- )
229+ # check if the requirements are met and if we can get splits
230+ self ._process_resampling_strategy_args ()
231+
232+ labels_to_stratify = self .train_tensors [- 1 ] if self .is_stratify else None
233+
234+ if isinstance (self .resampling_strategy , HoldoutTypes ):
235+ val_share = self .resampling_strategy_args ['val_share' ]
236+
237+ return self .resampling_strategy (
238+ random_state = self .random_state ,
239+ val_share = val_share ,
240+ shuffle = self .shuffle ,
241+ indices = self ._get_indices (),
242+ labels_to_stratify = labels_to_stratify
228243 )
229244 elif isinstance (self .resampling_strategy , CrossValTypes ):
230- num_splits = DEFAULT_RESAMPLING_PARAMETERS [self .resampling_strategy ].get (
231- 'num_splits' , None )
232- if self .resampling_strategy_args is not None :
233- num_splits = self .resampling_strategy_args .get ('num_splits' , num_splits )
234- # Create the split if it was not created before
235- splits .extend (
236- self .create_cross_val_splits (
237- cross_val_type = self .resampling_strategy ,
238- num_splits = cast (int , num_splits ),
239- )
245+ num_splits = self .resampling_strategy_args ['num_splits' ]
246+
247+ return self .create_cross_val_splits (
248+ random_state = self .random_state ,
249+ num_splits = int (num_splits ),
250+ shuffle = self .shuffle ,
251+ indices = self ._get_indices (),
252+ labels_to_stratify = labels_to_stratify
240253 )
241254 else :
242255 raise ValueError (f"Unsupported resampling strategy={ self .resampling_strategy } " )
243- return splits
244-
245- def create_cross_val_splits (
246- self ,
247- cross_val_type : CrossValTypes ,
248- num_splits : int
249- ) -> List [Tuple [Union [List [int ], np .ndarray ], Union [List [int ], np .ndarray ]]]:
250- """
251- This function creates the cross validation split for the given task.
252-
253- It is done once per dataset to have comparable results among pipelines
254- Args:
255- cross_val_type (CrossValTypes):
256- num_splits (int): number of splits to be created
257-
258- Returns:
259- (List[Tuple[Union[List[int], np.ndarray], Union[List[int], np.ndarray]]]):
260- list containing 'num_splits' splits.
261- """
262- # Create just the split once
263- # This is gonna be called multiple times, because the current dataset
264- # is being used for multiple pipelines. That is, to be efficient with memory
265- # we dump the dataset to memory and read it on a need basis. So this function
266- # should be robust against multiple calls, and it does so by remembering the splits
267- if not isinstance (cross_val_type , CrossValTypes ):
268- raise NotImplementedError (f'The selected `cross_val_type` "{ cross_val_type } " is not implemented.' )
269- kwargs = {}
270- if cross_val_type .is_stratified ():
271- # we need additional information about the data for stratification
272- kwargs ["stratify" ] = self .train_tensors [- 1 ]
273- splits = self .cross_validators [cross_val_type .name ](
274- self .random_state , num_splits , self ._get_indices (), ** kwargs )
275- return splits
276-
277- def create_holdout_val_split (
278- self ,
279- holdout_val_type : HoldoutValTypes ,
280- val_share : float ,
281- ) -> Tuple [np .ndarray , np .ndarray ]:
282- """
283- This function creates the holdout split for the given task.
284-
285- It is done once per dataset to have comparable results among pipelines
286- Args:
287- holdout_val_type (HoldoutValTypes):
288- val_share (float): share of the validation data
289-
290- Returns:
291- (Tuple[np.ndarray, np.ndarray]): Tuple containing (train_indices, val_indices)
292- """
293- if holdout_val_type is None :
294- raise ValueError (
295- '`val_share` specified, but `holdout_val_type` not specified.'
296- )
297- if self .val_tensors is not None :
298- raise ValueError (
299- '`val_share` specified, but the Dataset was a given a pre-defined split at initialization already.' )
300- if val_share < 0 or val_share > 1 :
301- raise ValueError (f"`val_share` must be between 0 and 1, got { val_share } ." )
302- if not isinstance (holdout_val_type , HoldoutValTypes ):
303- raise NotImplementedError (f'The specified `holdout_val_type` "{ holdout_val_type } " is not supported.' )
304- kwargs = {}
305- if holdout_val_type .is_stratified ():
306- # we need additional information about the data for stratification
307- kwargs ["stratify" ] = self .train_tensors [- 1 ]
308- train , val = self .holdout_validators [holdout_val_type .name ](
309- self .random_state , val_share , self ._get_indices (), ** kwargs )
310- return train , val
311256
312257 def get_dataset_for_training (self , split_id : int ) -> Tuple [Dataset , Dataset ]:
313258 """
0 commit comments