Skip to content

Commit df2cdb3

Browse files
nabenabe0928ravinkohli
authored andcommitted
[refactor] Address Shuhei's comments
1 parent 25e85db commit df2cdb3

File tree

5 files changed

+65
-66
lines changed

5 files changed

+65
-66
lines changed

autoPyTorch/pipeline/components/training/trainer/__init__.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -413,12 +413,13 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
413413

414414
if self.choice.use_stochastic_weight_averaging and self.choice.swa_updated:
415415
# update batch norm statistics
416-
swa_utils.update_bn(X['train_data_loader'], self.choice.swa_model.double())
416+
swa_utils.update_bn(loader=X['train_data_loader'], model=self.choice.swa_model.double())
417+
417418
# change model
418419
update_model_state_dict_from_swa(X['network'], self.choice.swa_model.state_dict())
419420
if self.choice.use_snapshot_ensemble:
420421
for model in self.choice.model_snapshots:
421-
swa_utils.update_bn(X['train_data_loader'], model.double())
422+
swa_utils.update_bn(loader=X['train_data_loader'], model=model.double())
422423

423424
# wrap up -- add score if not evaluating every epoch
424425
if not self.eval_valid_each_epoch(X):
@@ -490,13 +491,10 @@ def early_stop_handler(self, X: Dict[str, Any]) -> bool:
490491
if self.checkpoint_dir is None:
491492
self.checkpoint_dir = tempfile.mkdtemp(dir=X['backend'].temporary_directory)
492493

494+
target_metrics = 'val_loss'
493495
if X['val_indices'] is None:
494-
if X['X_test'] is not None:
495-
epochs_since_best = self.run_summary.get_last_epoch() - self.run_summary.get_best_epoch('test_loss')
496-
else:
497-
epochs_since_best = self.run_summary.get_last_epoch() - self.run_summary.get_best_epoch('train_loss')
498-
else:
499-
epochs_since_best = self.run_summary.get_last_epoch() - self.run_summary.get_best_epoch()
496+
target_metrics = 'test_loss' if X['X_test'] is not None else 'train_loss'
497+
epochs_since_best = self.run_summary.get_last_epoch() - self.run_summary.get_best_epoch(target_metrics)
500498

501499
# Save the checkpoint if there is a new best epoch
502500
best_path = os.path.join(self.checkpoint_dir, 'best.pth')
@@ -626,11 +624,12 @@ def __str__(self) -> str:
626624
def _get_search_space_updates(self, prefix: Optional[str] = None) -> Dict[str, HyperparameterSearchSpace]:
627625
"""Get the search space updates with the given prefix
628626
629-
Keyword Arguments:
630-
prefix {str} -- Only return search space updates with given prefix (default: {None})
627+
Args:
628+
prefix (Optional[str]): Only return search space updates with given prefix
631629
632630
Returns:
633-
dict -- Mapping of search space updates. Keys don't contain the prefix.
631+
Dict[str, HyperparameterSearchSpace]:
632+
Mapping of search space updates. Keys don't contain the prefix.
634633
"""
635634
updates = super()._get_search_space_updates(prefix=prefix)
636635

autoPyTorch/pipeline/components/training/trainer/base_trainer.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from autoPyTorch.pipeline.components.training.metrics.metrics import CLASSIFICATION_METRICS, REGRESSION_METRICS
2929
from autoPyTorch.pipeline.components.training.trainer.utils import Lookahead
3030
from autoPyTorch.pipeline.components.training.metrics.utils import calculate_score
31-
from autoPyTorch.pipeline.components.training.trainer.utils import Lookahead, swa_average_function
31+
from autoPyTorch.pipeline.components.training.trainer.utils import Lookahead, swa_update
3232
from autoPyTorch.utils.common import FitRequirement, HyperparameterSearchSpace, add_hyperparameter, get_hyperparameter
3333
from autoPyTorch.utils.implementations import get_loss_weight_strategy
3434

@@ -216,7 +216,7 @@ def __init__(self, weighted_loss: bool = False,
216216
use_snapshot_ensemble: bool = True,
217217
se_lastk: int = 3,
218218
use_lookahead_optimizer: bool = True,
219-
random_state: Optional[Union[np.random.RandomState, int]] = None,
219+
random_state: Optional[np.random.RandomState] = None,
220220
swa_model: Optional[torch.nn.Module] = None,
221221
model_snapshots: Optional[List[torch.nn.Module]] = None,
222222
**lookahead_config: Any) -> None:
@@ -277,13 +277,14 @@ def prepare(
277277

278278
# in case we are using swa, maintain an averaged model,
279279
if self.use_stochastic_weight_averaging:
280-
self.swa_model = swa_utils.AveragedModel(self.model, avg_fn=swa_average_function)
280+
self.swa_model = swa_utils.AveragedModel(self.model, avg_fn=swa_update)
281281

282282
# in case we are using se or swa, initialise budget_threshold to know when to start swa or se
283283
self._budget_threshold = 0
284284
if self.use_stochastic_weight_averaging or self.use_snapshot_ensemble:
285-
assert budget_tracker.max_epochs is not None, "Can only use stochastic weight averaging or snapshot " \
286-
"ensemble when budget is epochs"
285+
if budget_tracker.max_epochs is None:
286+
raise ValueError("Budget for stochastic weight averaging or snapshot ensemble must be `epoch`.")
287+
287288
self._budget_threshold = int(0.75 * budget_tracker.max_epochs)
288289

289290
# in case we are using se, initialise list to store model snapshots
@@ -578,7 +579,7 @@ def get_hyperparameter_search_space(
578579
dataset_properties: Optional[Dict] = None,
579580
weighted_loss: HyperparameterSearchSpace = HyperparameterSearchSpace(
580581
hyperparameter="weighted_loss",
581-
value_range=[True, False],
582+
value_range=(True, False),
582583
default_value=True),
583584
la_steps: HyperparameterSearchSpace = HyperparameterSearchSpace(
584585
hyperparameter="la_steps",
@@ -610,9 +611,7 @@ def get_hyperparameter_search_space(
610611
cs = ConfigurationSpace()
611612

612613
add_hyperparameter(cs, use_stochastic_weight_averaging, CategoricalHyperparameter)
613-
snapshot_ensemble_flag = False
614-
if any(use_snapshot_ensemble.value_range):
615-
snapshot_ensemble_flag = True
614+
snapshot_ensemble_flag = any(use_snapshot_ensemble.value_range)
616615

617616
use_snapshot_ensemble = get_hyperparameter(use_snapshot_ensemble, CategoricalHyperparameter)
618617
cs.add_hyperparameter(use_snapshot_ensemble)
@@ -623,9 +622,7 @@ def get_hyperparameter_search_space(
623622
cond = EqualsCondition(se_lastk, use_snapshot_ensemble, True)
624623
cs.add_condition(cond)
625624

626-
lookahead_flag = False
627-
if any(use_lookahead_optimizer.value_range):
628-
lookahead_flag = True
625+
lookahead_flag = any(use_lookahead_optimizer.value_range)
629626

630627
use_lookahead_optimizer = get_hyperparameter(use_lookahead_optimizer, CategoricalHyperparameter)
631628
cs.add_hyperparameter(use_lookahead_optimizer)

autoPyTorch/pipeline/components/training/trainer/cutout_utils.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -60,45 +60,45 @@ def criterion_preparation(self, y_a: np.ndarray, y_b: np.ndarray = None, lam: fl
6060

6161
@staticmethod
6262
def get_hyperparameter_search_space(
63-
dataset_properties: Optional[Dict] = None,
64-
weighted_loss: HyperparameterSearchSpace = HyperparameterSearchSpace(
65-
hyperparameter="weighted_loss",
66-
value_range=[True, False],
67-
default_value=True),
68-
la_steps: HyperparameterSearchSpace = HyperparameterSearchSpace(
69-
hyperparameter="la_steps",
70-
value_range=(5, 10),
71-
default_value=6,
72-
log=False),
73-
la_alpha: HyperparameterSearchSpace = HyperparameterSearchSpace(
74-
hyperparameter="la_alpha",
75-
value_range=(0.5, 0.8),
76-
default_value=0.6,
77-
log=False),
78-
use_lookahead_optimizer: HyperparameterSearchSpace = HyperparameterSearchSpace(
79-
hyperparameter="use_lookahead_optimizer",
80-
value_range=(True, False),
81-
default_value=True),
82-
use_stochastic_weight_averaging: HyperparameterSearchSpace = HyperparameterSearchSpace(
83-
hyperparameter="use_stochastic_weight_averaging",
84-
value_range=(True, False),
85-
default_value=True),
86-
use_snapshot_ensemble: HyperparameterSearchSpace = HyperparameterSearchSpace(
87-
hyperparameter="use_snapshot_ensemble",
88-
value_range=(True, False),
89-
default_value=True),
90-
se_lastk: HyperparameterSearchSpace = HyperparameterSearchSpace(
91-
hyperparameter="se_lastk",
92-
value_range=(3,),
93-
default_value=3),
94-
patch_ratio: HyperparameterSearchSpace = HyperparameterSearchSpace(
95-
hyperparameter="patch_ratio",
96-
value_range=(0, 1),
97-
default_value=0.2),
98-
cutout_prob: HyperparameterSearchSpace = HyperparameterSearchSpace(
99-
hyperparameter="cutout_prob",
100-
value_range=(0, 1),
101-
default_value=0.2),
63+
dataset_properties: Optional[Dict] = None,
64+
weighted_loss: HyperparameterSearchSpace = HyperparameterSearchSpace(
65+
hyperparameter="weighted_loss",
66+
value_range=(True, False),
67+
default_value=True),
68+
la_steps: HyperparameterSearchSpace = HyperparameterSearchSpace(
69+
hyperparameter="la_steps",
70+
value_range=(5, 10),
71+
default_value=6,
72+
log=False),
73+
la_alpha: HyperparameterSearchSpace = HyperparameterSearchSpace(
74+
hyperparameter="la_alpha",
75+
value_range=(0.5, 0.8),
76+
default_value=0.6,
77+
log=False),
78+
use_lookahead_optimizer: HyperparameterSearchSpace = HyperparameterSearchSpace(
79+
hyperparameter="use_lookahead_optimizer",
80+
value_range=(True, False),
81+
default_value=True),
82+
use_stochastic_weight_averaging: HyperparameterSearchSpace = HyperparameterSearchSpace(
83+
hyperparameter="use_stochastic_weight_averaging",
84+
value_range=(True, False),
85+
default_value=True),
86+
use_snapshot_ensemble: HyperparameterSearchSpace = HyperparameterSearchSpace(
87+
hyperparameter="use_snapshot_ensemble",
88+
value_range=(True, False),
89+
default_value=True),
90+
se_lastk: HyperparameterSearchSpace = HyperparameterSearchSpace(
91+
hyperparameter="se_lastk",
92+
value_range=(3,),
93+
default_value=3),
94+
patch_ratio: HyperparameterSearchSpace = HyperparameterSearchSpace(
95+
hyperparameter="patch_ratio",
96+
value_range=(0, 1),
97+
default_value=0.2),
98+
cutout_prob: HyperparameterSearchSpace = HyperparameterSearchSpace(
99+
hyperparameter="cutout_prob",
100+
value_range=(0, 1),
101+
default_value=0.2),
102102
) -> ConfigurationSpace:
103103

104104
cs = ConfigurationSpace()

autoPyTorch/pipeline/components/training/trainer/mixup_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def get_hyperparameter_search_space(
6161
dataset_properties: Optional[Dict] = None,
6262
weighted_loss: HyperparameterSearchSpace = HyperparameterSearchSpace(
6363
hyperparameter="weighted_loss",
64-
value_range=[True, False],
64+
value_range=(True, False),
6565
default_value=True),
6666
la_steps: HyperparameterSearchSpace = HyperparameterSearchSpace(
6767
hyperparameter="la_steps",

autoPyTorch/pipeline/components/training/trainer/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,16 @@ def update_model_state_dict_from_swa(model: torch.nn.Module, swa_state_dict: Dic
3434
model_state[name].copy_(param)
3535

3636

37-
def swa_average_function(averaged_model_parameter: torch.nn.parameter.Parameter,
38-
model_parameter: torch.nn.parameter.Parameter,
39-
num_averaged: int) -> torch.nn.parameter.Parameter:
37+
def swa_update(averaged_model_parameter: torch.nn.parameter.Parameter,
38+
model_parameter: torch.nn.parameter.Parameter,
39+
num_averaged: int) -> torch.nn.parameter.Parameter:
4040
"""
4141
Pickling the averaged function causes an error because of
4242
how pytorch initialises the average function.
4343
Passing this function fixes the issue.
44+
The sequential update is performed via:
45+
avg[n + 1] = (avg[n] * n + W[n + 1]) / (n + 1)
46+
4447
Args:
4548
averaged_model_parameter:
4649
model_parameter:

0 commit comments

Comments
 (0)