Skip to content

Numerical Instability

Floating-point operations produce NaN, Inf, or silently lose precision during training. The model appears to train normally until a specific combination of values triggers overflow in exp(), underflow in log(), or catastrophic cancellation in subtraction. Often intermittent and hard to reproduce.

Floating-point numbers are an approximation of real numbers with finite precision (32-bit floats have ~7 decimal digits of precision, 16-bit have ~3). Most arithmetic works fine, but certain operations amplify tiny errors into catastrophic failures:

  • exp(x) for large x: exp(100) ≈ 2.7 × 10⁴³, exp(710) = Inf in float64, exp(89) = Inf in float32
  • log(x) for x near 0: log(0) = -Inf, and log of a very small but positive number is a very large negative number
  • Subtraction of nearly equal numbers: 1.000001 - 1.000000 loses most significant digits (catastrophic cancellation)

These aren’t edge cases — they occur naturally in deep learning. Softmax involves exp() of logits, which can overflow. Cross-entropy involves log() of probabilities, which can underflow. Attention scores involve dot products that grow with embedding dimension. Every one of these has a well-known numerical trick that avoids the instability, and modern frameworks use them by default — but the moment you write a custom operation, you’re on your own.

  • Loss suddenly becomes NaN or Inf after training was progressing normally — often caused by a single unlucky batch
  • Gradients become NaN and propagate through the entire model within a single step
  • Mixed-precision (fp16) training is more fragile — the smaller dynamic range makes overflow/underflow more likely
  • Specific operations are hot spots: softmax on large logits, log of small probabilities, division by small numbers, exp() in attention
  • The problem may be non-deterministic — different random seeds or batch orderings trigger it at different times
  • Transformer (transformer/): attention scores grow as O(dk)O(\sqrt{d_k}) — the 1/dk1/\sqrt{d_k} scaling is specifically to prevent softmax overflow; without it, attention becomes numerically unstable for large embedding dimensions
  • NN training (nn-training/): cross-entropy loss must be computed from logits (not probabilities) to avoid log(0) — the logsumexp trick is essential
  • Diffusion (diffusion/): noise schedules that approach 0 or 1 can cause log(0) or division-by-zero in the loss computation
  • GANs (gans/): log(1 - D(G(z))) in the vanilla GAN loss is numerically unstable when D(G(z)) ≈ 1 — the non-saturating loss reformulation addresses this
  • VAE (variational-inference-vae/): log-variance parameterisation (outputting log σ² instead of σ) avoids requiring σ > 0 and handles small variances stably
SolutionMechanismWhere documented
Logsumexp trickSubtract max before exp() in softmax/log-softmax: numerically identical, overflow-freeatomic-concepts/mathematical-tricks/logsumexp-trick.md
Log-space computationWork with log-probabilities throughout, avoiding exp() and log() of extreme values(standard practice)
Epsilon guardsAdd small ε (e.g., 1e-8) to denominators and arguments of log()(standard practice)
Scaled attentionDivide dot products by dk\sqrt{d_k} before softmaxatomic-concepts/architectural-primitives/scaled-dot-product-attention.md
Loss scaling (mixed precision)Scale loss up before backward pass, scale gradients down after — keeps fp16 gradients in representable range(AMP in PyTorch)
Gradient clippingCaps gradient norms, preventing a single NaN-causing updateatomic-concepts/optimisation-primitives/gradient-clipping.md

Numerical stability concerns predate deep learning by decades — Goldberg’s “What Every Computer Scientist Should Know About Floating-Point Arithmetic” (1991) is the classic reference. In deep learning, the logsumexp trick for stable softmax was one of the earliest and most impactful practical contributions. Mixed-precision training (Micikevicius et al., 2018) made numerical stability a mainstream concern by introducing fp16, which has only 5 bits of exponent (range ±65504) versus fp32’s 8 bits (range ±3.4 × 10³⁸). The introduction of bfloat16 (Brain Floating Point) by Google was specifically motivated by numerical stability — it has the same exponent range as fp32 but with reduced mantissa precision, making overflow much less likely than fp16.