Learning Rate Warmup
Learning Rate Warmup
Section titled “Learning Rate Warmup”Linearly increasing the learning rate from ~0 to the target value over the first N steps. Prevents early instability when Adam’s second-moment estimates are not yet calibrated. Standard practice in transformer training (GPT, BERT, ViT).
Intuition
Section titled “Intuition”At step 0, Adam’s running estimate of the squared gradient () is initialised to zero. Because Adam divides by , those first updates are divided by something tiny, producing wildly large parameter changes. The model can diverge in the first few hundred steps before the optimizer has seen enough gradients to form reliable estimates.
Warmup is the fix: start with a near-zero learning rate and ramp it up linearly over, say, 1000 steps. During that ramp, even though the per-parameter scaling in Adam is unreliable, the small global learning rate keeps the actual updates small. By the time the learning rate reaches its full value, the second-moment estimates have accumulated enough history to be trustworthy.
This is why warmup is most critical for Adam-family optimizers. Plain SGD with momentum doesn’t have the second-moment issue, so warmup helps less there (though it can still smooth the initial transient). The deeper or more attention-heavy the model, the more warmup matters — large transformers routinely use 1-5% of total steps for warmup.
Linear warmup schedule (step , warmup steps , target learning rate ):
After warmup, the learning rate is typically handed off to a decay schedule (cosine annealing, linear decay, etc.):
Common default: = 1-5% of total training steps. GPT-3 used 375 warmup steps out of ~300k total. ViT used 10k warmup steps.
import torch
# ── Using PyTorch's built-in schedulers ─────────────────────────optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
# Linear warmup for 1000 steps, then constantwarmup = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=1e-8/3e-4, total_iters=1000 # ramp from ~0 to 3e-4)
# Warmup + cosine decay (the standard transformer recipe)warmup = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=1e-8/3e-4, total_iters=1000)cosine = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=99000 # remaining steps after warmup)scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, schedulers=[warmup, cosine], milestones=[1000])
# ── In the training loop ────────────────────────────────────────for step, batch in enumerate(dataloader): loss = model(batch).loss loss.backward() optimizer.step() optimizer.zero_grad() scheduler.step() # call AFTER optimizer.step()Manual Implementation
Section titled “Manual Implementation”import numpy as np
def warmup_lr(step, lr_max, warmup_steps): """ Returns the learning rate at a given step during linear warmup. step: current training step (0-indexed) lr_max: target learning rate after warmup warmup_steps: number of steps to ramp over """ if step < warmup_steps: return lr_max * (step / warmup_steps) # linear ramp return lr_max # constant after
def warmup_cosine_lr(step, lr_max, warmup_steps, total_steps, lr_min=0.0): """ Linear warmup followed by cosine decay — the standard transformer schedule. """ if step < warmup_steps: return lr_max * (step / warmup_steps) # linear ramp # Cosine decay phase progress = (step - warmup_steps) / (total_steps - warmup_steps) # 0 → 1 return lr_min + 0.5 * (lr_max - lr_min) * (1 + np.cos(np.pi * progress))
# Example: 100k steps, 2k warmup, peak lr 3e-4lrs = [warmup_cosine_lr(t, 3e-4, 2000, 100000, 1e-5) for t in range(100000)]# lrs[0] ≈ 0, lrs[2000] = 3e-4, lrs[100000-1] ≈ 1e-5Popular Uses
Section titled “Popular Uses”- LLM pre-training (GPT, LLaMA, Chinchilla): linear warmup + cosine decay is the de facto standard schedule
- Vision transformers (ViT, DeiT, Swin): warmup is critical because self-attention layers amplify the early-step instability
- BERT / masked language modelling: original BERT paper used 10k warmup steps out of 1M total
- Fine-tuning large models: shorter warmup (100-500 steps) helps stabilise the first few gradient updates on a new task
- Diffusion models (DDPM, Stable Diffusion): warmup used in UNet training to prevent early divergence
Alternatives
Section titled “Alternatives”| Alternative | When to use | Tradeoff |
|---|---|---|
| No warmup (constant LR) | SGD with momentum on CNNs | Works for simpler optimizers; Adam without warmup risks early divergence |
| Exponential warmup | When linear ramp is too slow | Faster ramp but harder to tune; less common in practice |
| RAdam | Drop-in Adam replacement | Automatically corrects the variance bias in Adam’s early steps, removing the need for explicit warmup. Slightly higher compute per step |
| Gradual unfreezing | Transfer learning / fine-tuning | Warms up capacity rather than learning rate — unfreeze layers one at a time. Complementary to LR warmup |
| Learning rate probing | Unknown good LR range | LR range test (Smith 2017) sweeps LR to find the right max before setting the warmup target |
Historical Context
Section titled “Historical Context”The need for warmup was first identified empirically in the transformer paper (Vaswani et al., 2017, “Attention Is All You Need”), which used a specific schedule: warmup over 4000 steps followed by inverse-square-root decay. The authors didn’t explain it theoretically — it was a practical fix that made training converge.
The theoretical justification came from Liu et al. (2020, “On the Variance of the Adaptive Learning Rate and Beyond”), who showed that Adam’s adaptive learning rate has excessively high variance in early training because the second-moment estimate is biased toward zero. Their RAdam optimizer corrects this analytically, but linear warmup remains the dominant practical solution because it’s simpler, well-understood, and composes cleanly with any decay schedule.