@@ -332,6 +332,16 @@ def execute_something():
332332
333333 return RemovableEventHandle (event_name , handler , self )
334334
335+ @staticmethod
336+ def _assert_non_filtered_event (event_name : Any ) -> None :
337+ if (
338+ isinstance (event_name , CallableEventWithFilter )
339+ and event_name .filter != CallableEventWithFilter .default_event_filter
340+ ):
341+ raise TypeError (
342+ "Argument event_name should not be a filtered event, " "please use event without any event filtering"
343+ )
344+
335345 def has_event_handler (self , handler : Callable , event_name : Optional [Any ] = None ) -> bool :
336346 """Check if the specified event has the specified handler.
337347
@@ -932,6 +942,53 @@ def _setup_dataloader_iter(self) -> None:
932942 else :
933943 self ._dataloader_iter = iter (self .state .dataloader )
934944
945+ def _check_and_set_max_epochs (self , max_epochs : Optional [int ] = None ) -> None :
946+ if max_epochs is not None :
947+ if max_epochs < 1 :
948+ raise ValueError ("Argument max_epochs is invalid. Please, set a correct max_epochs positive value" )
949+ if self .state .max_epochs is not None and max_epochs <= self .state .epoch :
950+ raise ValueError (
951+ "Argument max_epochs should be larger than the current epoch "
952+ f"defined in the state: { max_epochs } vs { self .state .epoch } . "
953+ "Please, set engine.state.max_epochs = None "
954+ "before calling engine.run() in order to restart the training from the beginning."
955+ )
956+ self .state .max_epochs = max_epochs
957+
958+ def _check_and_set_max_iters (self , max_iters : Optional [int ] = None ) -> None :
959+ if max_iters is not None :
960+ if max_iters < 1 :
961+ raise ValueError ("Argument max_iters is invalid. Please, set a correct max_iters positive value" )
962+ if (self .state .max_iters is not None ) and max_iters <= self .state .iteration :
963+ raise ValueError (
964+ "Argument max_iters should be larger than the current iteration "
965+ f"defined in the state: { max_iters } vs { self .state .iteration } . "
966+ "Please, set engine.state.max_iters = None "
967+ "before calling engine.run() in order to restart the training from the beginning."
968+ )
969+ self .state .max_iters = max_iters
970+
971+ def _check_and_set_epoch_length (self , data : Iterable , epoch_length : Optional [int ] = None ) -> None :
972+ # Can't we accept a redefinition ?
973+ if self .state .epoch_length is not None :
974+ if epoch_length is not None :
975+ if epoch_length != self .state .epoch_length :
976+ raise ValueError (
977+ "Argument epoch_length should be same as in the state, "
978+ f"but given { epoch_length } vs { self .state .epoch_length } "
979+ )
980+ else :
981+ if epoch_length is None :
982+ epoch_length = self ._get_data_length (data )
983+
984+ if epoch_length is not None and epoch_length < 1 :
985+ raise ValueError (
986+ "Argument epoch_length is invalid. Please, either set a correct epoch_length value or "
987+ "check if input data has non-zero size."
988+ )
989+
990+ self .state .epoch_length = epoch_length
991+
935992 def _setup_engine (self ) -> None :
936993 self ._setup_dataloader_iter ()
937994
@@ -1291,6 +1348,16 @@ def _run_once_on_dataset_legacy(self) -> float:
12911348
12921349 return time .time () - start_time
12931350
1351+ def debug (self , enabled : bool = True ) -> None :
1352+ """Enables/disables engine's logging debug mode"""
1353+ from ignite .utils import setup_logger
1354+
1355+ if enabled :
1356+ setattr (self , "_stored_logger" , self .logger )
1357+ self .logger = setup_logger (level = logging .DEBUG )
1358+ elif hasattr (self , "_stored_logger" ):
1359+ self .logger = getattr (self , "_stored_logger" )
1360+
12941361
12951362def _get_none_data_iter (size : int ) -> Iterator :
12961363 # Sized iterator for data as None
0 commit comments