@@ -48,7 +48,7 @@ def supervised_training_step(
4848 device : Optional [Union [str , torch .device ]] = None ,
4949 non_blocking : bool = False ,
5050 prepare_batch : Callable = _prepare_batch ,
51- output_transform : Callable = lambda x , y , y_pred , loss : loss .item (),
51+ output_transform : Callable [[ Any , Any , Any , torch . Tensor ], Any ] = lambda x , y , y_pred , loss : loss .item (),
5252 gradient_accumulation_steps : int = 1 ,
5353) -> Callable :
5454 """Factory function for supervised training.
@@ -117,7 +117,7 @@ def supervised_training_step_amp(
117117 device : Optional [Union [str , torch .device ]] = None ,
118118 non_blocking : bool = False ,
119119 prepare_batch : Callable = _prepare_batch ,
120- output_transform : Callable = lambda x , y , y_pred , loss : loss .item (),
120+ output_transform : Callable [[ Any , Any , Any , torch . Tensor ], Any ] = lambda x , y , y_pred , loss : loss .item (),
121121 scaler : Optional ["torch.cuda.amp.GradScaler" ] = None ,
122122 gradient_accumulation_steps : int = 1 ,
123123) -> Callable :
@@ -203,7 +203,7 @@ def supervised_training_step_apex(
203203 device : Optional [Union [str , torch .device ]] = None ,
204204 non_blocking : bool = False ,
205205 prepare_batch : Callable = _prepare_batch ,
206- output_transform : Callable = lambda x , y , y_pred , loss : loss .item (),
206+ output_transform : Callable [[ Any , Any , Any , torch . Tensor ], Any ] = lambda x , y , y_pred , loss : loss .item (),
207207 gradient_accumulation_steps : int = 1 ,
208208) -> Callable :
209209 """Factory function for supervised training using apex.
@@ -279,7 +279,7 @@ def supervised_training_step_tpu(
279279 device : Optional [Union [str , torch .device ]] = None ,
280280 non_blocking : bool = False ,
281281 prepare_batch : Callable = _prepare_batch ,
282- output_transform : Callable = lambda x , y , y_pred , loss : loss .item (),
282+ output_transform : Callable [[ Any , Any , Any , torch . Tensor ], Any ] = lambda x , y , y_pred , loss : loss .item (),
283283 gradient_accumulation_steps : int = 1 ,
284284) -> Callable :
285285 """Factory function for supervised training using ``torch_xla``.
@@ -381,7 +381,7 @@ def create_supervised_trainer(
381381 device : Optional [Union [str , torch .device ]] = None ,
382382 non_blocking : bool = False ,
383383 prepare_batch : Callable = _prepare_batch ,
384- output_transform : Callable = lambda x , y , y_pred , loss : loss .item (),
384+ output_transform : Callable [[ Any , Any , Any , torch . Tensor ], Any ] = lambda x , y , y_pred , loss : loss .item (),
385385 deterministic : bool = False ,
386386 amp_mode : Optional [str ] = None ,
387387 scaler : Union [bool , "torch.cuda.amp.GradScaler" ] = False ,
@@ -418,6 +418,50 @@ def create_supervised_trainer(
418418 Returns:
419419 a trainer engine with supervised update function.
420420
421+ Examples:
422+
423+ Create a trainer
424+
425+ .. code-block:: python
426+
427+ from ignite.engine import create_supervised_trainer
428+ from ignite.utils import convert_tensor
429+ from ignite.contrib.handlers.tqdm_logger import ProgressBar
430+
431+ model = ...
432+ loss = ...
433+ optimizer = ...
434+ dataloader = ...
435+
436+ def prepare_batch_fn(batch, device, non_blocking):
437+ x = ... # get x from batch
438+ y = ... # get y from batch
439+
440+ # return a tuple of (x, y) that can be directly runned as
441+ # `loss_fn(model(x), y)`
442+ return (
443+ convert_tensor(x, device, non_blocking),
444+ convert_tensor(y, device, non_blocking)
445+ )
446+
447+ def output_transform_fn(x, y, y_pred, loss):
448+ # return only the loss is actually the default behavior for
449+ # trainer engine, but you can return anything you want
450+ return loss.item()
451+
452+ trainer = create_supervised_trainer(
453+ model,
454+ optimizer,
455+ loss,
456+ prepare_batch=prepare_batch_fn,
457+ output_transform=output_transform_fn
458+ )
459+
460+ pbar = ProgressBar()
461+ pbar.attach(trainer, output_transform=lambda x: {"loss": x})
462+
463+ trainer.run(dataloader, max_epochs=5)
464+
421465 Note:
422466 If ``scaler`` is True, GradScaler instance will be created internally and trainer state has attribute named
423467 ``scaler`` for that instance and can be used for saving and loading.
@@ -513,7 +557,7 @@ def supervised_evaluation_step(
513557 device : Optional [Union [str , torch .device ]] = None ,
514558 non_blocking : bool = False ,
515559 prepare_batch : Callable = _prepare_batch ,
516- output_transform : Callable = lambda x , y , y_pred : (y_pred , y ),
560+ output_transform : Callable [[ Any , Any , Any ], Any ] = lambda x , y , y_pred : (y_pred , y ),
517561) -> Callable :
518562 """
519563 Factory function for supervised evaluation.
@@ -561,7 +605,7 @@ def supervised_evaluation_step_amp(
561605 device : Optional [Union [str , torch .device ]] = None ,
562606 non_blocking : bool = False ,
563607 prepare_batch : Callable = _prepare_batch ,
564- output_transform : Callable = lambda x , y , y_pred : (y_pred , y ),
608+ output_transform : Callable [[ Any , Any , Any ], Any ] = lambda x , y , y_pred : (y_pred , y ),
565609) -> Callable :
566610 """
567611 Factory function for supervised evaluation using ``torch.cuda.amp``.
@@ -615,7 +659,7 @@ def create_supervised_evaluator(
615659 device : Optional [Union [str , torch .device ]] = None ,
616660 non_blocking : bool = False ,
617661 prepare_batch : Callable = _prepare_batch ,
618- output_transform : Callable = lambda x , y , y_pred : (y_pred , y ),
662+ output_transform : Callable [[ Any , Any , Any ], Any ] = lambda x , y , y_pred : (y_pred , y ),
619663 amp_mode : Optional [str ] = None ,
620664) -> Engine :
621665 """
0 commit comments