Skip to content

Commit 57ca887

Browse files
nabenabe0928ravinkohli
authored andcommitted
[refactor] Address Shuhei's comments
1 parent 62eebb5 commit 57ca887

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,31 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
2626
np.ndarray: that processes data
2727
typing.Dict[str, np.ndarray]: arguments to the criterion function
2828
"""
29-
beta = 1.0
30-
lam = self.random_state.beta(beta, beta)
31-
batch_size = X.size()[0]
32-
index = torch.randperm(batch_size).cuda() if X.is_cuda else torch.randperm(batch_size)
29+
alpha, beta = 1.0, 1.0
30+
lam = self.random_state.beta(alpha, beta)
31+
batch_size = X.shape[0]
32+
device = torch.device('cuda' if X.is_cuda else 'cpu')
33+
batch_indices = torch.randperm(batch_size).to(device)
3334

3435
r = self.random_state.rand(1)
3536
if beta <= 0 or r > self.alpha:
36-
return X, {'y_a': y, 'y_b': y[index], 'lam': 1}
37+
return X, {'y_a': y, 'y_b': y[batch_indices], 'lam': 1}
3738

38-
size = X.shape[1]
39-
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int32(size * lam)),
40-
replace=False))
39+
row_size = X.shape[1]
40+
row_indices = torch.tensor(
41+
self.random_state.choice(
42+
range(1, row_size),
43+
max(1, int(row_size * lam)),
44+
replace=False
45+
)
46+
)
4147

42-
X[:, indices] = X[index, :][:, indices]
48+
X[:, row_indices] = X[batch_indices, :][:, row_indices]
4349

4450
# Adjust lam
45-
lam = 1 - ((len(indices)) / (X.size()[1]))
51+
lam = 1 - len(row_indices) / X.shape[1]
4652

47-
y_a, y_b = y, y[index]
53+
y_a, y_b = y, y[batch_indices]
4854

4955
return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam}
5056

0 commit comments

Comments
 (0)