diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index c75c82127afac..d982c131171ff 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -25,6 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236)) +- Added `EMAWeightAveraging` callback that wraps Lightning's `WeightAveraging` class ([#21260](https://github.com/Lightning-AI/pytorch-lightning/pull/21260)) + + ### Changed - Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise ([#20896](https://github.com/Lightning-AI/pytorch-lightning/pull/20896)) diff --git a/src/lightning/pytorch/callbacks/__init__.py b/src/lightning/pytorch/callbacks/__init__.py index d0ffb7b6a990c..dd96c045d8366 100644 --- a/src/lightning/pytorch/callbacks/__init__.py +++ b/src/lightning/pytorch/callbacks/__init__.py @@ -32,7 +32,7 @@ from lightning.pytorch.callbacks.stochastic_weight_avg import StochasticWeightAveraging from lightning.pytorch.callbacks.throughput_monitor import ThroughputMonitor from lightning.pytorch.callbacks.timer import Timer -from lightning.pytorch.callbacks.weight_averaging import WeightAveraging +from lightning.pytorch.callbacks.weight_averaging import EMAWeightAveraging, WeightAveraging __all__ = [ "BackboneFinetuning", @@ -59,5 +59,6 @@ "ThroughputMonitor", "Timer", "TQDMProgressBar", + "EMAWeightAveraging", "WeightAveraging", ] diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py index f9b8d64eae6a5..0640efed3d87b 100644 --- a/src/lightning/pytorch/callbacks/weight_averaging.py +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -21,7 +21,7 @@ from typing import Any, Optional, Union import torch -from torch.optim.swa_utils import AveragedModel +from torch.optim.swa_utils import AveragedModel, get_ema_avg_fn from typing_extensions import override import lightning.pytorch as pl @@ -361,3 +361,55 @@ def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None: current_params = itertools.chain(pl_module.parameters(), pl_module.buffers()) for average_param, current_param in zip(average_params, current_params): current_param.data.copy_(average_param.data) + + +class EMAWeightAveraging(WeightAveraging): + """Exponential Moving Average (EMA) Weight Averaging callback.""" + + def __init__( + self, + device: Optional[Union[torch.device, str, int]] = None, + use_buffers: bool = True, + decay: float = 0.999, + update_every_n_steps: int = 1, + update_starting_at_step: Optional[int] = None, + update_starting_at_epoch: Optional[int] = None, + **kwargs: Any, + ): + super().__init__( + device=device, + use_buffers=use_buffers, + **kwargs, + avg_fn=get_ema_avg_fn(decay=decay), + ) + + self.update_every_n_steps = update_every_n_steps + self.update_starting_at_step = update_starting_at_step + self.update_starting_at_epoch = update_starting_at_epoch + + def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None) -> bool: + """Decide when to update the model weights. + + Args: + step_idx: The current step index. + epoch_idx: The current epoch index. + Returns: + bool: True if the model weights should be updated, False otherwise. + + """ + if step_idx is not None: + # Check step-based conditions only if we have a valid step_idx + meets_step_requirement = self.update_starting_at_step is None or step_idx >= self.update_starting_at_step + meets_step_frequency = self.update_every_n_steps > 0 and step_idx % self.update_every_n_steps == 0 + if meets_step_requirement and meets_step_frequency: + return True + + if epoch_idx is not None: + # Check epoch-based condition only if we specify one + meets_epoch_requirement = ( + self.update_starting_at_epoch is not None and epoch_idx >= self.update_starting_at_epoch + ) + if meets_epoch_requirement: + return True + + return False diff --git a/tests/tests_pytorch/callbacks/test_weight_averaging.py b/tests/tests_pytorch/callbacks/test_weight_averaging.py index ec230b2fd6c97..cfb066f023af0 100644 --- a/tests/tests_pytorch/callbacks/test_weight_averaging.py +++ b/tests/tests_pytorch/callbacks/test_weight_averaging.py @@ -23,7 +23,7 @@ from torch.utils.data import DataLoader, Dataset from lightning.pytorch import LightningModule, Trainer -from lightning.pytorch.callbacks import WeightAveraging +from lightning.pytorch.callbacks import EMAWeightAveraging, WeightAveraging from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from tests_pytorch.helpers.runif import RunIf @@ -329,3 +329,123 @@ def _train_and_resume(model: TestModel, dataset: Dataset, tmp_path: str, devices callback = EMATestCallback(devices=devices) _train(model, dataset, tmp_path, callback, devices=devices, checkpoint_path=checkpoint_path, **kwargs) return model + + +@pytest.mark.parametrize( + ("strategy", "accelerator", "devices"), + [ + ("auto", "cpu", 1), + pytest.param("auto", "gpu", 1, marks=RunIf(min_cuda_gpus=1)), + ], +) +def test_ema_weight_averaging(tmp_path, strategy, accelerator, devices): + """Test EMAWeightAveraging callback with various update configurations.""" + model = TestModel() + dataset = RandomDataset(32, 32) + + # Test with default settings (update every step) + callback = EMAWeightAveraging(decay=0.999, update_every_n_steps=1) + _train(model, dataset, tmp_path, callback, strategy=strategy, accelerator=accelerator, devices=devices) + + # Verify the average model was created and updated + assert callback._average_model is not None + assert callback._average_model.n_averaged > 0 + + +def test_ema_weight_averaging_step_frequency(tmp_path): + """Test EMAWeightAveraging with custom step update frequency.""" + model = TestModel() + dataset = RandomDataset(32, 32) + + # Update every 5 steps + callback = EMAWeightAveraging(decay=0.95, update_every_n_steps=5) + _train(model, dataset, tmp_path, callback) + + assert callback._average_model is not None + + +def test_ema_weight_averaging_starting_step(tmp_path): + """Test EMAWeightAveraging with delayed start based on steps.""" + model = TestModel() + dataset = RandomDataset(32, 32) + + # Start updating after step 10 + callback = EMAWeightAveraging(decay=0.999, update_every_n_steps=1, update_starting_at_step=10) + _train(model, dataset, tmp_path, callback) + + assert callback._average_model is not None + + +def test_ema_weight_averaging_starting_epoch(tmp_path): + """Test EMAWeightAveraging with delayed start based on epochs.""" + model = TestModel() + dataset = RandomDataset(32, 32) + + # Start updating after epoch 3 + callback = EMAWeightAveraging(decay=0.999, update_every_n_steps=1, update_starting_at_epoch=3) + _train(model, dataset, tmp_path, callback) + + assert callback._average_model is not None + + +def test_ema_weight_averaging_should_update(tmp_path): + """Test the should_update logic of EMAWeightAveraging.""" + # Test with step-based updates + callback = EMAWeightAveraging(update_every_n_steps=5, update_starting_at_step=10) + + # Before starting step + assert not callback.should_update(step_idx=5) + assert not callback.should_update(step_idx=9) + + # At and after starting step, but not on update frequency + assert callback.should_update(step_idx=10) # First update + assert not callback.should_update(step_idx=11) + assert not callback.should_update(step_idx=14) + assert callback.should_update(step_idx=15) # Second update + + # Test with epoch-based updates + callback = EMAWeightAveraging(update_starting_at_epoch=2) + + assert not callback.should_update(epoch_idx=0) + assert not callback.should_update(epoch_idx=1) + assert callback.should_update(epoch_idx=2) + assert callback.should_update(epoch_idx=3) + + +def test_ema_weight_averaging_checkpoint_save_load(tmp_path): + """Test that EMAWeightAveraging correctly saves and loads checkpoints.""" + model = TestModel() + model.crash_on_epoch = 2 + dataset = RandomDataset(32, 32) + + callback = EMAWeightAveraging(decay=0.99, update_every_n_steps=2) + + # Train and create checkpoint + _train(model, dataset, tmp_path, callback, will_crash=True) + + # Resume from checkpoint + model2 = TestModel() + callback2 = EMAWeightAveraging(decay=0.99, update_every_n_steps=2) + import glob # should be at the top + + _train( + model2, + dataset, + tmp_path, + callback2, + checkpoint_path=glob.glob((tmp_path / "checkpoints" / "*.ckpt").as_posix())[0], + ) + + assert callback2._average_model is not None + + +@pytest.mark.parametrize("decay", [0.9, 0.99, 0.999, 0.9999]) +def test_ema_weight_averaging_decay_values(tmp_path, decay): + """Test EMAWeightAveraging with different decay values.""" + model = TestModel() + dataset = RandomDataset(32, 32) + + callback = EMAWeightAveraging(decay=decay, update_every_n_steps=1) + _train(model, dataset, tmp_path, callback) + + assert callback._average_model is not None