-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x
Description
Bug description
When training ResNet18, which contains a BatchNorm2d layer and setting max_epochs to -1 for infinite training, the trainer.fit ends early due to increment of max_epochs in StochasticWeightAveraging.
The trainer prints the following message after the sanity check and discontinues the training.
`Trainer.fit` stopped: `max_epochs=0` reached.
The logged max_epochs is different from our configuration, which clearly shows an unexpected increment.
The following condition is met when the trained model contains _BatchNorm in stochastic_weight_avg.py
self._max_epochs = trainer.max_epochs
if self._model_contains_batch_norm:
# virtually increase max_epochs to perform batch norm update on latest epoch.
assert trainer.fit_loop.max_epochs is not None
trainer.fit_loop.max_epochs += 1After the max_epochs is set to 0, the fit loop stops when checking the following condition in fit_loop.py
# `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.
# we use it here because the checkpoint data won't have `completed` increased yet
assert isinstance(self.max_epochs, int)
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)
if stop_epochs:
# in case they are not equal, override so `trainer.current_epoch` has the expected value
self.epoch_progress.current.completed = self.epoch_progress.current.processed
rank_zero_info(f"`Trainer.fit` stopped: `max_epochs={self.max_epochs!r}` reached.")
return TrueWhat version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.5.0): 2.5.1
#- PyTorch Version (e.g., 2.5): 2.6.0
#- Python version (e.g., 3.12): 3.12.10
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: 11.8
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source): poetry
More info
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x