|
1 | 1 | import collections |
2 | 2 | import logging.handlers |
3 | 3 | import os |
| 4 | +import shutil |
4 | 5 | import tempfile |
5 | 6 | import time |
6 | 7 | from typing import Any, Dict, List, Optional, Tuple, cast |
@@ -351,13 +352,35 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic |
351 | 352 | ) |
352 | 353 | self.save_model_for_ensemble() |
353 | 354 |
|
| 355 | + # As training have finished, load the best weight |
| 356 | + if self.checkpoint_dir is not None: |
| 357 | + self._load_best_weights_and_clean_checkpoints(X) |
| 358 | + |
354 | 359 | self.logger.info(f"Finished training with {self.run_summary.repr_last_epoch()}") |
355 | 360 |
|
356 | 361 | # Tag as fitted |
357 | 362 | self.fitted_ = True |
358 | 363 |
|
359 | 364 | return self |
360 | 365 |
|
| 366 | + def _load_best_weights_and_clean_checkpoints(self, X: Dict[str, Any]) -> None: |
| 367 | + """ |
| 368 | + Load the best model until the last epoch and delete all the files for checkpoints. |
| 369 | +
|
| 370 | + Args: |
| 371 | + X (Dict[str, Any]): Dependencies needed by current component to perform fit |
| 372 | + """ |
| 373 | + assert self.checkpoint_dir is not None # mypy |
| 374 | + assert self.run_summary is not None # mypy |
| 375 | + |
| 376 | + best_path = os.path.join(self.checkpoint_dir, 'best.pth') |
| 377 | + self.logger.debug(f" Early stopped model {X['num_run']} on epoch {self.run_summary.get_best_epoch()}") |
| 378 | + # We will stop the training. Load the last best performing weights |
| 379 | + X['network'].load_state_dict(torch.load(best_path)) |
| 380 | + |
| 381 | + # Clean the temp dir |
| 382 | + shutil.rmtree(self.checkpoint_dir) |
| 383 | + |
361 | 384 | def early_stop_handler(self, X: Dict[str, Any]) -> bool: |
362 | 385 | """ |
363 | 386 | If early stopping is enabled, this procedure stops the training after a |
@@ -387,16 +410,7 @@ def early_stop_handler(self, X: Dict[str, Any]) -> bool: |
387 | 410 | if epochs_since_best == 0: |
388 | 411 | torch.save(X['network'].state_dict(), best_path) |
389 | 412 |
|
390 | | - if epochs_since_best > X['early_stopping']: |
391 | | - self.logger.debug(f" Early stopped model {X['num_run']} on epoch {self.run_summary.get_best_epoch()}") |
392 | | - # We will stop the training. Load the last best performing weights |
393 | | - X['network'].load_state_dict(torch.load(best_path)) |
394 | | - |
395 | | - # Let the tempfile module clean the temp dir |
396 | | - self.checkpoint_dir = None |
397 | | - return True |
398 | | - |
399 | | - return False |
| 413 | + return epochs_since_best > cast(int, X['early_stopping']) |
400 | 414 |
|
401 | 415 | def eval_valid_each_epoch(self, X: Dict[str, Any]) -> bool: |
402 | 416 | """ |
|
0 commit comments