Skip to content

Straight-Through Estimator

Approximates the gradient of a non-differentiable discrete operation (like argmax, rounding, or quantisation) by pretending it was the identity function in the backward pass. The forward pass does the hard discrete operation; the backward pass passes gradients straight through as if nothing happened. Key to VQ-VAE and binary neural networks.

Imagine a staircase function: input goes up smoothly, output jumps between flat steps. The true gradient is zero almost everywhere (on the flat parts) and undefined at the jumps — completely useless for learning. The straight-through estimator says: “for the backward pass, pretend the staircase was a ramp.” The gradient of a ramp is 1 everywhere, so gradients flow through unchanged.

This is clearly wrong — the “gradient” you compute doesn’t match the actual function. But it’s wrong in a useful way. It tells the upstream network: “if you increased your output a little, the loss would change by this much, assuming the discrete operation didn’t interfere.” In practice, the parameters adjust to put the continuous values closer to the discrete values they’ll snap to, and training works surprisingly well.

The key insight: a biased gradient that points roughly in the right direction is far more useful than the true gradient of zero. VQ-VAE (van den Oord et al., 2017) relies entirely on this trick — the encoder outputs a continuous vector, it gets snapped to the nearest codebook entry, and gradients flow back through the snap as if it never happened.

Forward pass — apply a non-differentiable function qq:

zq=q(ze)(e.g. nearest-neighbour lookup, rounding, argmax)z_q = q(z_e) \quad \text{(e.g. nearest-neighbour lookup, rounding, argmax)}

True gradient — zero almost everywhere:

zqze=0(a.e.)\frac{\partial z_q}{\partial z_e} = 0 \quad \text{(a.e.)}

STE approximation — replace the backward pass with identity:

zqzeILzeLzq\frac{\partial z_q}{\partial z_e} \approx I \quad \Longrightarrow \quad \frac{\partial \mathcal{L}}{\partial z_e} \approx \frac{\partial \mathcal{L}}{\partial z_q}

VQ-VAE applicationzez_e is the encoder output, eke_k is the nearest codebook vector:

zq=ekwherek=argminjzeej2z_q = e_k \quad \text{where} \quad k = \arg\min_j \|z_e - e_j\|^2

Forward: zq=ek,Backward: zeL=zqL\text{Forward: } z_q = e_k, \qquad \text{Backward: } \nabla_{z_e} \mathcal{L} = \nabla_{z_q} \mathcal{L}

The codebook vectors eke_k are updated separately via an EMA or a commitment loss, not through the STE.

import torch
# ── The core STE pattern in PyTorch ─────────────────────────────
# The trick: z_q = z_e + (quantised - z_e).detach()
# Forward: z_q = quantised (because z_e - z_e = 0, then + quantised)
# Backward: ∂z_q/∂z_e = 1 (because quantised.detach() is a constant)
z_e = encoder(x) # (B, D) continuous
distances = torch.cdist(z_e, codebook.weight) # (B, K)
indices = distances.argmin(dim=-1) # (B,) nearest codes
z_q = codebook.weight[indices] # (B, D) discrete
# STE: copy gradients from z_q to z_e
z_q_st = z_e + (z_q - z_e).detach() # (B, D)
# z_q_st has the VALUE of z_q but the GRADIENT path of z_e
reconstruction = decoder(z_q_st) # gradients flow to encoder
# ── For simple rounding / binarisation ──────────────────────────
x_hard = torch.round(x_soft) # no gradient
x_st = x_soft + (x_hard - x_soft).detach() # STE version
# WARNING: the STE gradient is biased. If the encoder and codebook
# drift apart, the approximation degrades. VQ-VAE uses a commitment
# loss (β‖z_e - z_q.detach()‖²) to keep them close.
import numpy as np
def straight_through_quantise(z_e, codebook):
"""
VQ-VAE style quantisation with straight-through gradient.
z_e: (B, D) continuous encoder outputs
codebook: (K, D) codebook vectors
Returns: z_q (B, D) quantised values (forward = discrete, backward = identity)
indices (B,) which codebook entry was selected
"""
B, D = z_e.shape
K = codebook.shape[0]
# Find nearest codebook entry for each encoder output
# ‖z_e - e_k‖² = ‖z_e‖² + ‖e_k‖² - 2·z_e·e_k^T
dist = (
np.sum(z_e ** 2, axis=1, keepdims=True) # (B, 1)
+ np.sum(codebook ** 2, axis=1, keepdims=True).T # (1, K)
- 2 * z_e @ codebook.T # (B, K)
) # (B, K)
indices = np.argmin(dist, axis=1) # (B,)
z_q = codebook[indices] # (B, D)
# STE: in a real backward pass, we'd set ∂z_q/∂z_e = I
# In numpy (no autograd), this means: when computing upstream
# gradients, treat z_q as if it were z_e.
# grad_z_e = grad_z_q (copy gradient unchanged)
return z_q, indices
def ste_backward(grad_output):
"""
The STE backward pass is literally the identity.
grad_output: (B, D) gradient flowing into the quantisation
Returns: (B, D) gradient flowing to the encoder — unchanged
"""
return grad_output # that's it. That's the whole trick.
  • VQ-VAE / VQ-VAE-2 (van den Oord et al.): quantise continuous encoder outputs to discrete codebook entries — the STE is the only way gradients reach the encoder (see variational-inference-vae/)
  • Binary / ternary neural networks (BinaryConnect, XNOR-Net): weights are binarised to {-1, +1} in the forward pass, STE passes gradients to the full-precision latent weights
  • Hard attention mechanisms: argmax selection of attention positions, with STE for gradient flow
  • Neural discrete representation learning (dVAE in DALL-E 1): discrete tokens for image generation use STE or Gumbel-softmax
  • Learned quantisation (neural compression): differentiable rounding for entropy coding
AlternativeWhen to useTradeoff
Gumbel-softmaxCategorical selection with smoother gradientsLower bias than STE (approaches true gradient as τ→0) but requires a temperature schedule and is a soft approximation during training
Reparameterisation trickContinuous latent variablesExact gradients, zero bias — but only works for continuous distributions, not discrete operations
REINFORCE / score functionAny discrete operation, unbiased gradients neededUnbiased but extremely high variance; impractical for high-dimensional discrete spaces like codebooks
EMA codebook updateUpdating codebook vectors (used alongside STE)Avoids backprop through codebook entirely; more stable than gradient-based codebook updates
Finite differencesDebugging, gradient checkingUnbiased but scales as O(D) per parameter — only useful for verification, never training

The straight-through estimator was introduced by Hinton in his 2012 Coursera lectures as a practical trick for training networks with discrete hidden units. Bengio, Leonard & Courville (2013, “Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation”) formalised and analysed it, showing that despite its bias, it worked well in practice for binary and discrete stochastic neurons.

Its most influential application came with VQ-VAE (van den Oord et al., 2017), which used STE to train a discrete autoencoder that produced high-quality codebook representations. This architecture became foundational — DALL-E 1 used a discrete VAE (dVAE) with Gumbel-softmax relaxation as an alternative, and later work on audio (SoundStream, Encodec) and language-image models relied on VQ codebooks with STE-based training.