Exponential Moving Average (EMA)
Exponential Moving Average (EMA)
Section titled “Exponential Moving Average (EMA)”Maintaining a slow-moving copy of model weights that tracks the training weights with a lag: . The EMA weights are smoother, more stable, and often generalise better than the raw training weights. Used for evaluation in diffusion models, target networks in RL/contrastive learning, and stochastic weight averaging.
Intuition
Section titled “Intuition”During training, model weights bounce around noisily due to stochastic gradient updates. Each step makes progress on average, but any single checkpoint is noisy — it reflects whatever batch it just saw. The EMA weights are a running average that smooths out this noise, like a low-pass filter on the weight trajectory.
With a typical decay of , the EMA weights are roughly the average of the last ~1000 sets of training weights (the effective window is steps). This averaging has a regularising effect: it smooths out sharp minima and favours flat regions of the loss landscape, which tend to generalise better.
The critical insight is that EMA weights are never trained directly — gradients flow only through the regular weights. The EMA is a passive observer that accumulates a smoothed version. This means it adds almost zero compute (just a weighted sum per parameter per step) and doesn’t interfere with the optimizer at all. In RL and contrastive learning, EMA serves a different purpose: the EMA copy provides a stable target that doesn’t change as fast as the online model, preventing the “moving target” problem in bootstrapped learning.
EMA update (decay , current weights , EMA weights ):
Equivalently:
Effective averaging window: steps. window of ~1000 steps. ~10,000 steps.
Bias correction (optional, for early steps when EMA is initialised to zero or to ):
This is the same correction Adam uses for its moment estimates. Often skipped in practice when is close to 1 and training is long.
Common defaults: for diffusion models (DDPM). for large-scale image generation (Imagen, DALL-E). (i.e., in the “soft update” convention) for RL target networks.
Note: RL papers often write , flipping the meaning of . In that convention, is a slow update. Always check which convention a paper uses.
import torchimport copy
# ── Manual EMA (the standard approach) ──────────────────────────ema_model = copy.deepcopy(model) # initialise EMA to current weightsema_decay = 0.999
for p_ema, p in zip(ema_model.parameters(), model.parameters()): p_ema.requires_grad_(False) # EMA weights are never trained
# In the training loop, after optimizer.step():with torch.no_grad(): for p_ema, p in zip(ema_model.parameters(), model.parameters()): p_ema.lerp_(p, 1 - ema_decay) # θ_ema = τ·θ_ema + (1-τ)·θ # lerp_(target, weight) does: self = self + weight * (target - self)
# Use ema_model for evaluation / inferenceoutput = ema_model(test_input)
# ── Using torch.optim.swa_utils (built-in since PyTorch 1.6) ───from torch.optim.swa_utils import AveragedModelema_model = AveragedModel(model, multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999))
# After each optimizer step:ema_model.update_parameters(model)
# WARNING: remember to update batch norm statistics before evaluation:# torch.optim.swa_utils.update_bn(dataloader, ema_model)Manual Implementation
Section titled “Manual Implementation”import numpy as np
class EMA: """Exponential moving average of model parameters (numpy)."""
def __init__(self, params, decay=0.999): """ params: list of numpy arrays (model weights) decay: τ, how much to keep of the old EMA (0.999 = slow-moving) """ self.decay = decay self.shadow = [p.copy() for p in params] # initialise to current weights self.steps = 0
def update(self, params): """Call after each gradient step with the updated model params.""" self.steps += 1 for s, p in zip(self.shadow, params): s[:] = self.decay * s + (1 - self.decay) * p # EMA update
def get(self, bias_correction=False): """Return the EMA weights, optionally with bias correction.""" if bias_correction and self.steps > 0: correction = 1 - self.decay ** self.steps return [s / correction for s in self.shadow] return [s.copy() for s in self.shadow]
# Usage:# weights = [np.random.randn(768, 768), np.random.randn(768)]# ema = EMA(weights, decay=0.999)# ... after each training step: ema.update(weights)# ... for evaluation: eval_weights = ema.get()Popular Uses
Section titled “Popular Uses”- Diffusion model inference (DDPM, Stable Diffusion, Imagen): the EMA weights are the ones actually shipped and used for generation — they produce noticeably better samples than the raw training weights
- Target networks in RL (DQN, SAC, TD3): EMA provides a slowly-changing target Q-network for stable bootstrapping; see
q-learning/in this series - Contrastive learning (MoCo, BYOL): EMA encodes the “momentum” branch, providing stable targets without requiring a stop-gradient on the full pipeline; see
contrastive-self-supervising/ - Stochastic Weight Averaging (SWA, Izmailov et al. 2018): equal-weight average of checkpoints along the trajectory — a special case of EMA with adjusted per step
- Pseudo-labelling / self-training (Mean Teacher, Noisy Student): EMA teacher generates pseudo-labels for unlabelled data; stability of EMA prevents confirmation bias
Alternatives
Section titled “Alternatives”| Alternative | When to use | Tradeoff |
|---|---|---|
| Hard target update (copy every N steps) | Classic DQN | Simpler; but creates discontinuities when the target network jumps. EMA is smoother |
| Polyak averaging (uniform average of all checkpoints) | Post-training ensemble | Equal weight to all checkpoints; EMA down-weights old ones, which is better during training when the model is improving |
| Stochastic Weight Averaging (SWA) | End-of-training performance boost | Average only over the final phase of training with cyclical LR. Better generalisation than EMA alone but requires a specific schedule |
| Checkpoint ensembling | Maximum accuracy, inference budget allows | Keep multiple full checkpoints and average their predictions (not weights). Higher compute but captures diverse solutions |
| No EMA (use final weights) | Quick experiments, small models | Saves memory (EMA doubles parameter storage); acceptable when training is stable and short |
Historical Context
Section titled “Historical Context”Exponential moving averages have been used in signal processing and statistics for decades, but their application to neural network weights was popularised by Polyak and Juditsky (1992) as “iterate averaging” for stochastic approximation. The technique languished in neural network practice until Tarvainen and Valpola (2017, “Mean Teachers are Better Role Models”) demonstrated that EMA teachers dramatically improve semi-supervised learning.
The technique became standard infrastructure after three independent developments: DQN (Mnih et al., 2015) used periodic hard copies for target networks, later replaced by EMA soft updates in DDPG/TD3/SAC; MoCo (He et al., 2020) showed EMA is key to momentum contrastive learning; and DDPM (Ho et al., 2020) demonstrated that EMA weights produce substantially better generated samples. Today, EMA is essentially a free lunch — the memory cost of storing a second copy of the weights is the only downside, and it’s worth it in nearly every setting.