Numerical Instability
Numerical Instability
Section titled “Numerical Instability”Floating-point operations produce NaN, Inf, or silently lose precision during training. The model appears to train normally until a specific combination of values triggers overflow in exp(), underflow in log(), or catastrophic cancellation in subtraction. Often intermittent and hard to reproduce.
Intuition
Section titled “Intuition”Floating-point numbers are an approximation of real numbers with finite precision (32-bit floats have ~7 decimal digits of precision, 16-bit have ~3). Most arithmetic works fine, but certain operations amplify tiny errors into catastrophic failures:
- exp(x) for large x: exp(100) ≈ 2.7 × 10⁴³, exp(710) = Inf in float64, exp(89) = Inf in float32
- log(x) for x near 0: log(0) = -Inf, and log of a very small but positive number is a very large negative number
- Subtraction of nearly equal numbers: 1.000001 - 1.000000 loses most significant digits (catastrophic cancellation)
These aren’t edge cases — they occur naturally in deep learning. Softmax involves exp() of logits, which can overflow. Cross-entropy involves log() of probabilities, which can underflow. Attention scores involve dot products that grow with embedding dimension. Every one of these has a well-known numerical trick that avoids the instability, and modern frameworks use them by default — but the moment you write a custom operation, you’re on your own.
Manifestation
Section titled “Manifestation”- Loss suddenly becomes NaN or Inf after training was progressing normally — often caused by a single unlucky batch
- Gradients become NaN and propagate through the entire model within a single step
- Mixed-precision (fp16) training is more fragile — the smaller dynamic range makes overflow/underflow more likely
- Specific operations are hot spots: softmax on large logits, log of small probabilities, division by small numbers, exp() in attention
- The problem may be non-deterministic — different random seeds or batch orderings trigger it at different times
Where It Appears
Section titled “Where It Appears”- Transformer (
transformer/): attention scores grow as — the scaling is specifically to prevent softmax overflow; without it, attention becomes numerically unstable for large embedding dimensions - NN training (
nn-training/): cross-entropy loss must be computed from logits (not probabilities) to avoid log(0) — the logsumexp trick is essential - Diffusion (
diffusion/): noise schedules that approach 0 or 1 can cause log(0) or division-by-zero in the loss computation - GANs (
gans/): log(1 - D(G(z))) in the vanilla GAN loss is numerically unstable when D(G(z)) ≈ 1 — the non-saturating loss reformulation addresses this - VAE (
variational-inference-vae/): log-variance parameterisation (outputting log σ² instead of σ) avoids requiring σ > 0 and handles small variances stably
Solutions at a Glance
Section titled “Solutions at a Glance”| Solution | Mechanism | Where documented |
|---|---|---|
| Logsumexp trick | Subtract max before exp() in softmax/log-softmax: numerically identical, overflow-free | atomic-concepts/mathematical-tricks/logsumexp-trick.md |
| Log-space computation | Work with log-probabilities throughout, avoiding exp() and log() of extreme values | (standard practice) |
| Epsilon guards | Add small ε (e.g., 1e-8) to denominators and arguments of log() | (standard practice) |
| Scaled attention | Divide dot products by before softmax | atomic-concepts/architectural-primitives/scaled-dot-product-attention.md |
| Loss scaling (mixed precision) | Scale loss up before backward pass, scale gradients down after — keeps fp16 gradients in representable range | (AMP in PyTorch) |
| Gradient clipping | Caps gradient norms, preventing a single NaN-causing update | atomic-concepts/optimisation-primitives/gradient-clipping.md |
Historical Context
Section titled “Historical Context”Numerical stability concerns predate deep learning by decades — Goldberg’s “What Every Computer Scientist Should Know About Floating-Point Arithmetic” (1991) is the classic reference. In deep learning, the logsumexp trick for stable softmax was one of the earliest and most impactful practical contributions. Mixed-precision training (Micikevicius et al., 2018) made numerical stability a mainstream concern by introducing fp16, which has only 5 bits of exponent (range ±65504) versus fp32’s 8 bits (range ±3.4 × 10³⁸). The introduction of bfloat16 (Brain Floating Point) by Google was specifically motivated by numerical stability — it has the same exponent range as fp32 but with reduced mantissa precision, making overflow much less likely than fp16.