Skip to content

Commit 107750a

Browse files
ArlindKadraravinkohli
authored andcommitted
Cocktail hotfixes (#245)
* Fixes for the development branch and regularization cocktails * Update implementation * Fix unit tests temporarily * Implementation update and bug fixes * Removing unecessary code * Addressing Ravin's comments [refactor] Address Shuhei's comments [refactor] Address Shuhei's comments [refactor] Address Shuhei's comments [refactor] Address Shuhei's comments
1 parent 8b8ba42 commit 107750a

File tree

12 files changed

+83
-49
lines changed

12 files changed

+83
-49
lines changed

autoPyTorch/api/base_task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def set_pipeline_config(self, **pipeline_config_kwargs: Any) -> None:
401401
None
402402
"""
403403
unknown_keys = []
404-
for option, value in pipeline_config_kwargs.items():
404+
for option in pipeline_config_kwargs.keys():
405405
if option in self.pipeline_options.keys():
406406
pass
407407
else:

autoPyTorch/api/tabular_classification.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,16 @@ def search(
379379
y_test=y_test,
380380
resampling_strategy=self.resampling_strategy,
381381
resampling_strategy_args=self.resampling_strategy_args,
382-
dataset_name=dataset_name)
382+
dataset_name=dataset_name
383+
)
384+
385+
if not isinstance(self.resampling_strategy, (CrossValTypes, HoldoutValTypes)):
386+
raise ValueError(
387+
'Hyperparameter optimization requires a validation split. '
388+
'Expected `self.resampling_strategy` to be either '
389+
'(CrossValTypes, HoldoutValTypes), but got {}'.format(self.resampling_strategy)
390+
)
391+
383392

384393
return self._search(
385394
dataset=self.dataset,
@@ -420,23 +429,23 @@ def predict(
420429
raise ValueError("predict() is only supported after calling search. Kindly call first "
421430
"the estimator search() method.")
422431

423-
X_test = self.InputValidator.feature_validator.transform(X_test)
432+
X_test = self.input_validator.feature_validator.transform(X_test)
424433
predicted_probabilities = super().predict(X_test, batch_size=batch_size,
425434
n_jobs=n_jobs)
426435

427-
if self.InputValidator.target_validator.is_single_column_target():
436+
if self.input_validator.target_validator.is_single_column_target():
428437
predicted_indexes = np.argmax(predicted_probabilities, axis=1)
429438
else:
430439
predicted_indexes = (predicted_probabilities > 0.5).astype(int)
431440

432441
# Allow to predict in the original domain -- that is, the user is not interested
433442
# in our encoded values
434-
return self.InputValidator.target_validator.inverse_transform(predicted_indexes)
443+
return self.input_validator.target_validator.inverse_transform(predicted_indexes)
435444

436445
def predict_proba(self,
437446
X_test: Union[np.ndarray, pd.DataFrame, List],
438447
batch_size: Optional[int] = None, n_jobs: int = 1) -> np.ndarray:
439-
if self.InputValidator is None or not self.InputValidator._is_fitted:
448+
if self.input_validator is None or not self.input_validator._is_fitted:
440449
raise ValueError("predict() is only supported after calling search. Kindly call first "
441450
"the estimator search() method.")
442451
X_test = self.InputValidator.feature_validator.transform(X_test)

autoPyTorch/api/tabular_regression.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,16 @@ def search(
377377
y_test=y_test,
378378
resampling_strategy=self.resampling_strategy,
379379
resampling_strategy_args=self.resampling_strategy_args,
380-
dataset_name=dataset_name)
380+
dataset_name=dataset_name
381+
)
382+
383+
if not isinstance(self.resampling_strategy, (CrossValTypes, HoldoutValTypes)):
384+
raise ValueError(
385+
'Hyperparameter optimization requires a validation split. '
386+
'Expected `self.resampling_strategy` to be either '
387+
'(CrossValTypes, HoldoutValTypes), but got {}'.format(self.resampling_strategy)
388+
)
389+
381390

382391
return self._search(
383392
dataset=self.dataset,
@@ -404,14 +413,14 @@ def predict(
404413
batch_size: Optional[int] = None,
405414
n_jobs: int = 1
406415
) -> np.ndarray:
407-
if self.InputValidator is None or not self.InputValidator._is_fitted:
416+
if self.input_validator is None or not self.input_validator._is_fitted:
408417
raise ValueError("predict() is only supported after calling search. Kindly call first "
409418
"the estimator search() method.")
410419

411-
X_test = self.InputValidator.feature_validator.transform(X_test)
420+
X_test = self.input_validator.feature_validator.transform(X_test)
412421
predicted_values = super().predict(X_test, batch_size=batch_size,
413422
n_jobs=n_jobs)
414423

415424
# Allow to predict in the original domain -- that is, the user is not interested
416425
# in our encoded values
417-
return self.InputValidator.target_validator.inverse_transform(predicted_values)
426+
return self.input_validator.target_validator.inverse_transform(predicted_values)

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def _get_columns_to_encode(
391391
feat_type = []
392392

393393
# Make sure each column is a valid type
394-
for i, column in enumerate(X.columns):
394+
for column in X.columns:
395395
if X[column].dtype.name in ['category', 'bool']:
396396

397397
transformed_columns.append(column)
@@ -512,7 +512,7 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:
512512
X[key] = X[key].astype(dtype.name)
513513
except Exception as e:
514514
# Try inference if possible
515-
self.logger.warning(f"Tried to cast column {key} to {dtype} caused {e}")
515+
self.logger.warning(f'Casting the column {key} to {dtype} caused the exception {e}')
516516
pass
517517
else:
518518
X = X.infer_objects()

autoPyTorch/evaluation/fit_evaluator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,12 @@ def __init__(self, backend: Backend, queue: Queue,
5858
pipeline_config=pipeline_config,
5959
search_space_updates=search_space_updates
6060
)
61-
assert isinstance(self.datamanager.resampling_strategy, NoResamplingStrategyTypes),\
62-
"This Evaluator is used for fitting a pipeline on the whole dataset. " \
63-
"Expected 'self.resampling_strategy' to be" \
64-
" 'NoResamplingStrategyTypes' got {}".format(self.datamanager.resampling_strategy)
61+
if not isinstance(self.datamanager.resampling_strategy, NoResamplingStrategyTypes):
62+
raise ValueError(
63+
"FitEvaluator needs to be fitted on the whole dataset and resampling_strategy "
64+
"must be `NoResamplingStrategyTypes`, but got {}".format(
65+
self.datamanager.resampling_strategy
66+
))
6567

6668
self.splits = self.datamanager.splits
6769
self.Y_target: Optional[np.ndarray] = None

autoPyTorch/evaluation/tae.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,10 @@ def __init__(
184184
elif isinstance(self.resampling_strategy, NoResamplingStrategyTypes):
185185
eval_function = autoPyTorch.evaluation.fit_evaluator.eval_function
186186
else:
187-
raise ValueError("Unknown resampling strategy specified."
188-
"Expected resampling strategy to be in "
189-
"'(HoldoutValTypes, CrossValTypes, NoResamplingStrategyTypes)"
190-
"got {}".format(self.resampling_strategy))
187+
raise ValueError("resampling strategy must be in "
188+
"(HoldoutValTypes, CrossValTypes, NoResamplingStrategyTypes), "
189+
"but got {}.".format(self.resampling_strategy)
190+
)
191191

192192
self.worst_possible_result = cost_for_crash
193193

@@ -336,6 +336,7 @@ def run(
336336
info: Optional[List[RunValue]]
337337
additional_run_info: Dict[str, Any]
338338
try:
339+
# By default, self.ta is fit_predict_try_except_decorator
339340
obj = pynisher.enforce_limits(**pynisher_arguments)(self.ta)
340341
obj(**obj_kwargs)
341342
except Exception as e:

autoPyTorch/evaluation/train_evaluator.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,13 @@ def __init__(self, backend: Backend, queue: Queue,
149149
pipeline_config=pipeline_config,
150150
search_space_updates=search_space_updates
151151
)
152-
assert isinstance(self.datamanager.resampling_strategy, (CrossValTypes, HoldoutValTypes)),\
153-
"This Evaluator is used for HPO Search. " \
154-
"Val Split is required for HPO search. " \
155-
"Expected 'self.resampling_strategy' in" \
156-
" '(CrossValTypes, HoldoutValTypes)' got {}".format(self.datamanager.resampling_strategy)
152+
153+
if not isinstance(self.datamanager.resampling_strategy, (CrossValTypes, HoldoutValTypes)):
154+
raise ValueError(
155+
'TrainEvaluator expect to have (CrossValTypes, HoldoutValTypes) as '
156+
'resampling_strategy, but got {}'.format(self.datamanager.resampling_strategy)
157+
)
158+
157159

158160
self.splits = self.datamanager.splits
159161
if self.splits is None:

autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def _add_group(self, in_features: int, out_features: int,
6464
out_features (int): output dimensionality for the current block
6565
blocks_per_group (int): Number of ResNet per group
6666
last_block_index (int): block index for shake regularization
67-
dropout (bool): whether or not use dropout
67+
dropout (None, float): dropout value for the group. If none,
68+
no dropout is applied.
6869
"""
6970
blocks = list()
7071
for i in range(blocks_per_group):
@@ -180,9 +181,7 @@ def get_hyperparameter_search_space(
180181

181182
if skip_connection_flag:
182183

183-
shake_drop_prob_flag = False
184-
if 'shake-drop' in multi_branch_choice.value_range:
185-
shake_drop_prob_flag = True
184+
shake_drop_prob_flag = 'shake-drop' in multi_branch_choice.value_range
186185

187186
mb_choice = get_hyperparameter(multi_branch_choice, CategoricalHyperparameter)
188187
cs.add_hyperparameter(mb_choice)
@@ -290,13 +289,21 @@ def _build_block(self, in_features: int, out_features: int) -> nn.Module:
290289
if self.config['use_batch_norm']:
291290
layers.append(nn.BatchNorm1d(in_features))
292291
layers.append(self.activation())
292+
elif not self.config['use_skip_connection']:
293+
# if start norm is not None and skip connection is False
294+
# we will never apply the start_norm for the first layer in the block,
295+
# which is why we should account for this case.
296+
if self.config['use_batch_norm']:
297+
layers.append(nn.BatchNorm1d(in_features))
298+
layers.append(self.activation())
299+
293300
layers.append(nn.Linear(in_features, out_features))
294301

295302
if self.config['use_batch_norm']:
296303
layers.append(nn.BatchNorm1d(out_features))
297304
layers.append(self.activation())
298305

299-
if self.config["use_dropout"]:
306+
if self.dropout is not None:
300307
layers.append(nn.Dropout(self.dropout))
301308
layers.append(nn.Linear(out_features, out_features))
302309

@@ -321,6 +328,7 @@ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
321328
if self.config["use_skip_connection"]:
322329
residual = self.shortcut(x)
323330

331+
# TODO make the below code better
324332
if self.config["use_skip_connection"]:
325333
if self.config["multi_branch_choice"] == 'shake-shake':
326334
x1 = self.layers(x)

autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,13 @@ def build_backbone(self, input_shape: Tuple[int, ...]) -> torch.nn.Sequential:
3131
out_features = self.config["output_dim"]
3232

3333
# use the get_shaped_neuron_counts to update the number of units
34-
neuron_counts = get_shaped_neuron_counts(self.config['resnet_shape'],
35-
in_features,
36-
out_features,
37-
self.config['max_units'],
38-
self.config['num_groups'] + 2)[:-1]
34+
neuron_counts = get_shaped_neuron_counts(
35+
shape=self.config['resnet_shape'],
36+
in_feat=in_features,
37+
out_feat=out_features,
38+
max_neurons=self.config['max_units'],
39+
layer_count=self.config['num_groups'] + 2,
40+
)[:-1]
3941
self.config.update(
4042
{"num_units_%d" % (i): num for i, num in enumerate(neuron_counts)}
4143
)
@@ -45,12 +47,12 @@ def build_backbone(self, input_shape: Tuple[int, ...]) -> torch.nn.Sequential:
4547
# n_units for the architecture, since, it is mostly implemented for the
4648
# output layer, which is part of the head and not of the backbone.
4749
dropout_shape = get_shaped_neuron_counts(
48-
self.config['dropout_shape'], 0, 0, 1000, self.config['num_groups']
49-
)
50-
51-
dropout_shape = [
52-
dropout / 1000 * self.config["max_dropout"] for dropout in dropout_shape
53-
]
50+
shape=self.config['dropout_shape'],
51+
in_feat=0,
52+
out_feat=0,
53+
max_neurons=self.config["max_dropout"],
54+
layer_count=self.config['num_groups'] + 1,
55+
)[:-1]
5456

5557
self.config.update(
5658
{"dropout_%d" % (i + 1): dropout for i, dropout in enumerate(dropout_shape)}

autoPyTorch/pipeline/components/setup/network_head/no_head.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class NoHead(NetworkHeadComponent):
2020
"""
2121

2222
def build_head(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> nn.Module:
23-
layers = [nn.Flatten()]
23+
layers = []
2424
in_features = np.prod(input_shape).item()
2525
out_features = np.prod(output_shape).item()
2626
layers.append(_activations[self.config["activation"]]())
@@ -34,8 +34,8 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[
3434
'shortname': 'NoHead',
3535
'name': 'NoHead',
3636
'handles_tabular': True,
37-
'handles_image': True,
38-
'handles_time_series': True,
37+
'handles_image': False,
38+
'handles_time_series': False,
3939
}
4040

4141
@staticmethod

0 commit comments

Comments
 (0)