1818from torch .utils .data .sampler import SubsetRandomSampler
1919from torchvision import datasets , models , transforms
2020
21+ from autoPyTorch .utils .transforms import transform_int64
22+
23+
2124class CreateImageDataLoader (CreateDataLoader ):
2225
2326 def fit (self , pipeline_config , hyperparameter_config , X , Y , train_indices , valid_indices , train_transform , valid_transform , dataset_info ):
@@ -27,18 +30,19 @@ def fit(self, pipeline_config, hyperparameter_config, X, Y, train_indices, valid
2730
2831 torch .manual_seed (pipeline_config ["random_seed" ])
2932 hyperparameter_config = ConfigWrapper (self .get_name (), hyperparameter_config )
33+ to_int64 = transform_int64
3034
3135 if dataset_info .default_dataset :
3236 train_dataset = dataset_info .default_dataset (root = pipeline_config ['default_dataset_download_dir' ], train = True , download = True , transform = train_transform )
3337 if valid_indices is not None :
3438 valid_dataset = dataset_info .default_dataset (root = pipeline_config ['default_dataset_download_dir' ], train = True , download = True , transform = valid_transform )
3539 elif len (X .shape ) > 1 :
36- train_dataset = XYDataset (X , Y , transform = train_transform , target_transform = lambda y : y . astype ( np . int64 ) )
37- valid_dataset = XYDataset (X , Y , transform = valid_transform , target_transform = lambda y : y . astype ( np . int64 ) )
40+ train_dataset = XYDataset (X , Y , transform = train_transform , target_transform = to_int64 )
41+ valid_dataset = XYDataset (X , Y , transform = valid_transform , target_transform = to_int64 )
3842 else :
39- train_dataset = ImageFilelist (X , Y , transform = train_transform , target_transform = lambda y : y . astype ( np . int64 ) , cache_size = pipeline_config ['dataloader_cache_size_mb' ] * 1000 , image_size = dataset_info .x_shape [2 :])
43+ train_dataset = ImageFilelist (X , Y , transform = train_transform , target_transform = to_int64 , cache_size = pipeline_config ['dataloader_cache_size_mb' ] * 1000 , image_size = dataset_info .x_shape [2 :])
4044 if valid_indices is not None :
41- valid_dataset = ImageFilelist (X , Y , transform = valid_transform , target_transform = lambda y : y . astype ( np . int64 ) , cache_size = 0 , image_size = dataset_info .x_shape [2 :])
45+ valid_dataset = ImageFilelist (X , Y , transform = valid_transform , target_transform = to_int64 , cache_size = 0 , image_size = dataset_info .x_shape [2 :])
4246 valid_dataset .cache = train_dataset .cache
4347
4448 train_loader = DataLoader (
0 commit comments