Skip to content

Commit 5af8aab

Browse files
nabenabe0928ravinkohli
authored andcommitted
[refactor] Address Shuhei's comments
[fix] Fix Flake8 issues [refactor] Address Shuhei's comment [refactor] Address Shuhei's comments [refactor] Address Shuhei's comments [refactor] Address Shuhei's comments
1 parent 09ad0d7 commit 5af8aab

File tree

11 files changed

+130
-138
lines changed

11 files changed

+130
-138
lines changed

autoPyTorch/evaluation/tae.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -201,27 +201,12 @@ def __init__(
201201

202202
self.search_space_updates = search_space_updates
203203

204-
<<<<<<< HEAD
205204
def _check_and_get_default_budget(self) -> float:
206205
budget_type_choices = ('epochs', 'runtime')
207206
budget_choices = {
208207
budget_type: float(self.pipeline_config.get(budget_type, np.inf))
209208
for budget_type in budget_type_choices
210209
}
211-
=======
212-
if isinstance(self.resampling_strategy, (HoldoutValTypes, CrossValTypes)):
213-
eval_function = autoPyTorch.evaluation.train_evaluator.eval_function
214-
elif isinstance(self.resampling_strategy, NoResamplingStrategyTypes):
215-
eval_function = autoPyTorch.evaluation.fit_evaluator.eval_function
216-
else:
217-
raise ValueError("resampling strategy must be in "
218-
"(HoldoutValTypes, CrossValTypes, NoResamplingStrategyTypes), "
219-
"but got {}.".format(self.resampling_strategy)
220-
)
221-
222-
self.worst_possible_result = cost_for_crash
223-
>>>>>>> Cocktail hotfixes (#245)
224-
225210
# budget is defined by epochs by default
226211
budget_type = str(self.pipeline_config.get('budget_type', 'epochs'))
227212
if self.budget_type is not None:

autoPyTorch/evaluation/train_evaluator.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -419,11 +419,7 @@ def _predict(self, pipeline: BaseEstimator,
419419

420420

421421
# create closure for evaluating an algorithm
422-
<<<<<<< HEAD
423422
def eval_train_function(
424-
=======
425-
def eval_function(
426-
>>>>>>> Create fit evaluator, no resampling strategy and fix bug for test statistics
427423
backend: Backend,
428424
queue: Queue,
429425
metric: autoPyTorchMetric,

autoPyTorch/pipeline/components/training/trainer/AdversarialTrainer.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@ def __init__(
3737
3838
Args:
3939
epsilon (float): The perturbation magnitude.
40-
40+
41+
References:
42+
Explaining and Harnessing Adversarial Examples
43+
Ian J. Goodfellow et. al.
44+
https://arxiv.org/pdf/1412.6572.pdf
4145
"""
4246
super().__init__(random_state=random_state,
4347
weighted_loss=weighted_loss,
@@ -96,10 +100,10 @@ def train_step(self, data: np.ndarray, targets: np.ndarray) -> Tuple[float, torc
96100
# training
97101
self.optimizer.zero_grad()
98102
original_outputs = self.model(original_data)
99-
adversarial_output = self.model(adversarial_data)
103+
adversarial_outputs = self.model(adversarial_data)
100104

101105
loss_func = self.criterion_preparation(**criterion_kwargs)
102-
loss = loss_func(self.criterion, original_outputs, adversarial_output)
106+
loss = loss_func(self.criterion, original_outputs, adversarial_outputs)
103107
loss.backward()
104108
self.optimizer.step()
105109
if self.scheduler:
@@ -125,6 +129,9 @@ def fgsm_attack(
125129
126130
Returns:
127131
adv_data (np.ndarray): the adversarial examples.
132+
133+
References:
134+
https://pytorch.org/tutorials/beginner/fgsm_tutorial.html#fgsm-attack
128135
"""
129136
data_copy = deepcopy(data)
130137
data_copy = data_copy.float().to(self.device)
@@ -159,7 +166,7 @@ def get_hyperparameter_search_space(
159166
dataset_properties: Optional[Dict] = None,
160167
weighted_loss: HyperparameterSearchSpace = HyperparameterSearchSpace(
161168
hyperparameter="weighted_loss",
162-
value_range=[True, False],
169+
value_range=(True, False),
163170
default_value=True),
164171
la_steps: HyperparameterSearchSpace = HyperparameterSearchSpace(
165172
hyperparameter="la_steps",
@@ -196,9 +203,7 @@ def get_hyperparameter_search_space(
196203

197204
add_hyperparameter(cs, epsilon, UniformFloatHyperparameter)
198205
add_hyperparameter(cs, use_stochastic_weight_averaging, CategoricalHyperparameter)
199-
snapshot_ensemble_flag = False
200-
if any(use_snapshot_ensemble.value_range):
201-
snapshot_ensemble_flag = True
206+
snapshot_ensemble_flag = any(use_snapshot_ensemble.value_range)
202207

203208
use_snapshot_ensemble = get_hyperparameter(use_snapshot_ensemble, CategoricalHyperparameter)
204209
cs.add_hyperparameter(use_snapshot_ensemble)
@@ -209,9 +214,7 @@ def get_hyperparameter_search_space(
209214
cond = EqualsCondition(se_lastk, use_snapshot_ensemble, True)
210215
cs.add_condition(cond)
211216

212-
lookahead_flag = False
213-
if any(use_lookahead_optimizer.value_range):
214-
lookahead_flag = True
217+
lookahead_flag = any(use_lookahead_optimizer.value_range)
215218

216219
use_lookahead_optimizer = get_hyperparameter(use_lookahead_optimizer, CategoricalHyperparameter)
217220
cs.add_hyperparameter(use_lookahead_optimizer)

autoPyTorch/pipeline/components/training/trainer/GridCutMixTrainer.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
2626
np.ndarray: that processes data
2727
typing.Dict[str, np.ndarray]: arguments to the criterion function
2828
"""
29-
beta = 1.0
30-
lam = self.random_state.beta(beta, beta)
31-
batch_size, channel, W, H = X.size()
32-
index = torch.randperm(batch_size).cuda() if X.is_cuda else torch.randperm(batch_size)
29+
alpha, beta = 1.0, 1.0
30+
lam = self.random_state.beta(alpha, beta)
31+
batch_size, _, W, H = X.shape
32+
device = torch.device('cuda' if X.is_cuda else 'cpu')
33+
batch_indices = torch.randperm(batch_size).to(device)
3334

3435
r = self.random_state.rand(1)
3536
if beta <= 0 or r > self.alpha:
36-
return X, {'y_a': y, 'y_b': y[index], 'lam': 1}
37+
return X, {'y_a': y, 'y_b': y[batch_indices], 'lam': 1}
3738

3839
# Draw parameters of a random bounding box
3940
# Where to cut basically
@@ -47,12 +48,13 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
4748
bbx2 = np.clip(cx + cut_w // 2, 0, W)
4849
bby2 = np.clip(cy + cut_h // 2, 0, H)
4950

50-
X[:, :, bbx1:bbx2, bby1:bby2] = X[index, :, bbx1:bbx2, bby1:bby2]
51+
X[:, :, bbx1:bbx2, bby1:bby2] = X[batch_indices, :, bbx1:bbx2, bby1:bby2]
5152

5253
# Adjust lam
53-
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (X.size()[-1] * X.size()[-2]))
54+
pixel_size = W * H
55+
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / pixel_size)
5456

55-
y_a, y_b = y, y[index]
57+
y_a, y_b = y, y[batch_indices]
5658

5759
return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam}
5860

autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,31 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
2626
np.ndarray: that processes data
2727
typing.Dict[str, np.ndarray]: arguments to the criterion function
2828
"""
29-
beta = 1.0
30-
lam = self.random_state.beta(beta, beta)
31-
batch_size = X.size()[0]
32-
index = torch.randperm(batch_size).cuda() if X.is_cuda else torch.randperm(batch_size)
29+
alpha, beta = 1.0, 1.0
30+
lam = self.random_state.beta(alpha, beta)
31+
batch_size = X.shape[0]
32+
device = torch.device('cuda' if X.is_cuda else 'cpu')
33+
batch_indices = torch.randperm(batch_size).to(device)
3334

3435
r = self.random_state.rand(1)
3536
if beta <= 0 or r > self.alpha:
36-
return X, {'y_a': y, 'y_b': y[index], 'lam': 1}
37+
return X, {'y_a': y, 'y_b': y[batch_indices], 'lam': 1}
3738

38-
size = X.shape[1]
39-
indices = torch.tensor(self.random_state.choice(range(1, size), max(1, np.int32(size * lam)),
40-
replace=False))
39+
row_size = X.shape[1]
40+
row_indices = torch.tensor(
41+
self.random_state.choice(
42+
range(1, row_size),
43+
max(1, int(row_size * lam)),
44+
replace=False
45+
)
46+
)
4147

42-
X[:, indices] = X[index, :][:, indices]
48+
X[:, row_indices] = X[batch_indices, :][:, row_indices]
4349

4450
# Adjust lam
45-
lam = 1 - ((len(indices)) / (X.size()[1]))
51+
lam = 1 - len(row_indices) / X.shape[1]
4652

47-
y_a, y_b = y, y[index]
53+
y_a, y_b = y, y[batch_indices]
4854

4955
return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam}
5056

autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99

1010

1111
class RowCutOutTrainer(CutOut, BaseTrainerComponent):
12+
# 0 is non-informative in image data
1213
NUMERICAL_VALUE = 0
14+
# -1 is the conceptually equivalent to 0 in a image, i.e. 0-pad
1315
CATEGORICAL_VALUE = -1
1416

1517
def data_preparation(self, X: np.ndarray, y: np.ndarray,
@@ -36,23 +38,18 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
3638
lam = 1
3739
return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam}
3840

39-
size = X.shape[1]
40-
indices = self.random_state.choice(range(1, size), max(1, np.int32(size * self.patch_ratio)),
41-
replace=False)
41+
row_size = X.shape[1]
42+
row_indices = self.random_state.choice(range(1, row_size), max(1, int(row_size * self.patch_ratio)),
43+
replace=False)
4244

4345
if not isinstance(self.numerical_columns, typing.Iterable):
44-
raise ValueError("{} requires numerical columns information of {}"
45-
"to prepare data got {}.".format(self.__class__.__name__,
46-
typing.Iterable,
47-
self.numerical_columns))
46+
raise ValueError("numerical_columns in {} must be iterable, "
47+
"but got {}.".format(self.__class__.__name__,
48+
self.numerical_columns))
49+
4850
numerical_indices = torch.tensor(self.numerical_columns)
49-
categorical_indices = torch.tensor([index for index in indices if index not in self.numerical_columns])
51+
categorical_indices = torch.tensor([idx for idx in row_indices if idx not in self.numerical_columns])
5052

51-
# We use an ordinal encoder on the categorical columns of tabular data
52-
# -1 is the conceptual equivalent to 0 in a image, that does not
53-
# have color as a feature and hence the network has to learn to deal
54-
# without this data. For numerical columns we use 0 to cutout the features
55-
# similar to the effect that setting 0 as a pixel value in an image.
5653
X[:, categorical_indices.long()] = self.CATEGORICAL_VALUE
5754
X[:, numerical_indices.long()] = self.NUMERICAL_VALUE
5855

autoPyTorch/pipeline/components/training/trainer/__init__.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -384,11 +384,7 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
384384

385385
val_loss, val_metrics, test_loss, test_metrics = None, {}, None, {}
386386
if self.eval_valid_each_epoch(X):
387-
<<<<<<< HEAD
388387
if X['val_data_loader']:
389-
=======
390-
if 'val_data_loader' in X and X['val_data_loader']:
391-
>>>>>>> Create fit evaluator, no resampling strategy and fix bug for test statistics
392388
val_loss, val_metrics = self.choice.evaluate(X['val_data_loader'], epoch, writer)
393389
if 'test_data_loader' in X and X['test_data_loader']:
394390
test_loss, test_metrics = self.choice.evaluate(X['test_data_loader'], epoch, writer)
@@ -433,26 +429,20 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
433429

434430
if self.choice.use_stochastic_weight_averaging and self.choice.swa_updated:
435431
# update batch norm statistics
436-
swa_utils.update_bn(X['train_data_loader'], self.choice.swa_model.double())
432+
swa_utils.update_bn(loader=X['train_data_loader'], model=self.choice.swa_model.double())
433+
437434
# change model
438435
update_model_state_dict_from_swa(X['network'], self.choice.swa_model.state_dict())
439436
if self.choice.use_snapshot_ensemble:
440437
for model in self.choice.model_snapshots:
441-
swa_utils.update_bn(X['train_data_loader'], model.double())
438+
swa_utils.update_bn(loader=X['train_data_loader'], model=model.double())
442439

443440
# wrap up -- add score if not evaluating every epoch
444441
if not self.eval_valid_each_epoch(X):
445-
<<<<<<< HEAD
446442
if X['val_data_loader']:
447443
val_loss, val_metrics = self.choice.evaluate(X['val_data_loader'], epoch, writer)
448444
if 'test_data_loader' in X and X['val_data_loader']:
449445
test_loss, test_metrics = self.choice.evaluate(X['test_data_loader'], epoch, writer)
450-
=======
451-
if 'val_data_loader' in X and X['val_data_loader']:
452-
val_loss, val_metrics = self.choice.evaluate(X['val_data_loader'], epoch, writer)
453-
if 'test_data_loader' in X and X['test_data_loader']:
454-
test_loss, test_metrics = self.choice.evaluate(X['test_data_loader'])
455-
>>>>>>> Create fit evaluator, no resampling strategy and fix bug for test statistics
456446
self.run_summary.add_performance(
457447
epoch=epoch,
458448
start_time=start_time,
@@ -653,11 +643,12 @@ def __str__(self) -> str:
653643
def _get_search_space_updates(self, prefix: Optional[str] = None) -> Dict[str, HyperparameterSearchSpace]:
654644
"""Get the search space updates with the given prefix
655645
656-
Keyword Arguments:
657-
prefix {str} -- Only return search space updates with given prefix (default: {None})
646+
Args:
647+
prefix (Optional[str]): Only return search space updates with given prefix
658648
659649
Returns:
660-
dict -- Mapping of search space updates. Keys don't contain the prefix.
650+
Dict[str, HyperparameterSearchSpace]:
651+
Mapping of search space updates. Keys don't contain the prefix.
661652
"""
662653
updates = super()._get_search_space_updates(prefix=prefix)
663654

autoPyTorch/pipeline/components/training/trainer/base_trainer.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from autoPyTorch.pipeline.components.training.metrics.metrics import CLASSIFICATION_METRICS, REGRESSION_METRICS
2929
from autoPyTorch.pipeline.components.training.trainer.utils import Lookahead
3030
from autoPyTorch.pipeline.components.training.metrics.utils import calculate_score
31-
from autoPyTorch.pipeline.components.training.trainer.utils import Lookahead, swa_average_function
31+
from autoPyTorch.pipeline.components.training.trainer.utils import Lookahead, swa_update
3232
from autoPyTorch.utils.common import FitRequirement, HyperparameterSearchSpace, add_hyperparameter, get_hyperparameter
3333
from autoPyTorch.utils.implementations import get_loss_weight_strategy
3434

@@ -226,7 +226,7 @@ def __init__(self, weighted_loss: bool = False,
226226
use_snapshot_ensemble: bool = True,
227227
se_lastk: int = 3,
228228
use_lookahead_optimizer: bool = True,
229-
random_state: Optional[Union[np.random.RandomState, int]] = None,
229+
random_state: Optional[np.random.RandomState] = None,
230230
swa_model: Optional[torch.nn.Module] = None,
231231
model_snapshots: Optional[List[torch.nn.Module]] = None,
232232
**lookahead_config: Any) -> None:
@@ -287,13 +287,14 @@ def prepare(
287287

288288
# in case we are using swa, maintain an averaged model,
289289
if self.use_stochastic_weight_averaging:
290-
self.swa_model = swa_utils.AveragedModel(self.model, avg_fn=swa_average_function)
290+
self.swa_model = swa_utils.AveragedModel(self.model, avg_fn=swa_update)
291291

292292
# in case we are using se or swa, initialise budget_threshold to know when to start swa or se
293293
self._budget_threshold = 0
294294
if self.use_stochastic_weight_averaging or self.use_snapshot_ensemble:
295-
assert budget_tracker.max_epochs is not None, "Can only use stochastic weight averaging or snapshot " \
296-
"ensemble when budget is epochs"
295+
if budget_tracker.max_epochs is None:
296+
raise ValueError("Budget for stochastic weight averaging or snapshot ensemble must be `epoch`.")
297+
297298
self._budget_threshold = int(0.75 * budget_tracker.max_epochs)
298299

299300
# in case we are using se, initialise list to store model snapshots
@@ -591,7 +592,7 @@ def get_hyperparameter_search_space(
591592
dataset_properties: Optional[Dict] = None,
592593
weighted_loss: HyperparameterSearchSpace = HyperparameterSearchSpace(
593594
hyperparameter="weighted_loss",
594-
value_range=[True, False],
595+
value_range=(True, False),
595596
default_value=True),
596597
la_steps: HyperparameterSearchSpace = HyperparameterSearchSpace(
597598
hyperparameter="la_steps",
@@ -623,22 +624,30 @@ def get_hyperparameter_search_space(
623624
cs = ConfigurationSpace()
624625

625626
add_hyperparameter(cs, use_stochastic_weight_averaging, CategoricalHyperparameter)
627+
snapshot_ensemble_flag = any(use_snapshot_ensemble.value_range)
628+
626629
use_snapshot_ensemble = get_hyperparameter(use_snapshot_ensemble, CategoricalHyperparameter)
627-
se_lastk = get_hyperparameter(se_lastk, Constant)
628-
cs.add_hyperparameters([use_snapshot_ensemble, se_lastk])
629-
cond = EqualsCondition(se_lastk, use_snapshot_ensemble, True)
630-
cs.add_condition(cond)
630+
cs.add_hyperparameter(use_snapshot_ensemble)
631631

632+
if snapshot_ensemble_flag:
633+
se_lastk = get_hyperparameter(se_lastk, Constant)
634+
cs.add_hyperparameter(se_lastk)
635+
cond = EqualsCondition(se_lastk, use_snapshot_ensemble, True)
636+
cs.add_condition(cond)
637+
638+
lookahead_flag = any(use_lookahead_optimizer.value_range)
632639
use_lookahead_optimizer = get_hyperparameter(use_lookahead_optimizer, CategoricalHyperparameter)
633640
cs.add_hyperparameter(use_lookahead_optimizer)
634-
la_config_space = Lookahead.get_hyperparameter_search_space(la_steps=la_steps,
635-
la_alpha=la_alpha)
636-
parent_hyperparameter = {'parent': use_lookahead_optimizer, 'value': True}
637-
cs.add_configuration_space(
638-
Lookahead.__name__,
639-
la_config_space,
640-
parent_hyperparameter=parent_hyperparameter
641-
)
641+
642+
if lookahead_flag:
643+
la_config_space = Lookahead.get_hyperparameter_search_space(la_steps=la_steps,
644+
la_alpha=la_alpha)
645+
parent_hyperparameter = {'parent': use_lookahead_optimizer, 'value': True}
646+
cs.add_configuration_space(
647+
Lookahead.__name__,
648+
la_config_space,
649+
parent_hyperparameter=parent_hyperparameter
650+
)
642651

643652
# TODO, decouple the weighted loss from the trainer
644653
if dataset_properties is not None:

0 commit comments

Comments
 (0)