2424)
2525from autoPyTorch .utils .common import FitRequirement , hash_array_or_matrix
2626
27- BASE_DATASET_INPUT = Union [Tuple [np .ndarray , np .ndarray ], Dataset ]
27+ BaseDatasetType = Union [Tuple [np .ndarray , np .ndarray ], Dataset ]
2828
2929
3030def check_valid_data (data : Any ) -> None :
31- if not (hasattr (data , '__getitem__' ) and hasattr ( data , '__len__' ) ):
31+ if not all (hasattr (data , attr ) for attr in [ '__getitem__' , '__len__' ] ):
3232 raise ValueError (
33- 'The specified Data for Dataset does either not have a __getitem__ or a __len__ attribute.' )
33+ 'The specified Data for Dataset must have both __getitem__ and __len__ attribute.' )
3434
3535
36- def type_check (train_tensors : BASE_DATASET_INPUT , val_tensors : Optional [BASE_DATASET_INPUT ] = None ) -> None :
36+ def type_check (train_tensors : BaseDatasetType , val_tensors : Optional [BaseDatasetType ] = None ) -> None :
37+ """To avoid unexpected behavior, we use loops over indices."""
3738 for i in range (len (train_tensors )):
3839 check_valid_data (train_tensors [i ])
3940 if val_tensors is not None :
@@ -42,12 +43,20 @@ def type_check(train_tensors: BASE_DATASET_INPUT, val_tensors: Optional[BASE_DAT
4243
4344
4445class TransformSubset (Subset ):
45- """
46- Because the BaseDataset contains all the data (train/val/test), the transformations
47- have to be applied with some directions. That is, if yielding train data,
48- we expect to apply train transformation (which have augmentations exclusively).
46+ """Wrapper of BaseDataset for splitted datasets
47+
48+ Since the BaseDataset contains all the data points (train/val/test),
49+ we require different transformation for each data point.
50+ This class helps to take the subset of the dataset
51+ with either training or validation transformation.
4952
5053 We achieve so by adding a train flag to the pytorch subset
54+
55+ Attributes:
56+ dataset (BaseDataset/Dataset): Dataset to sample the subset
57+ indices names (Sequence[int]): Indices to sample from the dataset
58+ train (bool): If we apply train or validation transformation
59+
5160 """
5261
5362 def __init__ (self , dataset : Dataset , indices : Sequence [int ], train : bool ) -> None :
@@ -62,10 +71,10 @@ def __getitem__(self, idx: int) -> np.ndarray:
6271class BaseDataset (Dataset , metaclass = ABCMeta ):
6372 def __init__ (
6473 self ,
65- train_tensors : BASE_DATASET_INPUT ,
74+ train_tensors : BaseDatasetType ,
6675 dataset_name : Optional [str ] = None ,
67- val_tensors : Optional [BASE_DATASET_INPUT ] = None ,
68- test_tensors : Optional [BASE_DATASET_INPUT ] = None ,
76+ val_tensors : Optional [BaseDatasetType ] = None ,
77+ test_tensors : Optional [BaseDatasetType ] = None ,
6978 resampling_strategy : Union [CrossValTypes , HoldoutValTypes ] = HoldoutValTypes .holdout_validation ,
7079 resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
7180 shuffle : Optional [bool ] = True ,
@@ -97,18 +106,15 @@ def __init__(
97106 val_transforms (Optional[torchvision.transforms.Compose]):
98107 Additional Transforms to be applied to the validation/test data
99108 """
100- if dataset_name is not None :
101- self .dataset_name = dataset_name
102- else :
103- self .dataset_name = hash_array_or_matrix (train_tensors [0 ])
109+ self .dataset_name = dataset_name if dataset_name is not None \
110+ else hash_array_or_matrix (train_tensors [0 ])
111+
104112 if not hasattr (train_tensors [0 ], 'shape' ):
105113 type_check (train_tensors , val_tensors )
106- self .train_tensors = train_tensors
107- self .val_tensors = val_tensors
108- self .test_tensors = test_tensors
114+ self .train_tensors , self .val_tensors , self .test_tensors = train_tensors , val_tensors , test_tensors
109115 self .cross_validators : Dict [str , CROSS_VAL_FN ] = {}
110116 self .holdout_validators : Dict [str , HOLDOUT_FN ] = {}
111- self .rand = np .random .RandomState (seed = seed )
117+ self .rng = np .random .RandomState (seed = seed )
112118 self .shuffle = shuffle
113119 self .resampling_strategy = resampling_strategy
114120 self .resampling_strategy_args = resampling_strategy_args
@@ -128,16 +134,8 @@ def __init__(
128134 self .is_small_preprocess = True
129135
130136 # Make sure cross validation splits are created once
131- self .cross_validators = get_cross_validators (
132- CrossValTypes .stratified_k_fold_cross_validation ,
133- CrossValTypes .k_fold_cross_validation ,
134- CrossValTypes .shuffle_split_cross_validation ,
135- CrossValTypes .stratified_shuffle_split_cross_validation
136- )
137- self .holdout_validators = get_holdout_validators (
138- HoldoutValTypes .holdout_validation ,
139- HoldoutValTypes .stratified_holdout_validation
140- )
137+ self .cross_validators = get_cross_validators (* CrossValTypes )
138+ self .holdout_validators = get_holdout_validators (* HoldoutValTypes )
141139 self .splits = self .get_splits_from_resampling_strategy ()
142140
143141 # We also need to be able to transform the data, be it for pre-processing
@@ -146,19 +144,19 @@ def __init__(
146144 self .val_transform = val_transforms
147145
148146 def update_transform (self , transform : Optional [torchvision .transforms .Compose ],
149- train : bool = True ,
150- ) -> 'BaseDataset' :
147+ train : bool = True ) -> 'BaseDataset' :
151148 """
152149 During the pipeline execution, the pipeline object might propose transformations
153150 as a product of the current pipeline configuration being tested.
154151
155- This utility allows to return a self with the updated transformation, so that
152+ This utility allows to return self with the updated transformation, so that
156153 a dataloader can yield this dataset with the desired transformations
157154
158155 Args:
159- transform (torchvision.transforms.Compose): The transformations proposed
160- by the current pipeline
161- train (bool): Whether to update the train or validation transform
156+ transform (torchvision.transforms.Compose):
157+ The transformations proposed by the current pipeline
158+ train (bool):
159+ Whether to update the train or validation transform
162160
163161 Returns:
164162 self: A copy of the update pipeline
@@ -171,9 +169,9 @@ def update_transform(self, transform: Optional[torchvision.transforms.Compose],
171169
172170 def __getitem__ (self , index : int , train : bool = True ) -> Tuple [np .ndarray , ...]:
173171 """
174- The base dataset uses a Subset of the data. Nevertheless, the base dataset expect
175- both validation and test data to be present in the same dataset, which motivated the
176- need to dynamically give train/test data with the __getitem__ command.
172+ The base dataset uses a Subset of the data. Nevertheless, the base dataset expects
173+ both validation and test data to be present in the same dataset, which motivates
174+ the need to dynamically give train/test data with the __getitem__ command.
177175
178176 This method yields a datapoint of the whole data (after a Subset has selected a given
179177 item, based on the resampling strategy) and applies a train/testing transformation, if any.
@@ -186,34 +184,24 @@ def __getitem__(self, index: int, train: bool = True) -> Tuple[np.ndarray, ...]:
186184 A transformed single point prediction
187185 """
188186
189- if hasattr (self .train_tensors [0 ], 'loc' ):
190- X = self .train_tensors [0 ].iloc [[index ]]
191- else :
192- X = self .train_tensors [0 ][index ]
187+ X = self .train_tensors [0 ].iloc [[index ]] if hasattr (self .train_tensors [0 ], 'loc' ) \
188+ else self .train_tensors [0 ][index ]
193189
194190 if self .train_transform is not None and train :
195191 X = self .train_transform (X )
196192 elif self .val_transform is not None and not train :
197193 X = self .val_transform (X )
198194
199195 # In case of prediction, the targets are not provided
200- Y = self .train_tensors [1 ]
201- if Y is not None :
202- Y = Y [index ]
203- else :
204- Y = None
196+ Y = self .train_tensors [1 ][index ] if self .train_tensors [1 ] is not None else None
205197
206198 return X , Y
207199
208200 def __len__ (self ) -> int :
209201 return self .train_tensors [0 ].shape [0 ]
210202
211203 def _get_indices (self ) -> np .ndarray :
212- if self .shuffle :
213- indices = self .rand .permutation (len (self ))
214- else :
215- indices = np .arange (len (self ))
216- return indices
204+ return self .rng .permutation (len (self )) if self .shuffle else np .arange (len (self ))
217205
218206 def get_splits_from_resampling_strategy (self ) -> List [Tuple [List [int ], List [int ]]]:
219207 """
@@ -333,7 +321,7 @@ def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
333321 return (TransformSubset (self , self .splits [split_id ][0 ], train = True ),
334322 TransformSubset (self , self .splits [split_id ][1 ], train = False ))
335323
336- def replace_data (self , X_train : BASE_DATASET_INPUT , X_test : Optional [BASE_DATASET_INPUT ]) -> 'BaseDataset' :
324+ def replace_data (self , X_train : BaseDatasetType , X_test : Optional [BaseDatasetType ]) -> 'BaseDataset' :
337325 """
338326 To speed up the training of small dataset, early pre-processing of the data
339327 can be made on the fly by the pipeline.
@@ -361,7 +349,8 @@ def get_dataset_properties(self, dataset_requirements: List[FitRequirement]) ->
361349 contain.
362350
363351 Returns:
364-
352+ dataset_properties (Dict[str, Any]):
353+ Dict of the dataset properties.
365354 """
366355 dataset_properties = dict ()
367356 for dataset_requirement in dataset_requirements :
0 commit comments