Skip to content

Commit 4b28a1c

Browse files
nabenabe0928ravinkohli
authored andcommitted
[refactor] Address Shuhei's comments
[fix] Fix Flake8 issues [refactor] Address Shuhei's comment [refactor] Address Shuhei's comments [refactor] Address Shuhei's comments [refactor] Address Shuhei's comments
1 parent 23a3b6c commit 4b28a1c

File tree

11 files changed

+116
-111
lines changed

11 files changed

+116
-111
lines changed

autoPyTorch/evaluation/tae.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,7 @@ def __init__(
186186
else:
187187
raise ValueError("resampling strategy must be in "
188188
"(HoldoutValTypes, CrossValTypes, NoResamplingStrategyTypes), "
189-
"but got {}.".format(self.resampling_strategy)
190-
)
189+
"but got {}.".format(self.resampling_strategy))
191190

192191
self.worst_possible_result = cost_for_crash
193192

autoPyTorch/evaluation/train_evaluator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ def __init__(self, backend: Backend, queue: Queue,
156156
'resampling_strategy, but got {}'.format(self.datamanager.resampling_strategy)
157157
)
158158

159-
160159
self.splits = self.datamanager.splits
161160
if self.splits is None:
162161
raise AttributeError("Must have called create_splits on {}".format(self.datamanager.__class__.__name__))

autoPyTorch/pipeline/components/training/trainer/AdversarialTrainer.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@ def __init__(
3737
3838
Args:
3939
epsilon (float): The perturbation magnitude.
40-
40+
41+
References:
42+
Explaining and Harnessing Adversarial Examples
43+
Ian J. Goodfellow et. al.
44+
https://arxiv.org/pdf/1412.6572.pdf
4145
"""
4246
super().__init__(random_state=random_state,
4347
weighted_loss=weighted_loss,
@@ -96,10 +100,10 @@ def train_step(self, data: np.ndarray, targets: np.ndarray) -> Tuple[float, torc
96100
# training
97101
self.optimizer.zero_grad()
98102
original_outputs = self.model(original_data)
99-
adversarial_output = self.model(adversarial_data)
103+
adversarial_outputs = self.model(adversarial_data)
100104

101105
loss_func = self.criterion_preparation(**criterion_kwargs)
102-
loss = loss_func(self.criterion, original_outputs, adversarial_output)
106+
loss = loss_func(self.criterion, original_outputs, adversarial_outputs)
103107
loss.backward()
104108
self.optimizer.step()
105109
if self.scheduler:
@@ -125,6 +129,9 @@ def fgsm_attack(
125129
126130
Returns:
127131
adv_data (np.ndarray): the adversarial examples.
132+
133+
References:
134+
https://pytorch.org/tutorials/beginner/fgsm_tutorial.html#fgsm-attack
128135
"""
129136
data_copy = deepcopy(data)
130137
data_copy = data_copy.float().to(self.device)
@@ -159,7 +166,7 @@ def get_hyperparameter_search_space(
159166
dataset_properties: Optional[Dict] = None,
160167
weighted_loss: HyperparameterSearchSpace = HyperparameterSearchSpace(
161168
hyperparameter="weighted_loss",
162-
value_range=[True, False],
169+
value_range=(True, False),
163170
default_value=True),
164171
la_steps: HyperparameterSearchSpace = HyperparameterSearchSpace(
165172
hyperparameter="la_steps",
@@ -196,9 +203,7 @@ def get_hyperparameter_search_space(
196203

197204
add_hyperparameter(cs, epsilon, UniformFloatHyperparameter)
198205
add_hyperparameter(cs, use_stochastic_weight_averaging, CategoricalHyperparameter)
199-
snapshot_ensemble_flag = False
200-
if any(use_snapshot_ensemble.value_range):
201-
snapshot_ensemble_flag = True
206+
snapshot_ensemble_flag = any(use_snapshot_ensemble.value_range)
202207

203208
use_snapshot_ensemble = get_hyperparameter(use_snapshot_ensemble, CategoricalHyperparameter)
204209
cs.add_hyperparameter(use_snapshot_ensemble)
@@ -209,9 +214,7 @@ def get_hyperparameter_search_space(
209214
cond = EqualsCondition(se_lastk, use_snapshot_ensemble, True)
210215
cs.add_condition(cond)
211216

212-
lookahead_flag = False
213-
if any(use_lookahead_optimizer.value_range):
214-
lookahead_flag = True
217+
lookahead_flag = any(use_lookahead_optimizer.value_range)
215218

216219
use_lookahead_optimizer = get_hyperparameter(use_lookahead_optimizer, CategoricalHyperparameter)
217220
cs.add_hyperparameter(use_lookahead_optimizer)

autoPyTorch/pipeline/components/training/trainer/GridCutMixTrainer.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
2626
np.ndarray: that processes data
2727
typing.Dict[str, np.ndarray]: arguments to the criterion function
2828
"""
29-
beta = 1.0
30-
lam = self.random_state.beta(beta, beta)
31-
batch_size, channel, W, H = X.size()
32-
index = torch.randperm(batch_size).cuda() if X.is_cuda else torch.randperm(batch_size)
29+
alpha, beta = 1.0, 1.0
30+
lam = self.random_state.beta(alpha, beta)
31+
batch_size, _, W, H = X.shape
32+
device = torch.device('cuda' if X.is_cuda else 'cpu')
33+
batch_indices = torch.randperm(batch_size).to(device)
3334

3435
r = self.random_state.rand(1)
3536
if beta <= 0 or r > self.alpha:
36-
return X, {'y_a': y, 'y_b': y[index], 'lam': 1}
37+
return X, {'y_a': y, 'y_b': y[batch_indices], 'lam': 1}
3738

3839
# Draw parameters of a random bounding box
3940
# Where to cut basically
@@ -47,12 +48,13 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
4748
bbx2 = np.clip(cx + cut_w // 2, 0, W)
4849
bby2 = np.clip(cy + cut_h // 2, 0, H)
4950

50-
X[:, :, bbx1:bbx2, bby1:bby2] = X[index, :, bbx1:bbx2, bby1:bby2]
51+
X[:, :, bbx1:bbx2, bby1:bby2] = X[batch_indices, :, bbx1:bbx2, bby1:bby2]
5152

5253
# Adjust lam
53-
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (X.size()[-1] * X.size()[-2]))
54+
pixel_size = W * H
55+
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / pixel_size)
5456

55-
y_a, y_b = y, y[index]
57+
y_a, y_b = y, y[batch_indices]
5658

5759
return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam}
5860

autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,31 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
2626
np.ndarray: that processes data
2727
typing.Dict[str, np.ndarray]: arguments to the criterion function
2828
"""
29-
beta = 1.0
30-
lam = self.random_state.beta(beta, beta)
31-
batch_size = X.size()[0]
32-
index = torch.randperm(batch_size).cuda() if X.is_cuda else torch.randperm(batch_size)
29+
alpha, beta = 1.0, 1.0
30+
lam = self.random_state.beta(alpha, beta)
31+
batch_size = X.shape[0]
32+
device = torch.device('cuda' if X.is_cuda else 'cpu')
33+
batch_indices = torch.randperm(batch_size).to(device)
3334

3435
r = self.random_state.rand(1)
3536
if beta <= 0 or r > self.alpha:
36-
return X, {'y_a': y, 'y_b': y[index], 'lam': 1}
37+
return X, {'y_a': y, 'y_b': y[batch_indices], 'lam': 1}
3738

38-
size = X.shape[1]
39-
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int32(size * lam)),
40-
replace=False))
39+
row_size = X.shape[1]
40+
row_indices = torch.tensor(
41+
self.random_state.choice(
42+
range(1, row_size),
43+
max(1, int(row_size * lam)),
44+
replace=False
45+
)
46+
)
4147

42-
X[:, indices] = X[index, :][:, indices]
48+
X[:, row_indices] = X[batch_indices, :][:, row_indices]
4349

4450
# Adjust lam
45-
lam = 1 - ((len(indices)) / (X.size()[1]))
51+
lam = 1 - len(row_indices) / X.shape[1]
4652

47-
y_a, y_b = y, y[index]
53+
y_a, y_b = y, y[batch_indices]
4854

4955
return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam}
5056

autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99

1010

1111
class RowCutOutTrainer(CutOut, BaseTrainerComponent):
12+
# 0 is non-informative in image data
1213
NUMERICAL_VALUE = 0
14+
# -1 is the conceptually equivalent to 0 in a image, i.e. 0-pad
1315
CATEGORICAL_VALUE = -1
1416

1517
def data_preparation(self, X: np.ndarray, y: np.ndarray,
@@ -36,23 +38,18 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
3638
lam = 1
3739
return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam}
3840

39-
size = X.shape[1]
40-
indices = self.random_state.choice(range(1, size), max(1, np.int32(size * self.patch_ratio)),
41-
replace=False)
41+
row_size = X.shape[1]
42+
row_indices = self.random_state.choice(range(1, row_size), max(1, int(row_size * self.patch_ratio)),
43+
replace=False)
4244

4345
if not isinstance(self.numerical_columns, typing.Iterable):
44-
raise ValueError("{} requires numerical columns information of {}"
45-
"to prepare data got {}.".format(self.__class__.__name__,
46-
typing.Iterable,
47-
self.numerical_columns))
46+
raise ValueError("numerical_columns in {} must be iterable, "
47+
"but got {}.".format(self.__class__.__name__,
48+
self.numerical_columns))
49+
4850
numerical_indices = torch.tensor(self.numerical_columns)
49-
categorical_indices = torch.tensor([index for index in indices if index not in self.numerical_columns])
51+
categorical_indices = torch.tensor([idx for idx in row_indices if idx not in self.numerical_columns])
5052

51-
# We use an ordinal encoder on the categorical columns of tabular data
52-
# -1 is the conceptual equivalent to 0 in a image, that does not
53-
# have color as a feature and hence the network has to learn to deal
54-
# without this data. For numerical columns we use 0 to cutout the features
55-
# similar to the effect that setting 0 as a pixel value in an image.
5653
X[:, categorical_indices.long()] = self.CATEGORICAL_VALUE
5754
X[:, numerical_indices.long()] = self.NUMERICAL_VALUE
5855

autoPyTorch/pipeline/components/training/trainer/__init__.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -423,12 +423,13 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
423423

424424
if self.choice.use_stochastic_weight_averaging and self.choice.swa_updated:
425425
# update batch norm statistics
426-
swa_utils.update_bn(X['train_data_loader'], self.choice.swa_model.double())
426+
swa_utils.update_bn(loader=X['train_data_loader'], model=self.choice.swa_model.double())
427+
427428
# change model
428429
update_model_state_dict_from_swa(X['network'], self.choice.swa_model.state_dict())
429430
if self.choice.use_snapshot_ensemble:
430431
for model in self.choice.model_snapshots:
431-
swa_utils.update_bn(X['train_data_loader'], model.double())
432+
swa_utils.update_bn(loader=X['train_data_loader'], model=model.double())
432433

433434
# wrap up -- add score if not evaluating every epoch
434435
if not self.eval_valid_each_epoch(X):
@@ -500,13 +501,10 @@ def early_stop_handler(self, X: Dict[str, Any]) -> bool:
500501
if self.checkpoint_dir is None:
501502
self.checkpoint_dir = tempfile.mkdtemp(dir=X['backend'].temporary_directory)
502503

504+
target_metrics = 'val_loss'
503505
if X['val_indices'] is None:
504-
if X['X_test'] is not None:
505-
epochs_since_best = self.run_summary.get_last_epoch() - self.run_summary.get_best_epoch('test_loss')
506-
else:
507-
epochs_since_best = self.run_summary.get_last_epoch() - self.run_summary.get_best_epoch('train_loss')
508-
else:
509-
epochs_since_best = self.run_summary.get_last_epoch() - self.run_summary.get_best_epoch()
506+
target_metrics = 'test_loss' if X['X_test'] is not None else 'train_loss'
507+
epochs_since_best = self.run_summary.get_last_epoch() - self.run_summary.get_best_epoch(target_metrics)
510508

511509
# Save the checkpoint if there is a new best epoch
512510
best_path = os.path.join(self.checkpoint_dir, 'best.pth')
@@ -636,11 +634,12 @@ def __str__(self) -> str:
636634
def _get_search_space_updates(self, prefix: Optional[str] = None) -> Dict[str, HyperparameterSearchSpace]:
637635
"""Get the search space updates with the given prefix
638636
639-
Keyword Arguments:
640-
prefix {str} -- Only return search space updates with given prefix (default: {None})
637+
Args:
638+
prefix (Optional[str]): Only return search space updates with given prefix
641639
642640
Returns:
643-
dict -- Mapping of search space updates. Keys don't contain the prefix.
641+
Dict[str, HyperparameterSearchSpace]:
642+
Mapping of search space updates. Keys don't contain the prefix.
644643
"""
645644
updates = super()._get_search_space_updates(prefix=prefix)
646645

autoPyTorch/pipeline/components/training/trainer/base_trainer.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from autoPyTorch.pipeline.components.training.base_training import autoPyTorchTrainingComponent
2727
from autoPyTorch.pipeline.components.training.metrics.metrics import CLASSIFICATION_METRICS, REGRESSION_METRICS
2828
from autoPyTorch.pipeline.components.training.metrics.utils import calculate_score
29-
from autoPyTorch.pipeline.components.training.trainer.utils import Lookahead, swa_average_function
29+
from autoPyTorch.pipeline.components.training.trainer.utils import Lookahead, swa_update
3030
from autoPyTorch.utils.common import FitRequirement, HyperparameterSearchSpace, add_hyperparameter, get_hyperparameter
3131
from autoPyTorch.utils.implementations import get_loss_weight_strategy
3232

@@ -224,7 +224,7 @@ def __init__(self, weighted_loss: bool = False,
224224
use_snapshot_ensemble: bool = True,
225225
se_lastk: int = 3,
226226
use_lookahead_optimizer: bool = True,
227-
random_state: Optional[Union[np.random.RandomState, int]] = None,
227+
random_state: Optional[np.random.RandomState] = None,
228228
swa_model: Optional[torch.nn.Module] = None,
229229
model_snapshots: Optional[List[torch.nn.Module]] = None,
230230
**lookahead_config: Any) -> None:
@@ -285,13 +285,14 @@ def prepare(
285285

286286
# in case we are using swa, maintain an averaged model,
287287
if self.use_stochastic_weight_averaging:
288-
self.swa_model = swa_utils.AveragedModel(self.model, avg_fn=swa_average_function)
288+
self.swa_model = swa_utils.AveragedModel(self.model, avg_fn=swa_update)
289289

290290
# in case we are using se or swa, initialise budget_threshold to know when to start swa or se
291291
self._budget_threshold = 0
292292
if self.use_stochastic_weight_averaging or self.use_snapshot_ensemble:
293-
assert budget_tracker.max_epochs is not None, "Can only use stochastic weight averaging or snapshot " \
294-
"ensemble when budget is epochs"
293+
if budget_tracker.max_epochs is None:
294+
raise ValueError("Budget for stochastic weight averaging or snapshot ensemble must be `epoch`.")
295+
295296
self._budget_threshold = int(0.75 * budget_tracker.max_epochs)
296297

297298
# in case we are using se, initialise list to store model snapshots
@@ -589,7 +590,7 @@ def get_hyperparameter_search_space(
589590
dataset_properties: Optional[Dict] = None,
590591
weighted_loss: HyperparameterSearchSpace = HyperparameterSearchSpace(
591592
hyperparameter="weighted_loss",
592-
value_range=[True, False],
593+
value_range=(True, False),
593594
default_value=True),
594595
la_steps: HyperparameterSearchSpace = HyperparameterSearchSpace(
595596
hyperparameter="la_steps",
@@ -621,9 +622,7 @@ def get_hyperparameter_search_space(
621622
cs = ConfigurationSpace()
622623

623624
add_hyperparameter(cs, use_stochastic_weight_averaging, CategoricalHyperparameter)
624-
snapshot_ensemble_flag = False
625-
if any(use_snapshot_ensemble.value_range):
626-
snapshot_ensemble_flag = True
625+
snapshot_ensemble_flag = any(use_snapshot_ensemble.value_range)
627626

628627
use_snapshot_ensemble = get_hyperparameter(use_snapshot_ensemble, CategoricalHyperparameter)
629628
cs.add_hyperparameter(use_snapshot_ensemble)
@@ -634,9 +633,7 @@ def get_hyperparameter_search_space(
634633
cond = EqualsCondition(se_lastk, use_snapshot_ensemble, True)
635634
cs.add_condition(cond)
636635

637-
lookahead_flag = False
638-
if any(use_lookahead_optimizer.value_range):
639-
lookahead_flag = True
636+
lookahead_flag = any(use_lookahead_optimizer.value_range)
640637

641638
use_lookahead_optimizer = get_hyperparameter(use_lookahead_optimizer, CategoricalHyperparameter)
642639
cs.add_hyperparameter(use_lookahead_optimizer)

0 commit comments

Comments
 (0)