Skip to content

KL Divergence

Measures how one probability distribution diverges from a second, reference distribution. The standard regularisation term in VAEs, the objective in knowledge distillation, and the quantity that policy gradient methods like PPO constrain. Also called “relative entropy.”

KL divergence answers: “if the world is actually distributed as P, how many extra bits do I waste by using distribution Q instead?” It is always non-negative (Gibbs’ inequality) and equals zero only when P and Q are identical. This asymmetry is its defining feature — KL(P || Q) and KL(Q || P) are different numbers and behave very differently during optimisation.

The direction matters in practice. KL(P || Q) — “forward KL” — penalises Q heavily wherever P has mass but Q doesn’t. This forces Q to cover all modes of P, even if it wastes probability on regions P ignores. KL(Q || P) — “reverse KL” — penalises Q for placing mass where P has none. This encourages Q to pick a single mode of P and match it tightly, potentially ignoring other modes entirely. VAEs use forward KL (data distribution as P); policy optimisation typically uses reverse KL (old policy as P).

A common source of confusion: minimising cross-entropy H(P, Q) with respect to Q is equivalent to minimising KL(P || Q), because the entropy H(P) is constant. This is why “cross-entropy loss” and “KL minimisation” are often used interchangeably in classification.

General form (discrete):

DKL(PQ)=iP(i)logP(i)Q(i)D_{\text{KL}}(P \| Q) = \sum_{i} P(i) \log \frac{P(i)}{Q(i)}

Equivalently:

DKL(PQ)=H(P,Q)H(P)D_{\text{KL}}(P \| Q) = H(P, Q) - H(P)

where H(P,Q)H(P, Q) is the cross-entropy and H(P)H(P) is the entropy of P.

Continuous form (e.g. between two Gaussians):

DKL(pq)=p(x)logp(x)q(x)dxD_{\text{KL}}(p \| q) = \int p(x) \log \frac{p(x)}{q(x)} \, dx

KL between two Gaussians (the closed-form used in VAEs):

DKL(N(μ,σ2)N(0,1))=12(μ2+σ2logσ21)D_{\text{KL}}\bigl(\mathcal{N}(\mu, \sigma^2) \| \mathcal{N}(0, 1)\bigr) = \frac{1}{2}\bigl(\mu^2 + \sigma^2 - \log \sigma^2 - 1\bigr)

This is summed over all latent dimensions. It pushes the encoder toward the standard normal prior.

With temperature (knowledge distillation):

LKD=T2DKL(softmax(zt/T)softmax(zs/T))\mathcal{L}_{\text{KD}} = T^2 \cdot D_{\text{KL}}\bigl(\text{softmax}(z_t / T) \| \text{softmax}(z_s / T)\bigr)

The T2T^2 scaling compensates for the reduced gradient magnitude from the softened distributions.

import torch
import torch.nn.functional as F
# ── KL between two distributions (from log-probs) ────────────────
# F.kl_div expects LOG-PROBABILITIES as input, raw probabilities as target.
# This is the opposite of what you'd expect. Read the docs carefully.
log_q = F.log_softmax(student_logits, dim=-1) # (B, C) — log probs
p = F.softmax(teacher_logits / T, dim=-1) # (B, C) — probs
# reduction='batchmean' gives the mathematically correct KL per sample.
# Do NOT use reduction='mean' — it divides by B*C instead of B.
loss = F.kl_div(log_q, p, reduction='batchmean')
# ── Gaussian KL for VAEs (closed-form) ───────────────────────────
# mu, log_var from the encoder: (B, d_latent)
kl = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1) # (B,)
kl_loss = kl.mean() # scalar
import numpy as np
def kl_divergence_categorical(p, log_q):
"""
KL(P || Q) for categorical distributions.
Equivalent to F.kl_div with reduction='batchmean'.
p: (B, C) probability targets (must sum to 1 along C)
log_q: (B, C) log-probability predictions
"""
# KL = sum_i p_i * (log p_i - log q_i), averaged over batch
# Avoid log(0) by clamping p
log_p = np.log(np.clip(p, 1e-12, None)) # (B, C)
kl_per_sample = (p * (log_p - log_q)).sum(axis=1) # (B,)
return kl_per_sample.mean()
def kl_divergence_gaussian(mu, log_var):
"""
KL(N(mu, sigma^2) || N(0, 1)), closed-form for VAEs.
Equivalent to the standard VAE KL term.
mu: (B, D) encoder means
log_var: (B, D) encoder log-variances
"""
# -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kl = -0.5 * (1 + log_var - mu ** 2 - np.exp(log_var)) # (B, D)
return kl.sum(axis=1).mean() # scalar
  • VAE regularisation (see variational-inference-vae/): the KL term pushes the approximate posterior toward the prior, balancing reconstruction quality against latent structure
  • Knowledge distillation (DistilBERT, TinyLLaMA): KL between teacher and student softened distributions transfers “dark knowledge” about inter-class similarities
  • PPO clipping (see policy-gradient/): PPO constrains the KL between old and new policies to prevent catastrophic policy updates
  • Diffusion guidance (classifier-free guidance): KL appears implicitly in the derivation of the score function that diffusion models learn
  • GAN training (see gans/): the vanilla GAN objective is equivalent to minimising the Jensen-Shannon divergence, which is a symmetrised KL
AlternativeWhen to useTradeoff
Cross-entropyFixed target distribution (classification)Equivalent gradient to KL when P is fixed; simpler API
Jensen-Shannon divergenceNeed a symmetric measure (e.g. GAN theory)Bounded [0, log 2], symmetric, but no closed-form for Gaussians
Wasserstein distanceDistributions may have non-overlapping support (WGAN)Meaningful gradients even when KL is infinite; requires Lipschitz constraint
MMD (Maximum Mean Discrepancy)Kernel-based distribution matching (WAE)Non-parametric, no density estimation needed; sensitive to kernel choice
Reverse KLMode-seeking behaviour desired (RL, variational inference)Avoids covering low-density regions but may miss modes entirely

KL divergence was introduced by Solomon Kullback and Richard Leibler in 1951 as a measure of “information for discrimination” between hypotheses. It became central to variational inference through the work of Jordan, Ghahramani, Jaakkola, and Saul in the late 1990s, who showed that approximate Bayesian inference could be cast as minimising KL divergence between the approximate and true posterior.

The closed-form KL between Gaussians made VAEs (Kingma & Welling, 2014) practical — without it, the regularisation term would require sampling, adding variance to the gradient. Hinton’s knowledge distillation (2015) brought KL divergence into mainstream deep learning beyond generative models, and PPO (Schulman et al., 2017) made KL constraints a standard tool in reinforcement learning.