Weight Initialisation (Xavier / Kaiming)
Weight Initialisation (Xavier / Kaiming)
Section titled “Weight Initialisation (Xavier / Kaiming)”Controls the scale of initial weights so that signals neither explode nor vanish as they propagate through layers. Without proper initialisation, a 50-layer network’s activations can overflow to infinity or collapse to zero before training even begins.
Intuition
Section titled “Intuition”Imagine passing a message through a chain of people, where each person multiplies the message by a random number. If that number is typically greater than 1, the message grows exponentially. If less than 1, it shrinks to nothing. The fix: choose multipliers that keep the message roughly the same size on average.
That is exactly what happens with layer activations. Each layer multiplies its input by a weight matrix. If the variance of the weights is too high, activations explode; too low, they vanish. The solution is to set the initial weight variance so that the variance of the output equals the variance of the input — “variance preservation.”
Xavier init achieves this for linear/tanh activations by setting . Kaiming init adjusts for ReLU, which zeros out half the activations: to compensate, it doubles the variance to . This single factor-of-2 correction is what allows deep ReLU networks to train stably.
The problem: for layer with :
To preserve variance (), we need .
Xavier / Glorot init (linear, tanh, sigmoid — symmetric activations):
Compromises between preserving variance in both forward and backward passes.
Kaiming / He init (ReLU — kills half the activations):
The factor of 2 compensates for ReLU zeroing out negative values, which halves the variance.
Uniform vs normal sampling:
- Uniform : , so
- Normal :
Transformer scaled init (common practice):
GPT-2 scales residual projections by to prevent the residual sum from growing.
import torchimport torch.nn as nn
# ── Xavier (Glorot) — for tanh/sigmoid layers ────────────────────linear = nn.Linear(256, 128)nn.init.xavier_uniform_(linear.weight) # U[-a, a], a = sqrt(6/(fan_in+fan_out))nn.init.xavier_normal_(linear.weight) # N(0, 2/(fan_in+fan_out))
# ── Kaiming (He) — for ReLU layers ───────────────────────────────conv = nn.Conv2d(64, 128, 3)nn.init.kaiming_normal_(conv.weight, mode='fan_in', nonlinearity='relu')nn.init.kaiming_uniform_(conv.weight, mode='fan_in', nonlinearity='relu')# mode='fan_out' preserves variance in backward pass instead
# ── Transformer-style scaled init ────────────────────────────────d_model = 768layer = nn.Linear(d_model, d_model)nn.init.normal_(layer.weight, mean=0.0, std=1 / (d_model ** 0.5))nn.init.zeros_(layer.bias)
# ── Check what PyTorch does by default ───────────────────────────default_linear = nn.Linear(256, 128) # Kaiming uniform by default# Verify: weight variance should be ≈ 1/fan_in = 1/256print(default_linear.weight.var().item()) # ≈ 0.0039Warning: PyTorch’s nn.Linear default is Kaiming uniform — this is fine for ReLU but suboptimal for other activations. If using GELU or SiLU, consider explicit initialisation.
Manual Implementation
Section titled “Manual Implementation”import numpy as np
def xavier_uniform(fan_in, fan_out): """Xavier/Glorot uniform init. Best for tanh/sigmoid.""" a = np.sqrt(6.0 / (fan_in + fan_out)) return np.random.uniform(-a, a, size=(fan_out, fan_in)) # (fan_out, fan_in)
def xavier_normal(fan_in, fan_out): """Xavier/Glorot normal init.""" std = np.sqrt(2.0 / (fan_in + fan_out)) return np.random.randn(fan_out, fan_in) * std # (fan_out, fan_in)
def kaiming_normal(fan_in, fan_out): """Kaiming/He normal init. Best for ReLU.""" std = np.sqrt(2.0 / fan_in) return np.random.randn(fan_out, fan_in) * std # (fan_out, fan_in)
def kaiming_uniform(fan_in, fan_out): """Kaiming/He uniform init. PyTorch nn.Linear default.""" a = np.sqrt(6.0 / fan_in) # sqrt(3 * 2/fan_in) return np.random.uniform(-a, a, size=(fan_out, fan_in)) # (fan_out, fan_in)
# Verify variance preservation through a 50-layer ReLU networkx = np.random.randn(32, 256) # (B, D)for _ in range(50): W = kaiming_normal(256, 256) # (256, 256) x = x @ W.T # (B, 256) x = np.maximum(0, x) # ReLUprint(f"Activation std after 50 layers: {x.std():.4f}") # should be ≈ O(1)Popular Uses
Section titled “Popular Uses”- ResNets (He et al.): Kaiming init enabled training of 100+ layer networks with ReLU; without it, these networks don’t converge
- Transformers (GPT, BERT, LLaMA): scaled normal init with or ; residual projections scaled by
- GANs: proper init is critical — DCGAN specifies for all weights
- LSTM / GRU: orthogonal init for recurrent weights preserves gradient norms across time steps
- nn-training entry: the
init_weightsvariant axis demonstrates Xavier vs Kaiming vs scaled init
Alternatives
Section titled “Alternatives”| Alternative | When to use | Tradeoff |
|---|---|---|
| Orthogonal init | RNNs, very deep networks | Preserves gradient norms exactly; slightly more expensive to compute |
| LSUV (Layer-Sequential Unit Variance) | Networks with unusual architectures | Data-driven: passes a batch and rescales each layer; more robust but slower |
| Fixup init | ResNets without BatchNorm | Scales residual branches by ; avoids need for normalisation layers |
| Zero init (for residual branches) | Transformer residual projections, ReZero | Identity at init — each layer starts as a no-op; stable training but slower early progress |
| Pretrained weights | Transfer learning, fine-tuning | Bypasses init entirely; best when sufficient pretraining data exists |
Historical Context
Section titled “Historical Context”The variance preservation idea was formalised by Glorot & Bengio (2010) as “Xavier initialisation,” derived for linear and tanh activations. He et al. (2015) extended this to ReLU networks as “Kaiming initialisation” — the factor-of-2 correction was the key insight that enabled training of very deep residual networks.
Before these principled approaches, practitioners used heuristics like or , which happened to work for shallow networks but failed catastrophically for deep ones. Modern normalisation layers (BatchNorm, LayerNorm) reduce sensitivity to initialisation but don’t eliminate it — scaled init in Transformers remains important, especially for training stability at large model sizes.