11import collections
22import logging .handlers
33import os
4+ import tempfile
45import time
56from typing import Any , Dict , List , Optional , Tuple , cast
67
@@ -66,7 +67,7 @@ def __init__(self,
6667 self .writer = None # type: Optional[SummaryWriter]
6768 self ._fit_requirements : Optional [List [FitRequirement ]] = [
6869 FitRequirement ("lr_scheduler" , (_LRScheduler ,), user_defined = False , dataset_property = False ),
69- FitRequirement ("job_id " , (str ,), user_defined = False , dataset_property = False ),
70+ FitRequirement ("num_run " , (int ,), user_defined = False , dataset_property = False ),
7071 FitRequirement (
7172 "optimizer" , (Optimizer ,), user_defined = False , dataset_property = False ),
7273 FitRequirement ("train_data_loader" ,
@@ -75,6 +76,7 @@ def __init__(self,
7576 FitRequirement ("val_data_loader" ,
7677 (torch .utils .data .DataLoader ,),
7778 user_defined = False , dataset_property = False )]
79+ self .checkpoint_dir = None # type: Optional[str]
7880
7981 def get_fit_requirements (self ) -> Optional [List [FitRequirement ]]:
8082 return self ._fit_requirements
@@ -185,7 +187,7 @@ def fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> autoPyTorchCom
185187
186188 # Setup the logger
187189 self .logger = get_named_client_logger (
188- name = X ['job_id ' ],
190+ name = X ['num_run ' ],
189191 # Log to a user provided port else to the default logging port
190192 port = X ['logger_port'
191193 ] if 'logger_port' in X else logging .handlers .DEFAULT_TCP_LOGGING_PORT ,
@@ -369,8 +371,29 @@ def early_stop_handler(self, X: Dict[str, Any]) -> bool:
369371 bool: If true, training should be stopped
370372 """
371373 assert self .run_summary is not None
372- epochs_since_best = self .run_summary .get_best_epoch () - self .run_summary .get_last_epoch ()
374+
375+ # Allow to disable early stopping
376+ if X ['early_stopping' ] is None or X ['early_stopping' ] < 0 :
377+ return False
378+
379+ # Store the best weights seen so far:
380+ if self .checkpoint_dir is None :
381+ self .checkpoint_dir = tempfile .mkdtemp (dir = X ['backend' ].temporary_directory )
382+
383+ epochs_since_best = self .run_summary .get_last_epoch () - self .run_summary .get_best_epoch ()
384+
385+ # Save the checkpoint if there is a new best epoch
386+ best_path = os .path .join (self .checkpoint_dir , 'best.pth' )
387+ if epochs_since_best == 0 :
388+ torch .save (X ['network' ].state_dict (), best_path )
389+
373390 if epochs_since_best > X ['early_stopping' ]:
391+ self .logger .debug (f" Early stopped model { X ['num_run' ]} on epoch { self .run_summary .get_best_epoch ()} " )
392+ # We will stop the training. Load the last best performing weights
393+ X ['network' ].load_state_dict (torch .load (best_path ))
394+
395+ # Let the tempfile module clean the temp dir
396+ self .checkpoint_dir = None
374397 return True
375398
376399 return False
@@ -458,8 +481,8 @@ def check_requirements(self, X: Dict[str, Any], y: Any = None) -> None:
458481 X ['budget_type' ]
459482 ))
460483
461- if 'job_id ' not in X :
462- raise ValueError ('To fit a trainer, expected fit dictionary to have a job_id ' )
484+ if 'num_run ' not in X :
485+ raise ValueError ('To fit a trainer, expected fit dictionary to have a num_run ' )
463486
464487 for config_option in ["torch_num_threads" , 'device' ]:
465488 if config_option not in X :
0 commit comments