@@ -129,7 +129,16 @@ def compute_mean_std(engine, batch):
129129 """
130130
131131 _state_dict_all_req_keys = ("epoch_length" , "max_epochs" )
132- _state_dict_one_of_opt_keys = ("iteration" , "epoch" )
132+ _state_dict_one_of_opt_keys = (
133+ (
134+ "iteration" ,
135+ "epoch" ,
136+ ),
137+ (
138+ "max_epochs" ,
139+ "max_iters" ,
140+ ),
141+ )
133142
134143 # Flag to disable engine._internal_run as generator feature for BC
135144 interrupt_resume_enabled = True
@@ -310,6 +319,7 @@ def execute_something():
310319 for e in event_name :
311320 self .add_event_handler (e , handler , * args , ** kwargs )
312321 return RemovableEventHandle (event_name , handler , self )
322+
313323 if isinstance (event_name , CallableEventWithFilter ) and event_name .filter is not None :
314324 event_filter = event_name .filter
315325 handler = self ._handler_wrapper (handler , event_name , event_filter )
@@ -332,6 +342,16 @@ def execute_something():
332342
333343 return RemovableEventHandle (event_name , handler , self )
334344
345+ @staticmethod
346+ def _assert_non_filtered_event (event_name : Any ) -> None :
347+ if (
348+ isinstance (event_name , CallableEventWithFilter )
349+ and event_name .filter != CallableEventWithFilter .default_event_filter
350+ ):
351+ raise TypeError (
352+ "Argument event_name should not be a filtered event, " "please use event without any event filtering"
353+ )
354+
335355 def has_event_handler (self , handler : Callable , event_name : Optional [Any ] = None ) -> bool :
336356 """Check if the specified event has the specified handler.
337357
@@ -675,7 +695,12 @@ def save_engine(_):
675695 a dictionary containing engine's state
676696
677697 """
678- keys : Tuple [str , ...] = self ._state_dict_all_req_keys + (self ._state_dict_one_of_opt_keys [0 ],)
698+ keys : Tuple [str , ...] = self ._state_dict_all_req_keys
699+ keys += ("iteration" ,)
700+ if self .state .max_epochs is not None :
701+ keys += ("max_epochs" ,)
702+ else :
703+ keys += ("max_iters" ,)
679704 keys += tuple (self ._state_dict_user_keys )
680705 return OrderedDict ([(k , getattr (self .state , k )) for k in keys ])
681706
@@ -728,6 +753,8 @@ def load_state_dict(self, state_dict: Mapping) -> None:
728753 f"Input state_dict: { state_dict } "
729754 )
730755 self .state .iteration = self .state .epoch_length * self .state .epoch
756+ self ._check_and_set_max_epochs (state_dict .get ("max_epochs" , None ))
757+ self ._check_and_set_max_iters (state_dict .get ("max_iters" , None ))
731758
732759 @staticmethod
733760 def _is_done (state : State ) -> bool :
@@ -864,12 +891,26 @@ def switch_batch(engine):
864891
865892 epoch_length = self ._get_data_length (data )
866893 if epoch_length is not None and epoch_length < 1 :
867- raise ValueError ("Input data has zero size. Please provide non-empty data" )
894+ raise ValueError (
895+ "Argument epoch_length is invalid. Please, either set a"
896+ " correct epoch_length value or check if input data has"
897+ " non-zero size."
898+ )
868899
869900 if max_iters is None :
870901 if max_epochs is None :
871902 max_epochs = 1
872903 else :
904+ if max_iters < 1 :
905+ raise ValueError ("Argument max_iters is invalid. Please, set a correct max_iters positive value" )
906+ if (self .state .max_iters is not None ) and max_iters <= self .state .iteration :
907+ raise ValueError (
908+ "Argument max_iters should be larger than the current iteration "
909+ f"defined in the state: { max_iters } vs { self .state .iteration } . "
910+ "Please, set engine.state.max_iters = None "
911+ "before calling engine.run() in order to restart the training from the beginning."
912+ )
913+ self .state .max_iters = max_iters
873914 if max_epochs is not None :
874915 raise ValueError (
875916 "Arguments max_iters and max_epochs are mutually exclusive."
@@ -932,6 +973,53 @@ def _setup_dataloader_iter(self) -> None:
932973 else :
933974 self ._dataloader_iter = iter (self .state .dataloader )
934975
976+ def _check_and_set_max_epochs (self , max_epochs : Optional [int ] = None ) -> None :
977+ if max_epochs is not None :
978+ if max_epochs < 1 :
979+ raise ValueError ("Argument max_epochs is invalid. Please, set a correct max_epochs positive value" )
980+ if self .state .max_epochs is not None and max_epochs <= self .state .epoch :
981+ raise ValueError (
982+ "Argument max_epochs should be larger than the current epoch "
983+ f"defined in the state: { max_epochs } vs { self .state .epoch } . "
984+ "Please, set engine.state.max_epochs = None "
985+ "before calling engine.run() in order to restart the training from the beginning."
986+ )
987+ self .state .max_epochs = max_epochs
988+
989+ def _check_and_set_max_iters (self , max_iters : Optional [int ] = None ) -> None :
990+ if max_iters is not None :
991+ if max_iters < 1 :
992+ raise ValueError ("Argument max_iters is invalid. Please, set a correct max_iters positive value" )
993+ if (self .state .max_iters is not None ) and max_iters <= self .state .iteration :
994+ raise ValueError (
995+ "Argument max_iters should be larger than the current iteration "
996+ f"defined in the state: { max_iters } vs { self .state .iteration } . "
997+ "Please, set engine.state.max_iters = None "
998+ "before calling engine.run() in order to restart the training from the beginning."
999+ )
1000+ self .state .max_iters = max_iters
1001+
1002+ def _check_and_set_epoch_length (self , data : Iterable , epoch_length : Optional [int ] = None ) -> None :
1003+ # Can't we accept a redefinition ?
1004+ if self .state .epoch_length is not None :
1005+ if epoch_length is not None :
1006+ if epoch_length != self .state .epoch_length :
1007+ raise ValueError (
1008+ "Argument epoch_length should be same as in the state, "
1009+ f"but given { epoch_length } vs { self .state .epoch_length } "
1010+ )
1011+ else :
1012+ if epoch_length is None :
1013+ epoch_length = self ._get_data_length (data )
1014+
1015+ if epoch_length is not None and epoch_length < 1 :
1016+ raise ValueError (
1017+ "Argument epoch_length is invalid. Please, either set a correct epoch_length value or "
1018+ "check if input data has non-zero size."
1019+ )
1020+
1021+ self .state .epoch_length = epoch_length
1022+
9351023 def _setup_engine (self ) -> None :
9361024 self ._setup_dataloader_iter ()
9371025
@@ -1064,6 +1152,13 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
10641152 self .state .epoch_length = iter_counter
10651153 if self .state .max_iters is not None :
10661154 self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
1155+ # Warn but will continue until max iters is reached
1156+ warnings .warn (
1157+ "Data iterator can not provide data anymore but required total number of "
1158+ "iterations to run is not reached. "
1159+ f"Current iteration: { self .state .iteration } vs Total iterations to run :"
1160+ f" { self .state .max_iters } "
1161+ )
10671162 break
10681163
10691164 # Should exit while loop if we can not iterate
@@ -1106,7 +1201,13 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
11061201
11071202 if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
11081203 self .should_terminate = True
1109- raise _EngineTerminateException ()
1204+ warnings .warn (
1205+ "Data iterator can not provide data anymore but required total number of "
1206+ "iterations to run is not reached. "
1207+ f"Current iteration: { self .state .iteration } vs Total iterations to run : ? total_iters"
1208+ )
1209+ break
1210+ # raise _EngineTerminateException()
11101211
11111212 except _EngineTerminateSingleEpochException :
11121213 self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
@@ -1231,6 +1332,13 @@ def _run_once_on_dataset_legacy(self) -> float:
12311332 self .state .epoch_length = iter_counter
12321333 if self .state .max_iters is not None :
12331334 self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
1335+ # Warn but will continue until max iters is reached
1336+ warnings .warn (
1337+ "Data iterator can not provide data anymore but required total number of "
1338+ "iterations to run is not reached. "
1339+ f"Current iteration: { self .state .iteration } vs Total iterations to run :"
1340+ f" { self .state .max_iters } "
1341+ )
12341342 break
12351343
12361344 # Should exit while loop if we can not iterate
@@ -1291,6 +1399,16 @@ def _run_once_on_dataset_legacy(self) -> float:
12911399
12921400 return time .time () - start_time
12931401
1402+ def debug (self , enabled : bool = True ) -> None :
1403+ """Enables/disables engine's logging debug mode"""
1404+ from ignite .utils import setup_logger
1405+
1406+ if enabled :
1407+ setattr (self , "_stored_logger" , self .logger )
1408+ self .logger = setup_logger (level = logging .DEBUG )
1409+ elif hasattr (self , "_stored_logger" ):
1410+ self .logger = getattr (self , "_stored_logger" )
1411+
12941412
12951413def _get_none_data_iter (size : int ) -> Iterator :
12961414 # Sized iterator for data as None
0 commit comments