"""
Unified Variational Autoencoder (VAE) Algorithm
=================================================
A single skeleton covering: vanilla VAE, β-VAE, VQ-VAE,
conditional VAE (CVAE), and the VAE as used in latent diffusion
(Stable Diffusion's image compressor).
The core idea shared by ALL VAE variants:
1. ENCODE: x → q(z|x) (compress input to a distribution)
2. SAMPLE: z ~ q(z|x) (draw a latent code)
3. DECODE: z → p(x|z) (reconstruct from the code)
4. LOSS: reconstruction + regularisation
loss = E_q[ −log p(x|z) ] + KL[ q(z|x) ‖ p(z) ]
├── reconstruction ──┘ └── regularisation ──┘
"how good is the "how close is the
reconstruction?" learned posterior to
the prior?"
The pluggable components are:
1. encode() — how the posterior q(z|x) is parameterised
2. sample_latent() — how z is drawn (reparameterisation, codebook, ...)
3. decode() — how p(x|z) is parameterised
4. reconstruction_loss() — MSE, BCE, perceptual, ...
5. regularisation_loss() — KL divergence, commitment loss, ...
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from abc import ABC, abstractmethod
# ═══════════════════════════════════════════════════════════════════
# CORE: THE REPARAMETERISATION TRICK
# ═══════════════════════════════════════════════════════════════════
#
# The fundamental problem: we need to SAMPLE z ~ q(z|x), but
# sampling is not differentiable — you can't backprop through
# a random draw.
#
# The trick: instead of sampling z ~ N(μ, σ²), compute
# z = μ + σ · ε, where ε ~ N(0, I)
#
# Now the randomness (ε) is external to the computation graph,
# and gradients flow through μ and σ normally.
#
# This single trick is what makes VAEs trainable. Without it,
# you'd need REINFORCE-style gradient estimators (high variance).
def reparameterise(mu, log_var):
"""Sample z = μ + σ·ε, where ε ~ N(0, I). Differentiable in μ, σ."""
std = (0.5 * log_var).exp()
eps = torch.randn_like(std)
return mu + std * eps
# ═══════════════════════════════════════════════════════════════════
# CORE: KL DIVERGENCE (analytic for two Gaussians)
# ═══════════════════════════════════════════════════════════════════
#
# When both q(z|x) = N(μ, σ²) and p(z) = N(0, I) are Gaussian,
# the KL has a closed form — no sampling needed:
#
# KL = -½ Σ (1 + log σ² − μ² − σ²)
#
# This is per-sample; we average over the batch.
def kl_divergence_gaussian(mu, log_var):
"""KL[ N(μ, σ²) ‖ N(0, I) ], averaged over batch."""
return -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum(dim=-1).mean()
# ═══════════════════════════════════════════════════════════════════
# CORE ALGORITHM (the part that NEVER changes)
# ═══════════════════════════════════════════════════════════════════
class VAEAlgorithm(ABC):
"""
The universal VAE training step.
Every variant inherits this and only overrides:
- encode(x) → latent params
- sample_latent(latent_params) → z
- decode(z) → reconstruction
- reconstruction_loss(x_recon, x) → scalar
- regularisation_loss(latent_params) → scalar
"""
def __init__(self, encoder, decoder, optimizer):
self.encoder = encoder
self.decoder = decoder
self.optimizer = optimizer
# ── The pluggable pieces ──────────────────────────────────────
@abstractmethod
def encode(self, x):
"""Return latent parameters (e.g. mu, log_var)."""
...
@abstractmethod
def sample_latent(self, latent_params):
"""Sample z from the encoded distribution. Must be differentiable."""
...
@abstractmethod
def decode(self, z):
"""Reconstruct x from z."""
...
def reconstruction_loss(self, x_recon, x):
"""Default: MSE. Override for BCE, perceptual loss, etc."""
return F.mse_loss(x_recon, x, reduction="mean")
@abstractmethod
def regularisation_loss(self, latent_params):
"""Return the regularisation term (KL, commitment, ...)."""
...
# ── Core training step (IDENTICAL for every variant) ──────────
def train_step(self, x, condition=None):
"""
THE core VAE training loop:
1. Encode → distribution over z
2. Sample z (reparameterised — differentiable)
3. Decode → reconstruction
4. Loss = reconstruction + regularisation
"""
# 1. Encode
latent_params = self.encode(x)
# 2. Sample (must be differentiable)
z = self.sample_latent(latent_params)
# 3. Decode
x_recon = self.decode(z) if condition is None else self.decode(z, condition)
# 4. Loss (both terms are PLUGGABLE)
loss_recon = self.reconstruction_loss(x_recon, x)
loss_reg = self.regularisation_loss(latent_params)
loss = loss_recon + loss_reg
# 5. Gradient step (always the same)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return {"loss": loss.item(),
"recon": loss_recon.item(),
"reg": loss_reg.item()}
# ── Generation (decode from prior) ────────────────────────────
@torch.no_grad()
def generate(self, n_samples, device="cpu"):
"""Sample z ~ p(z), then decode."""
z = self.sample_prior(n_samples, device)
return self.decode(z)
@abstractmethod
def sample_prior(self, n_samples, device):
"""Sample from the prior p(z)."""
...
# ═══════════════════════════════════════════════════════════════════
# TRAINING LOOP
# ═══════════════════════════════════════════════════════════════════
def train(algo: VAEAlgorithm, dataloader, n_epochs, device="cpu"):
for epoch in range(n_epochs):
totals = {"loss": 0, "recon": 0, "reg": 0}
n = 0
for x, *rest in dataloader:
x = x.to(device)
metrics = algo.train_step(x)
for k in totals:
totals[k] += metrics[k] * x.size(0)
n += x.size(0)
avg = {k: v / n for k, v in totals.items()}
print(f"Epoch {epoch+1:3d}/{n_epochs} │ "
f"loss {avg['loss']:.4f} "
f"recon {avg['recon']:.4f} "
f"reg {avg['reg']:.4f}")
# ═══════════════════════════════════════════════════════════════════
# VARIANT IMPLEMENTATIONS (only the parts that differ)
# ═══════════════════════════════════════════════════════════════════
# ── 1. Vanilla VAE (Gaussian posterior, Gaussian prior) ─────────
class VanillaVAE(VAEAlgorithm):
"""
The original VAE (Kingma & Welling 2013).
Encoder outputs (μ, log σ²), sample via reparameterisation,
decode, and balance reconstruction vs KL divergence.
The ELBO (Evidence Lower Bound) is:
log p(x) ≥ E_q[log p(x|z)] − KL[q(z|x) ‖ p(z)]
= −reconstruction − KL
Maximising the ELBO = minimising our loss.
"""
def __init__(self, encoder, decoder, optimizer, d_latent):
super().__init__(encoder, decoder, optimizer)
self.d_latent = d_latent
def encode(self, x):
h = self.encoder(x) # (B, d_enc)
# Split into μ and log σ² (encoder's last layer outputs 2×d_latent)
mu, log_var = h.chunk(2, dim=-1)
return (mu, log_var)
def sample_latent(self, latent_params):
mu, log_var = latent_params
return reparameterise(mu, log_var)
def decode(self, z, condition=None):
return self.decoder(z)
def regularisation_loss(self, latent_params):
mu, log_var = latent_params
return kl_divergence_gaussian(mu, log_var)
def sample_prior(self, n_samples, device):
return torch.randn(n_samples, self.d_latent, device=device)
# ── 2. β-VAE (disentangled representations) ────────────────────
class BetaVAE(VanillaVAE):
"""
Identical to vanilla VAE but scales the KL term by β.
β > 1: stronger pressure to use a simple, factorised posterior.
Each latent dimension is pushed to be independent, encouraging
"disentangled" representations where individual dimensions
correspond to meaningful factors of variation (e.g. size,
rotation, colour).
β < 1: weaker regularisation, better reconstruction,
but less structured latent space.
The tradeoff: β controls reconstruction quality vs latent
structure. β=1 is the original VAE (statistically principled).
"""
def __init__(self, *args, beta=4.0, **kw):
super().__init__(*args, **kw)
self.beta = beta
def regularisation_loss(self, latent_params):
mu, log_var = latent_params
return self.beta * kl_divergence_gaussian(mu, log_var)
# ── 3. VQ-VAE (discrete latent codes via vector quantisation) ───
class VectorQuantiser(nn.Module):
"""
Replaces continuous z with the nearest vector from a learned codebook.
Given encoder output z_e, find the closest codebook entry:
z_q = codebook[argmin ‖z_e − e_k‖]
Not differentiable (argmin), so we use the straight-through estimator:
forward pass uses z_q, backward pass pretends z_q = z_e.
"""
def __init__(self, n_codes, d_code):
super().__init__()
self.n_codes = n_codes
self.d_code = d_code
self.codebook = nn.Embedding(n_codes, d_code)
# Uniform init
self.codebook.weight.data.uniform_(-1 / n_codes, 1 / n_codes)
def forward(self, z_e):
"""
Args:
z_e: (B, ..., d_code) — encoder output (any spatial shape)
Returns:
z_q: (B, ..., d_code) — quantised (straight-through)
indices: (B, ...) — codebook indices
commit_loss: scalar — commitment loss
"""
flat = z_e.reshape(-1, self.d_code) # (N, d)
# Distances to all codebook entries
dist = (flat.pow(2).sum(1, keepdim=True)
- 2 * flat @ self.codebook.weight.T
+ self.codebook.weight.pow(2).sum(1, keepdim=True).T)
indices = dist.argmin(dim=-1) # (N,)
z_q = self.codebook(indices).view_as(z_e) # (B, ..., d)
# Losses (see VQ-VAE variant for how these are used)
codebook_loss = F.mse_loss(z_q.detach(), z_e) # move codes → encoder
commit_loss = F.mse_loss(z_q, z_e.detach()) # move encoder → codes
# Straight-through: forward uses z_q, backward uses z_e
z_q_st = z_e + (z_q - z_e).detach()
return z_q_st, indices, codebook_loss + 0.25 * commit_loss
class VQVAE(VAEAlgorithm):
"""
VQ-VAE: replaces the Gaussian posterior with discrete codes.
Instead of z ~ N(μ, σ²), the encoder output is snapped to the
nearest entry in a learned codebook. No KL divergence — the
regularisation comes from the information bottleneck of
discrete codes.
This is the architecture behind:
• DALL-E 1 (image tokens for autoregressive generation)
• Stable Diffusion's VAE (continuous variant with KL, but same idea)
• AudioLM, SoundStream (audio tokenisation)
"""
def __init__(self, encoder, decoder, optimizer, quantiser: VectorQuantiser):
super().__init__(encoder, decoder, optimizer)
self.quantiser = quantiser
def encode(self, x):
z_e = self.encoder(x) # (B, d_code) or (B, H, W, d)
return (z_e,) # tuple for consistency
def sample_latent(self, latent_params):
z_e, = latent_params
z_q, indices, vq_loss = self.quantiser(z_e)
# Stash the VQ loss for use in regularisation_loss
self._vq_loss = vq_loss
self._indices = indices
return z_q
def decode(self, z, condition=None):
return self.decoder(z)
def regularisation_loss(self, latent_params):
# No KL — regularisation is the VQ commitment loss
return self._vq_loss
def sample_prior(self, n_samples, device):
# Sample random codebook indices, look up embeddings
indices = torch.randint(0, self.quantiser.n_codes,
(n_samples,), device=device)
return self.quantiser.codebook(indices)
# ── 4. Conditional VAE (generation conditioned on a label/text) ─
class CVAE(VanillaVAE):
"""
Conditional VAE: both encoder and decoder receive a condition c
(e.g. class label, text embedding).
Encoder: q(z | x, c) — "what latent code explains x given c?"
Decoder: p(x | z, c) — "generate x from z and c"
Prior: p(z) = N(0, I) — same as vanilla (or can be conditioned too)
At generation time: choose c, sample z ~ p(z), decode(z, c).
"""
def __init__(self, encoder, decoder, optimizer, d_latent,
cond_encoder=None):
super().__init__(encoder, decoder, optimizer, d_latent)
self.cond_encoder = cond_encoder # optional: embed raw labels → vectors
def train_step(self, x, condition=None):
if self.cond_encoder is not None and condition is not None:
condition = self.cond_encoder(condition)
# Encode with condition concatenated
latent_params = self.encode(torch.cat([x, condition], dim=-1)
if condition is not None else x)
z = self.sample_latent(latent_params)
# Decode with condition
z_cond = torch.cat([z, condition], dim=-1) if condition is not None else z
x_recon = self.decoder(z_cond)
loss_recon = self.reconstruction_loss(x_recon, x)
loss_reg = self.regularisation_loss(latent_params)
loss = loss_recon + loss_reg
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return {"loss": loss.item(),
"recon": loss_recon.item(),
"reg": loss_reg.item()}
# ── 5. KL-regularised autoencoder (Stable Diffusion's VAE) ─────
class KLAE(VanillaVAE):
"""
The image compressor used in Latent Diffusion / Stable Diffusion.
Structurally a VAE, but with two important differences:
• Very small KL weight (≈1e-6) — almost a plain autoencoder,
just enough regularisation to keep the latent space smooth
• Perceptual + adversarial reconstruction loss instead of MSE,
producing much sharper reconstructions
The purpose is not generation (diffusion handles that) but
COMPRESSION: map 512×512×3 images → 64×64×4 latents, making
the diffusion model 64× cheaper to train and run.
"""
def __init__(self, encoder, decoder, optimizer, d_latent,
kl_weight=1e-6, perceptual_loss_fn=None, disc=None,
disc_optimizer=None):
super().__init__(encoder, decoder, optimizer, d_latent)
self.kl_weight = kl_weight
self.perceptual_fn = perceptual_loss_fn # e.g. LPIPS
self.disc = disc # patch discriminator
self.disc_optimizer = disc_optimizer
def regularisation_loss(self, latent_params):
mu, log_var = latent_params
return self.kl_weight * kl_divergence_gaussian(mu, log_var)
def reconstruction_loss(self, x_recon, x):
loss = F.mse_loss(x_recon, x)
if self.perceptual_fn is not None:
loss = loss + self.perceptual_fn(x_recon, x)
if self.disc is not None:
# Generator wants discriminator to think reconstruction is real
loss = loss + (-self.disc(x_recon)).mean()
return loss
def train_step(self, x, condition=None):
metrics = super().train_step(x)
# Optional: update discriminator
if self.disc is not None and self.disc_optimizer is not None:
with torch.no_grad():
latent_params = self.encode(x)
z = self.sample_latent(latent_params)
x_recon = self.decode(z)
real_score = self.disc(x)
fake_score = self.disc(x_recon)
disc_loss = F.relu(1 - real_score).mean() + F.relu(1 + fake_score).mean()
self.disc_optimizer.zero_grad()
disc_loss.backward()
self.disc_optimizer.step()
metrics["disc_loss"] = disc_loss.item()
return metrics