|
14 | 14 | import sklearn.datasets |
15 | 15 | from sklearn.ensemble import VotingClassifier, VotingRegressor |
16 | 16 |
|
| 17 | +from smac.runhistory.runhistory import RunHistory |
| 18 | + |
17 | 19 | import torch |
18 | 20 |
|
19 | 21 | from autoPyTorch.api.tabular_classification import TabularClassificationTask |
@@ -104,17 +106,20 @@ def test_tabular_classification(openml_id, resampling_strategy, backend): |
104 | 106 |
|
105 | 107 | # Search for an existing run key in disc. A individual model might have |
106 | 108 | # a timeout and hence was not written to disc |
| 109 | + successful_num_run = None |
| 110 | + SUCCESS = False |
107 | 111 | for i, (run_key, value) in enumerate(estimator.run_history.data.items()): |
108 | | - if 'SUCCESS' not in str(value.status): |
109 | | - continue |
110 | | - |
111 | | - run_key_model_run_dir = estimator._backend.get_numrun_directory( |
112 | | - estimator.seed, run_key.config_id + 1, run_key.budget) |
113 | | - if os.path.exists(run_key_model_run_dir): |
114 | | - # Runkey config id is different from the num_run |
115 | | - # more specifically num_run = config_id + 1(dummy) |
| 112 | + if 'SUCCESS' in str(value.status): |
| 113 | + run_key_model_run_dir = estimator._backend.get_numrun_directory( |
| 114 | + estimator.seed, run_key.config_id + 1, run_key.budget) |
116 | 115 | successful_num_run = run_key.config_id + 1 |
117 | | - break |
| 116 | + if os.path.exists(run_key_model_run_dir): |
| 117 | + # Runkey config id is different from the num_run |
| 118 | + # more specifically num_run = config_id + 1(dummy) |
| 119 | + SUCCESS = True |
| 120 | + break |
| 121 | + |
| 122 | + assert SUCCESS, f"Successful run was not properly saved for num_run: {successful_num_run}" |
118 | 123 |
|
119 | 124 | if resampling_strategy == HoldoutValTypes.holdout_validation: |
120 | 125 | model_file = os.path.join(run_key_model_run_dir, |
@@ -272,17 +277,20 @@ def test_tabular_regression(openml_name, resampling_strategy, backend): |
272 | 277 |
|
273 | 278 | # Search for an existing run key in disc. A individual model might have |
274 | 279 | # a timeout and hence was not written to disc |
| 280 | + successful_num_run = None |
| 281 | + SUCCESS = False |
275 | 282 | for i, (run_key, value) in enumerate(estimator.run_history.data.items()): |
276 | | - if 'SUCCESS' not in str(value.status): |
277 | | - continue |
278 | | - |
279 | | - run_key_model_run_dir = estimator._backend.get_numrun_directory( |
280 | | - estimator.seed, run_key.config_id + 1, run_key.budget) |
281 | | - if os.path.exists(run_key_model_run_dir): |
282 | | - # Runkey config id is different from the num_run |
283 | | - # more specifically num_run = config_id + 1(dummy) |
| 283 | + if 'SUCCESS' in str(value.status): |
| 284 | + run_key_model_run_dir = estimator._backend.get_numrun_directory( |
| 285 | + estimator.seed, run_key.config_id + 1, run_key.budget) |
284 | 286 | successful_num_run = run_key.config_id + 1 |
285 | | - break |
| 287 | + if os.path.exists(run_key_model_run_dir): |
| 288 | + # Runkey config id is different from the num_run |
| 289 | + # more specifically num_run = config_id + 1(dummy) |
| 290 | + SUCCESS = True |
| 291 | + break |
| 292 | + |
| 293 | + assert SUCCESS, f"Successful run was not properly saved for num_run: {successful_num_run}" |
286 | 294 |
|
287 | 295 | if resampling_strategy == HoldoutValTypes.holdout_validation: |
288 | 296 | model_file = os.path.join(run_key_model_run_dir, |
@@ -384,7 +392,7 @@ def test_tabular_input_support(openml_id, backend): |
384 | 392 | estimator._do_dummy_prediction = unittest.mock.MagicMock() |
385 | 393 |
|
386 | 394 | with unittest.mock.patch.object(AutoMLSMBO, 'run_smbo') as AutoMLSMBOMock: |
387 | | - AutoMLSMBOMock.return_value = ({}, {}, 'epochs') |
| 395 | + AutoMLSMBOMock.return_value = (RunHistory(), {}, 'epochs') |
388 | 396 | estimator.search( |
389 | 397 | X_train=X_train, y_train=y_train, |
390 | 398 | X_test=X_test, y_test=y_test, |
|
0 commit comments