Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Learning rate scheduler is stepped at the end of epoch when `on_train_batch_start` returns -1 ([[#21296](https://github.com/Lightning-AI/pytorch-lightning/issues/21296)]).


---
Expand Down
41 changes: 24 additions & 17 deletions src/lightning/pytorch/loops/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,30 +325,33 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
trainer._logger_connector.on_batch_start(batch)

batch_output: _BATCH_OUTPUTS_TYPE = None # for mypy
should_skip_rest_of_epoch = False

if batch is None and not using_dataloader_iter:
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
else:
# hook
call._call_callback_hooks(trainer, "on_train_batch_start", batch, batch_idx)
response = call._call_lightning_module_hook(trainer, "on_train_batch_start", batch, batch_idx)
call._call_strategy_hook(trainer, "on_train_batch_start", batch, batch_idx)
if response == -1:
self.batch_progress.increment_processed()
raise StopIteration

self.batch_progress.increment_started()

kwargs = (
self._build_kwargs(OrderedDict(), batch, batch_idx)
if not using_dataloader_iter
else OrderedDict(any=dataloader_iter)
)
with trainer.profiler.profile("run_training_batch"):
if trainer.lightning_module.automatic_optimization:
# in automatic optimization, there can only be one optimizer
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
else:
batch_output = self.manual_optimization.run(kwargs)
should_skip_rest_of_epoch = response == -1
# Signal this is the last batch for the current epoch
if should_skip_rest_of_epoch:
self.batch_progress.increment_by(0, is_last_batch=True)
else:
self.batch_progress.increment_started()

kwargs = (
self._build_kwargs(OrderedDict(), batch, batch_idx)
if not using_dataloader_iter
else OrderedDict(any=dataloader_iter)
)
with trainer.profiler.profile("run_training_batch"):
if trainer.lightning_module.automatic_optimization:
# in automatic optimization, there can only be one optimizer
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
else:
batch_output = self.manual_optimization.run(kwargs)

self.batch_progress.increment_processed()

Expand All @@ -358,6 +361,10 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
if self._num_ready_batches_reached():
self.update_lr_schedulers("epoch", update_plateau_schedulers=False)

if should_skip_rest_of_epoch:
# Only raise StopIteration now so that the training epoch loop can finish
raise StopIteration

if using_dataloader_iter:
# update the hook kwargs now that the step method might have consumed the iterator
batch = data_fetcher._batch
Expand Down
Loading