Skip to content

Commit 1a9cd71

Browse files
Make sure the performance of pipeline is at least 0.8
1 parent ddc0f3d commit 1a9cd71

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

autoPyTorch/pipeline/components/setup/lr_scheduler/base_scheduler_choice.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,14 @@ def get_hyperparameter_search_space(
138138
raise ValueError("No scheduler found")
139139

140140
if default is None:
141-
defaults = ['no_LRScheduler',
142-
'LambdaLR',
143-
'StepLR',
144-
'ExponentialLR',
145-
'CosineAnnealingLR',
146-
'ReduceLROnPlateau'
147-
]
141+
defaults = [
142+
'ReduceLROnPlateau',
143+
'CosineAnnealingLR',
144+
'no_LRScheduler',
145+
'LambdaLR',
146+
'StepLR',
147+
'ExponentialLR',
148+
]
148149
for default_ in defaults:
149150
if default_ in available_schedulers:
150151
default = default_

test/test_pipeline/components/test_setup_networks.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ def test_pipeline_fit(self, fit_dictionary, backbone, head):
3232
assert backbone == config.get('network_backbone:__choice__', None)
3333
assert head == config.get('network_head:__choice__', None)
3434
pipeline.set_hyperparameters(config)
35+
36+
# Need more epochs to make sure validation performance is met
37+
fit_dictionary['epochs'] = 100
38+
3539
pipeline.fit(fit_dictionary)
3640

3741
# To make sure we fitted the model, there should be a
@@ -44,9 +48,10 @@ def test_pipeline_fit(self, fit_dictionary, backbone, head):
4448
assert run_summary.total_parameter_count > 0
4549
assert 'accuracy' in run_summary.performance_tracker['train_metrics'][1]
4650

47-
# Commented out the next line as some pipelines are not
48-
# achieving this accuracy with default configuration and 10 epochs
49-
# To be added once we fix the search space
50-
# assert run_summary.performance_tracker['val_metrics'][fit_dictionary['epochs']]['accuracy'] >= 0.8
51+
# Make sure default pipeline achieves a good score for dummy datasets
52+
assert run_summary.performance_tracker[
53+
'val_metrics'
54+
][fit_dictionary['epochs']]['accuracy'] >= 0.8, run_summary.performance_tracker['val_metrics']
55+
5156
# Make sure a network was fit
5257
assert isinstance(pipeline.named_steps['network'].get_network(), torch.nn.Module)

0 commit comments

Comments
 (0)