Skip to content

Trainer.fit completes early due to an unexpected increment in StochasticWeightAveraging when training a model with _BatchNorm and max_epochs=-1 #21347

@3waffel

Description

@3waffel

Bug description

When training ResNet18, which contains a BatchNorm2d layer and setting max_epochs to -1 for infinite training, the trainer.fit ends early due to increment of max_epochs in StochasticWeightAveraging.

The trainer prints the following message after the sanity check and discontinues the training.

`Trainer.fit` stopped: `max_epochs=0` reached.

The logged max_epochs is different from our configuration, which clearly shows an unexpected increment.


The following condition is met when the trained model contains _BatchNorm in stochastic_weight_avg.py

        self._max_epochs = trainer.max_epochs
        if self._model_contains_batch_norm:
            # virtually increase max_epochs to perform batch norm update on latest epoch.
            assert trainer.fit_loop.max_epochs is not None
            trainer.fit_loop.max_epochs += 1

After the max_epochs is set to 0, the fit loop stops when checking the following condition in fit_loop.py

        # `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.
        # we use it here because the checkpoint data won't have `completed` increased yet
        assert isinstance(self.max_epochs, int)
        stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)
        if stop_epochs:
            # in case they are not equal, override so `trainer.current_epoch` has the expected value
            self.epoch_progress.current.completed = self.epoch_progress.current.processed
            rank_zero_info(f"`Trainer.fit` stopped: `max_epochs={self.max_epochs!r}` reached.")
            return True

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0): 2.5.1
#- PyTorch Version (e.g., 2.5): 2.6.0
#- Python version (e.g., 3.12): 3.12.10
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: 11.8
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source): poetry

More info

No response

cc @ethanwharris

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions