1+ import warnings
12from copy import deepcopy
23from typing import Optional , Union
34
45import torch .nn as nn
56
67from ignite .engine import CallableEventWithFilter , Engine , Events , EventsList
8+ from ignite .handlers .param_scheduler import BaseParamScheduler
9+ from ignite .handlers .state_param_scheduler import LambdaStateScheduler
710
811__all__ = ["EMAHandler" ]
912
1013
14+ class EMAWarmUp :
15+ def __init__ (self , momentum_warmup : float , warmup_iters : int , momentum : float ) -> None :
16+ self .momentum_warmup = momentum_warmup
17+ self .warmup_iters = warmup_iters
18+ self .momentum = momentum
19+
20+ def __call__ (self , event_index : int ) -> float :
21+ denominator = max (1 , self .warmup_iters - 1 )
22+ curr_momentum = self .momentum_warmup + (self .momentum - self .momentum_warmup ) * (event_index - 1 ) / denominator
23+ if self .momentum >= self .momentum_warmup :
24+ return min (self .momentum , curr_momentum )
25+ else :
26+ return max (self .momentum , curr_momentum )
27+
28+
1129class EMAHandler :
1230 r"""Exponential moving average (EMA) handler can be used to compute a smoothed version of model.
1331 The EMA model is updated as follows:
1432
1533 .. math:: \theta_{\text{EMA}, t+1} = (1 - \lambda) \cdot \theta_{\text{EMA}, t} + \lambda \cdot \theta_{t}
1634
1735 where :math:`\theta_{\text{EMA}, t}` and :math:`\theta_{t}` are the EMA weights and online model weights at
18- :math:`t`-th iteration, respectively; :math:`\lambda` is the update momentum. The handler allows for linearly
19- warming up the momentum in the beginning when training process is not stable. Current momentum can be retrieved
36+ :math:`t`-th iteration, respectively; :math:`\lambda` is the update momentum. Current momentum can be retrieved
2037 from ``Engine.state.ema_momentum``.
2138
2239 Args:
2340 model: the online model for which an EMA model will be computed. If ``model`` is ``DataParallel`` or
2441 ``DistributedDataParallel``, the EMA smoothing will be applied to ``model.module`` .
2542 momentum: the update momentum after warmup phase, should be float in range :math:`\left(0, 1 \right)`.
26- momentum_warmup: the initial update momentum during warmup phase, the value should be smaller than
27- ``momentum``. Momentum will linearly increase from this value to ``momentum`` in ``warmup_iters``
28- iterations. If ``None``, no warmup will be performed.
29- warmup_iters: iterations of warmup. If ``None``, no warmup will be performed.
43+ momentum_warmup: the initial update momentum during warmup phase.
44+ warmup_iters: iterations of warmup.
3045
3146 Attributes:
3247 ema_model: the exponential moving averaged model.
3348 model: the online model that is tracked by EMAHandler. It is ``model.module`` if ``model`` in
3449 the initialization method is an instance of ``DistributedDataParallel``.
35- momentum: the update momentum after warmup phase.
36- momentum_warmup: the initial update momentum.
37- warmup_iters: number of warmup iterations.
50+ momentum: the update momentum.
3851
3952 Note:
4053 The EMA model is already in ``eval`` mode. If model in the arguments is an ``nn.Module`` or
@@ -56,8 +69,7 @@ class EMAHandler:
5669 device = torch.device("cuda:0")
5770 model = nn.Linear(2, 1).to(device)
5871 # update the ema every 5 iterations
59- ema_handler = EMAHandler(
60- model, momentum=0.0002, momentum_warmup=0.0001, warmup_iters=10000)
72+ ema_handler = EMAHandler(model, momentum=0.0002)
6173 # get the ema model, which is an instance of nn.Module
6274 ema_model = ema_handler.ema_model
6375 trainer = Engine(train_step_fn)
@@ -89,6 +101,19 @@ def run_validation(engine):
89101
90102 trainer.run(...)
91103
104+ The following example shows how to perform warm-up to the EMA momentum:
105+
106+ .. code-block:: python
107+
108+ device = torch.device("cuda:0")
109+ model = nn.Linear(2, 1).to(device)
110+ # linearly change the EMA momentum from 0.2 to 0.002 in the first 100 iterations,
111+ # then keep a constant EMA momentum of 0.002 afterwards
112+ ema_handler = EMAHandler(model, momentum=0.002, momentum_warmup=0.2, warmup_iters=100)
113+ engine = Engine(step_fn)
114+ ema_handler.attach(engine, name="ema_momentum")
115+ engine.run(...)
116+
92117 The following example shows how to attach two handlers to the same trainer:
93118
94119 .. code-block:: python
@@ -125,25 +150,19 @@ def __init__(
125150 momentum_warmup : Optional [float ] = None ,
126151 warmup_iters : Optional [int ] = None ,
127152 ) -> None :
128- if momentum_warmup is not None and not 0 < momentum_warmup < 1 :
129- raise ValueError (f"Invalid momentum_warmup: { momentum_warmup } " )
130153 if not 0 < momentum < 1 :
131154 raise ValueError (f"Invalid momentum: { momentum } " )
132- if momentum_warmup is not None and not momentum_warmup <= momentum :
133- raise ValueError (
134- f"momentum_warmup should be less than or equal to momentum, but got "
135- f"momentum_warmup: { momentum_warmup } and momentum: { momentum } "
136- )
137- if warmup_iters is not None and not (isinstance (warmup_iters , int ) and warmup_iters > 0 ):
138- raise ValueError (f"Invalid warmup_iters: { warmup_iters } " )
155+ self .momentum = momentum
156+ self ._momentum_lambda_obj : Optional [EMAWarmUp ] = None
157+ if momentum_warmup is not None and warmup_iters is not None :
158+ self .momentum_scheduler : Optional [BaseParamScheduler ] = None
159+ self ._momentum_lambda_obj = EMAWarmUp (momentum_warmup , warmup_iters , momentum )
160+
139161 if not isinstance (model , nn .Module ):
140162 raise ValueError (
141163 f"model should be an instance of nn.Module or its subclasses, but got"
142164 f"model: { model .__class__ .__name__ } "
143165 )
144- self .momentum_warmup = momentum_warmup
145- self .momentum = momentum
146- self .warmup_iters = warmup_iters
147166
148167 if isinstance (model , nn .parallel .DistributedDataParallel ):
149168 model = model .module
@@ -154,22 +173,6 @@ def __init__(
154173 param .detach_ ()
155174 self .ema_model .eval ()
156175
157- def _get_momentum (self , curr_iter : int ) -> float :
158- """Get current momentum, `curr_iter` should be 1-based. When `curr_iter = 1`, `momentum =
159- self.momentum_warmup`; when `curr_iter >= self.warmup_iters`, `momentum = self.momentum`"""
160-
161- # TODO: use ignite's parameter scheduling, see also GitHub issue #2090
162- if curr_iter < 1 :
163- raise ValueError (f"curr_iter should be at least 1, but got { curr_iter } ." )
164-
165- # no warmup
166- if self .momentum_warmup is None or self .warmup_iters is None :
167- return self .momentum
168-
169- denominator = max (1 , self .warmup_iters - 1 )
170- momentum = self .momentum_warmup + (self .momentum - self .momentum_warmup ) * (curr_iter - 1 ) / denominator
171- return min (self .momentum , momentum )
172-
173176 def _update_ema_model (self , engine : Engine , name : str ) -> None :
174177 """Update weights of ema model"""
175178 momentum = getattr (engine .state , name )
@@ -179,36 +182,47 @@ def _update_ema_model(self, engine: Engine, name: str) -> None:
179182 for ema_b , model_b in zip (self .ema_model .buffers (), self .model .buffers ()):
180183 ema_b .data = model_b .data
181184
182- def _update_ema_momentum (self , engine : Engine , name : str ) -> None :
183- """Update momentum in engine.state"""
184- curr_iter = engine .state .iteration
185- momentum = self ._get_momentum (curr_iter )
186- setattr (engine .state , name , momentum )
187-
188185 def attach (
189186 self ,
190187 engine : Engine ,
191188 name : str = "ema_momentum" ,
189+ warn_if_exists : bool = True ,
192190 event : Union [str , Events , CallableEventWithFilter , EventsList ] = Events .ITERATION_COMPLETED ,
193191 ) -> None :
194192 """Attach the handler to engine. After the handler is attached, the ``Engine.state`` will add an new attribute
195- with name ``name``. Then, current momentum can be retrieved by from ``Engine.state`` when the engine runs.
193+ with name ``name`` if the attribute does not exist. Then, the current momentum can be retrieved from
194+ ``Engine.state`` when the engine runs.
195+
196+
197+ Note:
198+ There are two cases where a momentum with name ``name`` already exists: 1. the engine has loaded its
199+ state dict after resuming. In this case, there is no need to initialize the momentum again, and users
200+ can set ``warn_if_exists`` to False to suppress the warning message; 2. another handler has created
201+ a state attribute with the same name. In this case, users should choose another name for the ema momentum.
202+
196203
197204 Args:
198205 engine: trainer to which the handler will be attached.
199206 name: attribute name for retrieving EMA momentum from ``Engine.state``. It should be a unique name since a
200207 trainer can have multiple EMA handlers.
208+ warn_if_exists: if True, a warning will be thrown if the momentum with name ``name`` already exists.
201209 event: event when the EMA momentum and EMA model are updated.
202210
203211 """
204212 if hasattr (engine .state , name ):
205- raise ValueError (
206- f"Attribute: '{ name } ' is already in Engine.state. Thus it might be "
207- f"overridden by other EMA handlers. Please select another name."
208- )
209-
210- setattr (engine .state , name , 0.0 )
211-
212- # first update momentum, then update ema model
213- engine .add_event_handler (event , self ._update_ema_momentum , name )
213+ if warn_if_exists :
214+ warnings .warn (
215+ f"Attribute '{ name } ' already exists in Engine.state. It might because 1. the engine has loaded its "
216+ f"state dict or 2. { name } is already created by other handlers. Turn off this warning by setting"
217+ f"warn_if_exists to False." ,
218+ category = UserWarning ,
219+ )
220+ else :
221+ setattr (engine .state , name , self .momentum )
222+
223+ if self ._momentum_lambda_obj is not None :
224+ self .momentum_scheduler = LambdaStateScheduler (self ._momentum_lambda_obj , param_name = "ema_momentum" )
225+
226+ # first update the momentum and then update the EMA model
227+ self .momentum_scheduler .attach (engine , event )
214228 engine .add_event_handler (event , self ._update_ema_model , name )
0 commit comments