Reparameterisation Trick
Reparameterisation Trick
Section titled “Reparameterisation Trick”Rewrites a sample from a parameterised distribution as a deterministic function of the parameters plus fixed noise: where . This moves the randomness out of the computational graph so gradients can flow through sampling. The key enabler of VAE training (see variational-inference-vae/).
Intuition
Section titled “Intuition”Imagine you’re trying to optimise the position of a dartboard target. You throw darts that land with some random spread around where you aim. With the log-derivative trick, you’d ask “given where this dart landed, should I have aimed more to the left?” — reasoning from one noisy outcome at a time. High variance, slow learning.
The reparameterisation trick says: fix the random jitter of each throw in advance (like pre-rolling dice), then ask “if I shifted my aim slightly, where would this exact same jitter have landed?” Now you have a clean, deterministic function from aim to landing spot, and you can compute exact gradients through it. The randomness is still there — you still average over many throws — but each individual gradient is much less noisy.
Concretely: instead of sampling (which has no gradient w.r.t. ), you sample and compute . Now and — clean, well-defined gradients. This only works for distributions where you can express sampling as a differentiable transform of fixed noise, which includes Gaussians, log-normals, and other location-scale families, but NOT categorical or Bernoulli distributions.
The problem — we want gradients through an expectation:
Sampling is not differentiable w.r.t. .
The reparameterisation — express where is parameter-free:
Now the expectation is over the fixed distribution :
The gradient moves inside the expectation because doesn’t depend on .
VAE application — (the decoder), :
Typically suffices during training — one sample per data point.
import torch
# ── Standard VAE reparameterisation ─────────────────────────────# The encoder outputs mu and log_var (NOT sigma — log-variance is# more numerically stable and can be any real number).
mu, log_var = encoder(x) # each (B, d_latent)std = torch.exp(0.5 * log_var) # (B, d_latent) — σ = exp(½ log σ²)eps = torch.randn_like(std) # (B, d_latent) — ε ~ N(0, I)z = mu + std * eps # (B, d_latent) — reparameterised sample
# z has gradients w.r.t. mu and log_var (and thus w.r.t. encoder params).# eps is a leaf tensor with no gradient — the randomness is "outside" the graph.
reconstruction = decoder(z) # (B, ...) — gradients flow all the way back
# WARNING: Never sample z = torch.normal(mu, std) — that detaches the gradient.# The whole point is z = mu + std * eps, keeping mu and std in the graph.
# ── Using torch.distributions (cleaner API) ────────────────────dist = torch.distributions.Normal(mu, std)z = dist.rsample() # "r" = reparameterised. Has gradients w.r.t. mu, std# NOT dist.sample() — that would detach, same as torch.normal()Manual Implementation
Section titled “Manual Implementation”import numpy as np
def reparameterise_and_grad(mu, log_var, decoder_fn, x_true): """ Forward pass + manual gradient of reconstruction loss w.r.t. mu. mu: (B, D) mean from encoder log_var: (B, D) log-variance from encoder decoder_fn: callable, maps z -> reconstruction x_true: (B, ...) target for MSE reconstruction loss Returns: z sample, reconstruction loss """ B, D = mu.shape std = np.exp(0.5 * log_var) # (B, D)
# The trick: sample noise ONCE, reuse for gradient computation eps = np.random.randn(B, D) # (B, D) — fixed noise z = mu + std * eps # (B, D) — differentiable w.r.t. mu, std
# Forward: decode and compute loss x_hat = decoder_fn(z) # (B, ...) recon_loss = np.mean((x_hat - x_true) ** 2)
# Gradients: ∂z/∂μ = 1, ∂z/∂σ = ε, ∂z/∂log_var = 0.5·σ·ε # So ∂L/∂μ = ∂L/∂z · 1 and ∂L/∂log_var = ∂L/∂z · 0.5·σ·ε # (The downstream gradient ∂L/∂z comes from the decoder + loss)
return z, recon_loss
def sample_gaussian_reparam(mu, log_var, n_samples=1): """ Pure reparameterised sampling (no loss computation). mu: (D,) or (B, D) log_var: (D,) or (B, D) """ std = np.exp(0.5 * log_var) eps = np.random.randn(*([n_samples] + list(mu.shape))) # (N, ..., D) return mu + std * eps # (N, ..., D)Popular Uses
Section titled “Popular Uses”- Variational autoencoders (VAE, β-VAE, CVAE): the reparameterisation trick is what makes VAE training practical — without it, you’d need REINFORCE with its high variance (see
variational-inference-vae/) - Latent diffusion models (Stable Diffusion): the KL-AE that compresses images to latent space uses reparameterised Gaussian sampling
- Normalising flows (RealNVP, Glow): sample base distribution noise, transform deterministically through invertible layers — same principle
- SAC (Soft Actor-Critic): reparameterised policy sampling allows direct gradient computation for the actor loss (see
q-learning/) - Stochastic neural networks: any model with learned stochastic layers (Bayesian NNs, stochastic depth) uses this to backprop through noise
Alternatives
Section titled “Alternatives”| Alternative | When to use | Tradeoff |
|---|---|---|
| Log-derivative trick (REINFORCE) | Discrete distributions, non-differentiable rewards | Works for any distribution but much higher variance — needs baselines and many samples |
| Gumbel-softmax | Categorical latent variables (discrete VAE) | Continuous relaxation of discrete sampling; introduces temperature-dependent bias |
| Straight-through estimator | Hard quantisation (VQ-VAE) | Zero variance, but biased — the backward pass approximates a non-differentiable forward pass |
| Implicit reparameterisation | Non-standard distributions (truncated, mixture) | Generalises the trick to distributions without simple location-scale form; more complex to implement |
| Pathwise gradient (deterministic) | No stochasticity needed (standard networks) | If you don’t need sampling, just backprop normally — the reparam trick is only needed when randomness is part of the model |
Historical Context
Section titled “Historical Context”The reparameterisation trick was introduced simultaneously by Kingma & Welling (2014, “Auto-Encoding Variational Bayes”) and Rezende, Mohamed & Wierstra (2014, “Stochastic Backpropagation and Approximate Inference in Deep Generative Models”). Both papers showed that reparameterised gradients had dramatically lower variance than REINFORCE-style estimators, making it practical to train deep generative models with latent variables.
The idea itself — expressing a random variable as a deterministic function of fixed noise — was known in statistics as the “non-centred parameterisation” and in simulation as “common random numbers.” The contribution of the VAE papers was recognising that this old trick, combined with amortised inference (an encoder network), scaled to high-dimensional deep learning. It has since become the default approach for any model requiring gradients through continuous stochastic nodes.