Skip to content

Commit b7a0897

Browse files
nabenabe0928ravinkohli
authored andcommitted
[refactor] Address Shuhei's comments
[fix] Fix Flake8 issues [refactor] Address Shuhei's comment [refactor] Address Shuhei's comments [refactor] Address Shuhei's comments [refactor] Address Shuhei's comments
1 parent 107750a commit b7a0897

File tree

11 files changed

+116
-111
lines changed

11 files changed

+116
-111
lines changed

autoPyTorch/evaluation/tae.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,7 @@ def __init__(
186186
else:
187187
raise ValueError("resampling strategy must be in "
188188
"(HoldoutValTypes, CrossValTypes, NoResamplingStrategyTypes), "
189-
"but got {}.".format(self.resampling_strategy)
190-
)
189+
"but got {}.".format(self.resampling_strategy))
191190

192191
self.worst_possible_result = cost_for_crash
193192

autoPyTorch/evaluation/train_evaluator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ def __init__(self, backend: Backend, queue: Queue,
156156
'resampling_strategy, but got {}'.format(self.datamanager.resampling_strategy)
157157
)
158158

159-
160159
self.splits = self.datamanager.splits
161160
if self.splits is None:
162161
raise AttributeError("Must have called create_splits on {}".format(self.datamanager.__class__.__name__))

autoPyTorch/pipeline/components/training/trainer/AdversarialTrainer.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@ def __init__(
3737
3838
Args:
3939
epsilon (float): The perturbation magnitude.
40-
40+
41+
References:
42+
Explaining and Harnessing Adversarial Examples
43+
Ian J. Goodfellow et. al.
44+
https://arxiv.org/pdf/1412.6572.pdf
4145
"""
4246
super().__init__(random_state=random_state,
4347
weighted_loss=weighted_loss,
@@ -96,10 +100,10 @@ def train_step(self, data: np.ndarray, targets: np.ndarray) -> Tuple[float, torc
96100
# training
97101
self.optimizer.zero_grad()
98102
original_outputs = self.model(original_data)
99-
adversarial_output = self.model(adversarial_data)
103+
adversarial_outputs = self.model(adversarial_data)
100104

101105
loss_func = self.criterion_preparation(**criterion_kwargs)
102-
loss = loss_func(self.criterion, original_outputs, adversarial_output)
106+
loss = loss_func(self.criterion, original_outputs, adversarial_outputs)
103107
loss.backward()
104108
self.optimizer.step()
105109
if self.scheduler:
@@ -125,6 +129,9 @@ def fgsm_attack(
125129
126130
Returns:
127131
adv_data (np.ndarray): the adversarial examples.
132+
133+
References:
134+
https://pytorch.org/tutorials/beginner/fgsm_tutorial.html#fgsm-attack
128135
"""
129136
data_copy = deepcopy(data)
130137
data_copy = data_copy.float().to(self.device)
@@ -159,7 +166,7 @@ def get_hyperparameter_search_space(
159166
dataset_properties: Optional[Dict] = None,
160167
weighted_loss: HyperparameterSearchSpace = HyperparameterSearchSpace(
161168
hyperparameter="weighted_loss",
162-
value_range=[True, False],
169+
value_range=(True, False),
163170
default_value=True),
164171
la_steps: HyperparameterSearchSpace = HyperparameterSearchSpace(
165172
hyperparameter="la_steps",
@@ -196,9 +203,7 @@ def get_hyperparameter_search_space(
196203

197204
add_hyperparameter(cs, epsilon, UniformFloatHyperparameter)
198205
add_hyperparameter(cs, use_stochastic_weight_averaging, CategoricalHyperparameter)
199-
snapshot_ensemble_flag = False
200-
if any(use_snapshot_ensemble.value_range):
201-
snapshot_ensemble_flag = True
206+
snapshot_ensemble_flag = any(use_snapshot_ensemble.value_range)
202207

203208
use_snapshot_ensemble = get_hyperparameter(use_snapshot_ensemble, CategoricalHyperparameter)
204209
cs.add_hyperparameter(use_snapshot_ensemble)
@@ -209,9 +214,7 @@ def get_hyperparameter_search_space(
209214
cond = EqualsCondition(se_lastk, use_snapshot_ensemble, True)
210215
cs.add_condition(cond)
211216

212-
lookahead_flag = False
213-
if any(use_lookahead_optimizer.value_range):
214-
lookahead_flag = True
217+
lookahead_flag = any(use_lookahead_optimizer.value_range)
215218

216219
use_lookahead_optimizer = get_hyperparameter(use_lookahead_optimizer, CategoricalHyperparameter)
217220
cs.add_hyperparameter(use_lookahead_optimizer)

autoPyTorch/pipeline/components/training/trainer/GridCutMixTrainer.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
2626
np.ndarray: that processes data
2727
typing.Dict[str, np.ndarray]: arguments to the criterion function
2828
"""
29-
beta = 1.0
30-
lam = self.random_state.beta(beta, beta)
31-
batch_size, channel, W, H = X.size()
32-
index = torch.randperm(batch_size).cuda() if X.is_cuda else torch.randperm(batch_size)
29+
alpha, beta = 1.0, 1.0
30+
lam = self.random_state.beta(alpha, beta)
31+
batch_size, _, W, H = X.shape
32+
device = torch.device('cuda' if X.is_cuda else 'cpu')
33+
batch_indices = torch.randperm(batch_size).to(device)
3334

3435
r = self.random_state.rand(1)
3536
if beta <= 0 or r > self.alpha:
36-
return X, {'y_a': y, 'y_b': y[index], 'lam': 1}
37+
return X, {'y_a': y, 'y_b': y[batch_indices], 'lam': 1}
3738

3839
# Draw parameters of a random bounding box
3940
# Where to cut basically
@@ -47,12 +48,13 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
4748
bbx2 = np.clip(cx + cut_w // 2, 0, W)
4849
bby2 = np.clip(cy + cut_h // 2, 0, H)
4950

50-
X[:, :, bbx1:bbx2, bby1:bby2] = X[index, :, bbx1:bbx2, bby1:bby2]
51+
X[:, :, bbx1:bbx2, bby1:bby2] = X[batch_indices, :, bbx1:bbx2, bby1:bby2]
5152

5253
# Adjust lam
53-
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (X.size()[-1] * X.size()[-2]))
54+
pixel_size = W * H
55+
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / pixel_size)
5456

55-
y_a, y_b = y, y[index]
57+
y_a, y_b = y, y[batch_indices]
5658

5759
return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam}
5860

autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,31 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
2626
np.ndarray: that processes data
2727
typing.Dict[str, np.ndarray]: arguments to the criterion function
2828
"""
29-
beta = 1.0
30-
lam = self.random_state.beta(beta, beta)
31-
batch_size = X.size()[0]
32-
index = torch.randperm(batch_size).cuda() if X.is_cuda else torch.randperm(batch_size)
29+
alpha, beta = 1.0, 1.0
30+
lam = self.random_state.beta(alpha, beta)
31+
batch_size = X.shape[0]
32+
device = torch.device('cuda' if X.is_cuda else 'cpu')
33+
batch_indices = torch.randperm(batch_size).to(device)
3334

3435
r = self.random_state.rand(1)
3536
if beta <= 0 or r > self.alpha:
36-
return X, {'y_a': y, 'y_b': y[index], 'lam': 1}
37+
return X, {'y_a': y, 'y_b': y[batch_indices], 'lam': 1}
3738

38-
size = X.shape[1]
39-
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int32(size * lam)),
40-
replace=False))
39+
row_size = X.shape[1]
40+
row_indices = torch.tensor(
41+
self.random_state.choice(
42+
range(1, row_size),
43+
max(1, int(row_size * lam)),
44+
replace=False
45+
)
46+
)
4147

42-
X[:, indices] = X[index, :][:, indices]
48+
X[:, row_indices] = X[batch_indices, :][:, row_indices]
4349

4450
# Adjust lam
45-
lam = 1 - ((len(indices)) / (X.size()[1]))
51+
lam = 1 - len(row_indices) / X.shape[1]
4652

47-
y_a, y_b = y, y[index]
53+
y_a, y_b = y, y[batch_indices]
4854

4955
return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam}
5056

autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99

1010

1111
class RowCutOutTrainer(CutOut, BaseTrainerComponent):
12+
# 0 is non-informative in image data
1213
NUMERICAL_VALUE = 0
14+
# -1 is the conceptually equivalent to 0 in a image, i.e. 0-pad
1315
CATEGORICAL_VALUE = -1
1416

1517
def data_preparation(self, X: np.ndarray, y: np.ndarray,
@@ -36,23 +38,18 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
3638
lam = 1
3739
return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam}
3840

39-
size = X.shape[1]
40-
indices = self.random_state.choice(range(1, size), max(1, np.int32(size * self.patch_ratio)),
41-
replace=False)
41+
row_size = X.shape[1]
42+
row_indices = self.random_state.choice(range(1, row_size), max(1, int(row_size * self.patch_ratio)),
43+
replace=False)
4244

4345
if not isinstance(self.numerical_columns, typing.Iterable):
44-
raise ValueError("{} requires numerical columns information of {}"
45-
"to prepare data got {}.".format(self.__class__.__name__,
46-
typing.Iterable,
47-
self.numerical_columns))
46+
raise ValueError("numerical_columns in {} must be iterable, "
47+
"but got {}.".format(self.__class__.__name__,
48+
self.numerical_columns))
49+
4850
numerical_indices = torch.tensor(self.numerical_columns)
49-
categorical_indices = torch.tensor([index for index in indices if index not in self.numerical_columns])
51+
categorical_indices = torch.tensor([idx for idx in row_indices if idx not in self.numerical_columns])
5052

51-
# We use an ordinal encoder on the categorical columns of tabular data
52-
# -1 is the conceptual equivalent to 0 in a image, that does not
53-
# have color as a feature and hence the network has to learn to deal
54-
# without this data. For numerical columns we use 0 to cutout the features
55-
# similar to the effect that setting 0 as a pixel value in an image.
5653
X[:, categorical_indices.long()] = self.CATEGORICAL_VALUE
5754
X[:, numerical_indices.long()] = self.NUMERICAL_VALUE
5855

autoPyTorch/pipeline/components/training/trainer/__init__.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -413,12 +413,13 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
413413

414414
if self.choice.use_stochastic_weight_averaging and self.choice.swa_updated:
415415
# update batch norm statistics
416-
swa_utils.update_bn(X['train_data_loader'], self.choice.swa_model.double())
416+
swa_utils.update_bn(loader=X['train_data_loader'], model=self.choice.swa_model.double())
417+
417418
# change model
418419
update_model_state_dict_from_swa(X['network'], self.choice.swa_model.state_dict())
419420
if self.choice.use_snapshot_ensemble:
420421
for model in self.choice.model_snapshots:
421-
swa_utils.update_bn(X['train_data_loader'], model.double())
422+
swa_utils.update_bn(loader=X['train_data_loader'], model=model.double())
422423

423424
# wrap up -- add score if not evaluating every epoch
424425
if not self.eval_valid_each_epoch(X):
@@ -490,13 +491,10 @@ def early_stop_handler(self, X: Dict[str, Any]) -> bool:
490491
if self.checkpoint_dir is None:
491492
self.checkpoint_dir = tempfile.mkdtemp(dir=X['backend'].temporary_directory)
492493

494+
target_metrics = 'val_loss'
493495
if X['val_indices'] is None:
494-
if X['X_test'] is not None:
495-
epochs_since_best = self.run_summary.get_last_epoch() - self.run_summary.get_best_epoch('test_loss')
496-
else:
497-
epochs_since_best = self.run_summary.get_last_epoch() - self.run_summary.get_best_epoch('train_loss')
498-
else:
499-
epochs_since_best = self.run_summary.get_last_epoch() - self.run_summary.get_best_epoch()
496+
target_metrics = 'test_loss' if X['X_test'] is not None else 'train_loss'
497+
epochs_since_best = self.run_summary.get_last_epoch() - self.run_summary.get_best_epoch(target_metrics)
500498

501499
# Save the checkpoint if there is a new best epoch
502500
best_path = os.path.join(self.checkpoint_dir, 'best.pth')
@@ -626,11 +624,12 @@ def __str__(self) -> str:
626624
def _get_search_space_updates(self, prefix: Optional[str] = None) -> Dict[str, HyperparameterSearchSpace]:
627625
"""Get the search space updates with the given prefix
628626
629-
Keyword Arguments:
630-
prefix {str} -- Only return search space updates with given prefix (default: {None})
627+
Args:
628+
prefix (Optional[str]): Only return search space updates with given prefix
631629
632630
Returns:
633-
dict -- Mapping of search space updates. Keys don't contain the prefix.
631+
Dict[str, HyperparameterSearchSpace]:
632+
Mapping of search space updates. Keys don't contain the prefix.
634633
"""
635634
updates = super()._get_search_space_updates(prefix=prefix)
636635

autoPyTorch/pipeline/components/training/trainer/base_trainer.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from autoPyTorch.pipeline.components.training.base_training import autoPyTorchTrainingComponent
2727
from autoPyTorch.pipeline.components.training.metrics.metrics import CLASSIFICATION_METRICS, REGRESSION_METRICS
2828
from autoPyTorch.pipeline.components.training.metrics.utils import calculate_score
29-
from autoPyTorch.pipeline.components.training.trainer.utils import Lookahead, swa_average_function
29+
from autoPyTorch.pipeline.components.training.trainer.utils import Lookahead, swa_update
3030
from autoPyTorch.utils.common import FitRequirement, HyperparameterSearchSpace, add_hyperparameter, get_hyperparameter
3131
from autoPyTorch.utils.implementations import get_loss_weight_strategy
3232

@@ -214,7 +214,7 @@ def __init__(self, weighted_loss: bool = False,
214214
use_snapshot_ensemble: bool = True,
215215
se_lastk: int = 3,
216216
use_lookahead_optimizer: bool = True,
217-
random_state: Optional[Union[np.random.RandomState, int]] = None,
217+
random_state: Optional[np.random.RandomState] = None,
218218
swa_model: Optional[torch.nn.Module] = None,
219219
model_snapshots: Optional[List[torch.nn.Module]] = None,
220220
**lookahead_config: Any) -> None:
@@ -275,13 +275,14 @@ def prepare(
275275

276276
# in case we are using swa, maintain an averaged model,
277277
if self.use_stochastic_weight_averaging:
278-
self.swa_model = swa_utils.AveragedModel(self.model, avg_fn=swa_average_function)
278+
self.swa_model = swa_utils.AveragedModel(self.model, avg_fn=swa_update)
279279

280280
# in case we are using se or swa, initialise budget_threshold to know when to start swa or se
281281
self._budget_threshold = 0
282282
if self.use_stochastic_weight_averaging or self.use_snapshot_ensemble:
283-
assert budget_tracker.max_epochs is not None, "Can only use stochastic weight averaging or snapshot " \
284-
"ensemble when budget is epochs"
283+
if budget_tracker.max_epochs is None:
284+
raise ValueError("Budget for stochastic weight averaging or snapshot ensemble must be `epoch`.")
285+
285286
self._budget_threshold = int(0.75 * budget_tracker.max_epochs)
286287

287288
# in case we are using se, initialise list to store model snapshots
@@ -576,7 +577,7 @@ def get_hyperparameter_search_space(
576577
dataset_properties: Optional[Dict] = None,
577578
weighted_loss: HyperparameterSearchSpace = HyperparameterSearchSpace(
578579
hyperparameter="weighted_loss",
579-
value_range=[True, False],
580+
value_range=(True, False),
580581
default_value=True),
581582
la_steps: HyperparameterSearchSpace = HyperparameterSearchSpace(
582583
hyperparameter="la_steps",
@@ -608,9 +609,7 @@ def get_hyperparameter_search_space(
608609
cs = ConfigurationSpace()
609610

610611
add_hyperparameter(cs, use_stochastic_weight_averaging, CategoricalHyperparameter)
611-
snapshot_ensemble_flag = False
612-
if any(use_snapshot_ensemble.value_range):
613-
snapshot_ensemble_flag = True
612+
snapshot_ensemble_flag = any(use_snapshot_ensemble.value_range)
614613

615614
use_snapshot_ensemble = get_hyperparameter(use_snapshot_ensemble, CategoricalHyperparameter)
616615
cs.add_hyperparameter(use_snapshot_ensemble)
@@ -621,9 +620,7 @@ def get_hyperparameter_search_space(
621620
cond = EqualsCondition(se_lastk, use_snapshot_ensemble, True)
622621
cs.add_condition(cond)
623622

624-
lookahead_flag = False
625-
if any(use_lookahead_optimizer.value_range):
626-
lookahead_flag = True
623+
lookahead_flag = any(use_lookahead_optimizer.value_range)
627624

628625
use_lookahead_optimizer = get_hyperparameter(use_lookahead_optimizer, CategoricalHyperparameter)
629626
cs.add_hyperparameter(use_lookahead_optimizer)

0 commit comments

Comments
 (0)