Skip to content

Commit 62eebb5

Browse files
nabenabe0928ravinkohli
authored andcommitted
[refactor] Address Shuhei's comment
1 parent 4e6c8be commit 62eebb5

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99

1010

1111
class RowCutOutTrainer(CutOut, BaseTrainerComponent):
12+
# 0 is non-informative in image data
1213
NUMERICAL_VALUE = 0
14+
# -1 is the conceptually equivalent to 0 in a image, i.e. 0-pad
1315
CATEGORICAL_VALUE = -1
1416

1517
def data_preparation(self, X: np.ndarray, y: np.ndarray,
@@ -36,23 +38,18 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
3638
lam = 1
3739
return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam}
3840

39-
size = X.shape[1]
40-
indices = self.random_state.choice(range(1, size), max(1, np.int32(size * self.patch_ratio)),
41-
replace=False)
41+
row_size = X.shape[1]
42+
row_indices = self.random_state.choice(range(1, row_size), max(1, int(row_size * self.patch_ratio)),
43+
replace=False)
4244

4345
if not isinstance(self.numerical_columns, typing.Iterable):
44-
raise ValueError("{} requires numerical columns information of {}"
45-
"to prepare data got {}.".format(self.__class__.__name__,
46-
typing.Iterable,
47-
self.numerical_columns))
46+
raise ValueError("numerical_columns in {} must be iterable, "
47+
"but got {}.".format(self.__class__.__name__,
48+
self.numerical_columns))
49+
4850
numerical_indices = torch.tensor(self.numerical_columns)
49-
categorical_indices = torch.tensor([index for index in indices if index not in self.numerical_columns])
51+
categorical_indices = torch.tensor([idx for idx in row_indices if idx not in self.numerical_columns])
5052

51-
# We use an ordinal encoder on the categorical columns of tabular data
52-
# -1 is the conceptual equivalent to 0 in a image, that does not
53-
# have color as a feature and hence the network has to learn to deal
54-
# without this data. For numerical columns we use 0 to cutout the features
55-
# similar to the effect that setting 0 as a pixel value in an image.
5653
X[:, categorical_indices.long()] = self.CATEGORICAL_VALUE
5754
X[:, numerical_indices.long()] = self.NUMERICAL_VALUE
5855

0 commit comments

Comments
 (0)