Skip to content

Commit dd80bec

Browse files
committed
Add new transform-utils for common type transforms and apply to CreateImageDataLoader
1 parent 48771ec commit dd80bec

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

autoPyTorch/pipeline/nodes/image/create_image_dataloader.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from torch.utils.data.sampler import SubsetRandomSampler
1919
from torchvision import datasets, models, transforms
2020

21+
from autoPyTorch.utils.transforms import transform_int64
22+
23+
2124
class CreateImageDataLoader(CreateDataLoader):
2225

2326
def fit(self, pipeline_config, hyperparameter_config, X, Y, train_indices, valid_indices, train_transform, valid_transform, dataset_info):
@@ -27,18 +30,19 @@ def fit(self, pipeline_config, hyperparameter_config, X, Y, train_indices, valid
2730

2831
torch.manual_seed(pipeline_config["random_seed"])
2932
hyperparameter_config = ConfigWrapper(self.get_name(), hyperparameter_config)
33+
to_int64 = transform_int64
3034

3135
if dataset_info.default_dataset:
3236
train_dataset = dataset_info.default_dataset(root=pipeline_config['default_dataset_download_dir'], train=True, download=True, transform=train_transform)
3337
if valid_indices is not None:
3438
valid_dataset = dataset_info.default_dataset(root=pipeline_config['default_dataset_download_dir'], train=True, download=True, transform=valid_transform)
3539
elif len(X.shape) > 1:
36-
train_dataset = XYDataset(X, Y, transform=train_transform, target_transform=lambda y: y.astype(np.int64))
37-
valid_dataset = XYDataset(X, Y, transform=valid_transform, target_transform=lambda y: y.astype(np.int64))
40+
train_dataset = XYDataset(X, Y, transform=train_transform, target_transform=to_int64)
41+
valid_dataset = XYDataset(X, Y, transform=valid_transform, target_transform=to_int64)
3842
else:
39-
train_dataset = ImageFilelist(X, Y, transform=train_transform, target_transform=lambda y: y.astype(np.int64), cache_size=pipeline_config['dataloader_cache_size_mb'] * 1000, image_size=dataset_info.x_shape[2:])
43+
train_dataset = ImageFilelist(X, Y, transform=train_transform, target_transform=to_int64, cache_size=pipeline_config['dataloader_cache_size_mb'] * 1000, image_size=dataset_info.x_shape[2:])
4044
if valid_indices is not None:
41-
valid_dataset = ImageFilelist(X, Y, transform=valid_transform, target_transform=lambda y: y.astype(np.int64), cache_size=0, image_size=dataset_info.x_shape[2:])
45+
valid_dataset = ImageFilelist(X, Y, transform=valid_transform, target_transform=to_int64, cache_size=0, image_size=dataset_info.x_shape[2:])
4246
valid_dataset.cache = train_dataset.cache
4347

4448
train_loader = DataLoader(

autoPyTorch/utils/transforms.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import numpy as np
2+
3+
4+
def transform_int64(y):
5+
return y.astype(np.int64)

0 commit comments

Comments
 (0)