Skip to content

Commit 3944bb9

Browse files
nabenabe0928ravinkohli
authored andcommitted
[refactor] Address Shuhei's comments
1 parent 57ca887 commit 3944bb9

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

autoPyTorch/pipeline/components/training/trainer/GridCutMixTrainer.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
2626
np.ndarray: that processes data
2727
typing.Dict[str, np.ndarray]: arguments to the criterion function
2828
"""
29-
beta = 1.0
30-
lam = self.random_state.beta(beta, beta)
31-
batch_size, channel, W, H = X.size()
32-
index = torch.randperm(batch_size).cuda() if X.is_cuda else torch.randperm(batch_size)
29+
alpha, beta = 1.0, 1.0
30+
lam = self.random_state.beta(alpha, beta)
31+
batch_size, _, W, H = X.shape
32+
device = torch.device('cuda' if X.is_cuda else 'cpu')
33+
batch_indices = torch.randperm(batch_size).to(device)
3334

3435
r = self.random_state.rand(1)
3536
if beta <= 0 or r > self.alpha:
36-
return X, {'y_a': y, 'y_b': y[index], 'lam': 1}
37+
return X, {'y_a': y, 'y_b': y[batch_indices], 'lam': 1}
3738

3839
# Draw parameters of a random bounding box
3940
# Where to cut basically
@@ -47,12 +48,13 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
4748
bbx2 = np.clip(cx + cut_w // 2, 0, W)
4849
bby2 = np.clip(cy + cut_h // 2, 0, H)
4950

50-
X[:, :, bbx1:bbx2, bby1:bby2] = X[index, :, bbx1:bbx2, bby1:bby2]
51+
X[:, :, bbx1:bbx2, bby1:bby2] = X[batch_indices, :, bbx1:bbx2, bby1:bby2]
5152

5253
# Adjust lam
53-
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (X.size()[-1] * X.size()[-2]))
54+
pixel_size = W * H
55+
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / pixel_size)
5456

55-
y_a, y_b = y, y[index]
57+
y_a, y_b = y, y[batch_indices]
5658

5759
return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam}
5860

0 commit comments

Comments
 (0)