Weight Decay
Weight Decay
Section titled “Weight Decay”Penalises large weights by adding to the loss, or equivalently, shrinking weights by a factor each step. Keeps weights small, reduces overfitting, and improves generalisation. The most universal regulariser in deep learning — used in virtually every modern model.
Intuition
Section titled “Intuition”Think of weight decay as a spring pulling every weight back toward zero. Without it, weights are free to grow as large as they want to fit the training data — potentially memorising noise. The spring provides a restoring force: the further a weight drifts from zero, the harder it gets pulled back. Only weights that significantly reduce the loss can justify being large.
This has a smoothing effect on the learned function. Large weights create sharp, spiky decision boundaries that overfit to individual training examples. Small weights produce smoother functions that generalise better. Weight decay is essentially Occam’s razor in optimisation form: prefer the simplest (smallest-weight) model that explains the data.
A critical subtlety: classical L2 regularisation (add to the loss) and decoupled weight decay (subtract from the weight each step) are identical for SGD but different for Adam. Adam scales gradients by their running variance, which also scales the L2 penalty — effectively applying less regularisation to weights with large gradients. AdamW fixes this by applying weight decay directly to the weights, outside the adaptive learning rate. This is why AdamW is the modern default.
L2 regularisation — add penalty to loss:
The gradient becomes:
SGD with L2 — the update step:
The factor shrinks weights toward zero each step — this is why it’s called “weight decay.”
Decoupled weight decay (AdamW):
The decay is applied directly to the weights, not routed through Adam’s adaptive scaling.
Key difference: with Adam + L2, the effective regularisation per weight is (weakened for high-variance gradients). With AdamW, every weight gets exactly decay regardless of gradient history.
import torch
# ── AdamW (decoupled weight decay) — the modern default ────────# weight_decay parameter IS the decay factor λ, applied directly.optimiser = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
# ── SGD with weight decay ──────────────────────────────────────# For SGD, weight_decay and L2 are equivalent.optimiser = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=1e-4)
# ── Adam with L2 (NOT the same as AdamW) ───────────────────────# torch.optim.Adam's weight_decay parameter does L2, not decoupled decay.# WARNING: this is almost never what you want. Use AdamW instead.optimiser = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.01)
# ── Exclude bias and LayerNorm from weight decay ───────────────# Common practice: only regularise weight matrices, not biases or norms.decay_params = [p for n, p in model.named_parameters() if "weight" in n and p.dim() >= 2]no_decay_params = [p for n, p in model.named_parameters() if "weight" not in n or p.dim() < 2]optimiser = torch.optim.AdamW([ {"params": decay_params, "weight_decay": 0.01}, {"params": no_decay_params, "weight_decay": 0.0},], lr=1e-3)Manual Implementation
Section titled “Manual Implementation”import numpy as np
def sgd_step_with_weight_decay(params, grads, lr, weight_decay): """ SGD + L2 regularisation (equivalent to decoupled weight decay for SGD). params: list of arrays (model weights) grads: list of arrays (gradients, same shapes) """ for w, g in zip(params, grads): w -= lr * (g + weight_decay * w) # shrink + step
def adamw_step(params, grads, m_states, v_states, lr, weight_decay, beta1, beta2, eps, t): """ Decoupled weight decay (AdamW). One step. m_states, v_states: running moment estimates, same shapes as params. t: step number (1-indexed for bias correction). """ for w, g, m, v in zip(params, grads, m_states, v_states): m[:] = beta1 * m + (1 - beta1) * g # first moment v[:] = beta2 * v + (1 - beta2) * g ** 2 # second moment m_hat = m / (1 - beta1 ** t) # bias correction v_hat = v / (1 - beta2 ** t) # bias correction w -= lr * m_hat / (np.sqrt(v_hat) + eps) # Adam step w -= weight_decay * w # decoupled decayPopular Uses
Section titled “Popular Uses”- LLM pretraining (GPT, LLaMA, Mistral): AdamW with weight_decay=0.1, excluding biases and layer norms
- Vision transformers (ViT, DeiT): weight_decay=0.05 is typical; critical for ViT which overfits easily without it
- Fine-tuning: often increased weight decay to prevent catastrophic forgetting (keeps weights close to pretrained values)
- CNN training (ResNet): SGD with weight_decay=1e-4, the classic recipe
- GAN training: careful tuning needed — too much decay can destabilise the discriminator/generator balance
Alternatives
Section titled “Alternatives”| Alternative | When to use | Tradeoff |
|---|---|---|
| L1 regularisation | When you want sparse weights (feature selection) | Drives weights exactly to zero; less common in deep learning, more in linear models |
| Dropout | MLPs, older architectures | Regularises activations not weights; introduces train/test discrepancy |
| Early stopping | When validation loss is monitored | Implicitly limits effective model capacity; no hyperparameter to tune beyond patience |
| Max-norm constraint | RNNs, when weight explosion is a concern | Clips weight magnitude directly; harder to tune than smooth decay |
| Spectral normalisation | GAN discriminators | Constrains the spectral norm specifically; more targeted than general weight decay |
Historical Context
Section titled “Historical Context”Weight decay dates back to ridge regression in statistics (Hoerl & Kennard, 1970) and was adopted early in neural network training as a standard regulariser. For decades, “L2 regularisation” and “weight decay” were treated as synonymous because they are equivalent under SGD.
The critical insight came from Loshchilov & Hutter (2019, “Decoupled Weight Decay Regularization”), who showed that L2 regularisation and weight decay diverge under adaptive optimisers like Adam. Their AdamW variant — applying decay directly to weights rather than through the gradient — was a simple fix that meaningfully improved generalisation. AdamW is now the default optimiser for virtually all transformer training, and “weight_decay=0.01 to 0.1” is standard in most recipes.