"""
Unified Generative Adversarial Networks (GANs): Core Algorithm
================================================================
A single skeleton covering: vanilla GAN, DCGAN, WGAN, WGAN-GP,
StyleGAN (conceptual), conditional GAN (cGAN), and pix2pix-style
paired translation.
The core idea shared by ALL GANs:
Two networks play a minimax game:
GENERATOR (G): z → fake_data (tries to fool D)
DISCRIMINATOR (D): data → real_or_fake (tries to catch G)
G wants D to output "real" for fakes.
D wants to output "real" for real data, "fake" for G's output.
At equilibrium: G produces data indistinguishable from real,
and D can't tell the difference (outputs 0.5 for everything).
The pluggable components are:
1. generator_loss() — how G is penalised (minimax, non-saturating, Wasserstein)
2. discriminator_loss() — how D is penalised (BCE, hinge, Wasserstein)
3. regularisation() — gradient penalty, spectral norm, R1, ...
4. conditioning — whether/how class labels or inputs are injected
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import ABC, abstractmethod
# ═══════════════════════════════════════════════════════════════════
# CORE ALGORITHM (the part that NEVER changes)
# ═══════════════════════════════════════════════════════════════════
#
# Every GAN variant follows this exact two-step loop:
#
# Step 1: Update D — make it better at telling real from fake
# Step 2: Update G — make it better at fooling D
#
# The adversarial dynamic IS the algorithm. Unlike every other
# file in this series, there are TWO networks with OPPOSING losses
# trained in alternation. This is what makes GANs both powerful
# and unstable.
class GANAlgorithm(ABC):
"""
The universal GAN training step.
Every variant inherits this and only overrides:
- discriminator_loss(real_scores, fake_scores) → scalar
- generator_loss(fake_scores) → scalar
- discriminator_regularisation(D, real, fake) → scalar (optional)
"""
def __init__(self, G, D, g_optimizer, d_optimizer,
d_latent=128, n_d_steps=1):
self.G = G
self.D = D
self.g_optimizer = g_optimizer
self.d_optimizer = d_optimizer
self.d_latent = d_latent
self.n_d_steps = n_d_steps # D updates per G update (often 1, WGAN uses 5)
# ── The pluggable pieces ──────────────────────────────────────
@abstractmethod
def discriminator_loss(self, real_scores, fake_scores):
"""How D is trained. Returns scalar to MINIMISE."""
...
@abstractmethod
def generator_loss(self, fake_scores):
"""How G is trained. Returns scalar to MINIMISE."""
...
def discriminator_regularisation(self, D, real_data, fake_data):
"""Optional: gradient penalty, R1, etc. Default: none."""
return torch.tensor(0.0, device=real_data.device)
def sample_noise(self, batch_size, device):
"""Sample latent z ~ N(0, I). Override for different priors."""
return torch.randn(batch_size, self.d_latent, device=device)
# ── Core training step (IDENTICAL for every variant) ──────────
def train_step(self, real_data):
device = real_data.device
B = real_data.size(0)
# ── Step 1: Update discriminator ──────────────────────────
# D sees real data and fake data, learns to distinguish them.
d_loss_total = 0.0
for _ in range(self.n_d_steps):
z = self.sample_noise(B, device)
with torch.no_grad():
fake_data = self.G(z) # don't track G grads
real_scores = self.D(real_data) # D(x)
fake_scores = self.D(fake_data) # D(G(z))
d_loss = self.discriminator_loss(real_scores, fake_scores)
d_reg = self.discriminator_regularisation(self.D, real_data, fake_data)
d_loss_full = d_loss + d_reg
self.d_optimizer.zero_grad()
d_loss_full.backward()
self.d_optimizer.step()
d_loss_total += d_loss_full.item()
# ── Step 2: Update generator ──────────────────────────────
# G generates fakes and wants D to think they're real.
z = self.sample_noise(B, device)
fake_data = self.G(z)
fake_scores = self.D(fake_data) # D(G(z)), grads flow to G
g_loss = self.generator_loss(fake_scores)
self.g_optimizer.zero_grad()
g_loss.backward()
self.g_optimizer.step()
return {"g_loss": g_loss.item(),
"d_loss": d_loss_total / self.n_d_steps,
"d_real": real_scores.mean().item(),
"d_fake": fake_scores.mean().item()}
# ── Generation ────────────────────────────────────────────────
@torch.no_grad()
def generate(self, n_samples, device="cpu"):
z = self.sample_noise(n_samples, device)
return self.G(z)
# ═══════════════════════════════════════════════════════════════════
# TRAINING LOOP
# ═══════════════════════════════════════════════════════════════════
def train(algo: GANAlgorithm, dataloader, n_epochs, device="cpu"):
for epoch in range(n_epochs):
totals = {"g_loss": 0, "d_loss": 0, "d_real": 0, "d_fake": 0}
n = 0
for real, *_ in dataloader:
real = real.to(device)
metrics = algo.train_step(real)
for k in totals:
totals[k] += metrics[k] * real.size(0)
n += real.size(0)
avg = {k: v / n for k, v in totals.items()}
print(f"Epoch {epoch+1:3d}/{n_epochs} │ "
f"G {avg['g_loss']:.4f} "
f"D {avg['d_loss']:.4f} "
f"D(real) {avg['d_real']:.3f} "
f"D(fake) {avg['d_fake']:.3f}")
# ═══════════════════════════════════════════════════════════════════
# VARIANT IMPLEMENTATIONS (only the parts that differ)
# ═══════════════════════════════════════════════════════════════════
# ── 1. Vanilla GAN (original minimax / non-saturating) ──────────
class VanillaGAN(GANAlgorithm):
"""
The original GAN (Goodfellow et al. 2014).
Minimax objective:
D maximises: E[log D(x)] + E[log(1 − D(G(z)))]
G minimises: E[log(1 − D(G(z)))]
In practice, G minimises −E[log D(G(z))] instead (non-saturating).
The minimax form has vanishing gradients when G is bad — log(1−D(G(z)))
is flat near 0. The non-saturating form provides strong gradients
early in training when G most needs them.
"""
def discriminator_loss(self, real_scores, fake_scores):
real_loss = F.binary_cross_entropy_with_logits(
real_scores, torch.ones_like(real_scores))
fake_loss = F.binary_cross_entropy_with_logits(
fake_scores, torch.zeros_like(fake_scores))
return real_loss + fake_loss
def generator_loss(self, fake_scores):
# Non-saturating: −log D(G(z)) instead of log(1 − D(G(z)))
return F.binary_cross_entropy_with_logits(
fake_scores, torch.ones_like(fake_scores))
# ── 2. WGAN (Wasserstein distance, weight clipping) ─────────────
class WGAN(GANAlgorithm):
"""
Wasserstein GAN: replaces the JS divergence with the Wasserstein
(Earth Mover's) distance.
D becomes a "critic" — it outputs an unbounded score (not a
probability). The loss is simply the difference in mean scores:
D maximises: E[D(x)] − E[D(G(z))] (real scores > fake scores)
G minimises: −E[D(G(z))] (push fake scores up)
Requires D to be Lipschitz-continuous. Original WGAN enforces this
by clamping weights to [−c, c] after each update. This works but
is crude — the capacity of D is artificially limited.
"""
def __init__(self, *args, clip_value=0.01, **kw):
kw.setdefault("n_d_steps", 5) # WGAN uses 5 critic steps
super().__init__(*args, **kw)
self.clip_value = clip_value
def discriminator_loss(self, real_scores, fake_scores):
# Wasserstein: maximise E[D(x)] − E[D(G(z))]
# Minimise the negation:
return fake_scores.mean() - real_scores.mean()
def generator_loss(self, fake_scores):
return -fake_scores.mean()
def train_step(self, real_data):
metrics = super().train_step(real_data)
# Weight clipping to enforce Lipschitz constraint
with torch.no_grad():
for p in self.D.parameters():
p.clamp_(-self.clip_value, self.clip_value)
return metrics
# ── 3. WGAN-GP (gradient penalty instead of clipping) ───────────
class WGANGP(GANAlgorithm):
"""
WGAN with Gradient Penalty: same Wasserstein loss, but enforces
the Lipschitz constraint via a penalty on D's gradient norm.
The penalty is computed at random interpolations between real
and fake data: x̂ = αx + (1−α)G(z), penalise (‖∇D(x̂)‖ − 1)².
This is mathematically cleaner than weight clipping, allows
deeper/larger discriminators, and is much more stable.
"""
def __init__(self, *args, gp_weight=10.0, **kw):
kw.setdefault("n_d_steps", 5)
super().__init__(*args, **kw)
self.gp_weight = gp_weight
def discriminator_loss(self, real_scores, fake_scores):
return fake_scores.mean() - real_scores.mean()
def generator_loss(self, fake_scores):
return -fake_scores.mean()
def discriminator_regularisation(self, D, real_data, fake_data):
B = real_data.size(0)
alpha = torch.rand(B, *([1] * (real_data.dim() - 1)), device=real_data.device)
interpolated = (alpha * real_data + (1 - alpha) * fake_data).requires_grad_(True)
scores = D(interpolated)
grads = torch.autograd.grad(
outputs=scores, inputs=interpolated,
grad_outputs=torch.ones_like(scores),
create_graph=True, retain_graph=True)[0]
grad_norm = grads.reshape(B, -1).norm(2, dim=1)
penalty = ((grad_norm - 1) ** 2).mean()
return self.gp_weight * penalty
# ── 4. Hinge GAN (spectral norm + hinge loss) ──────────────────
class HingeGAN(GANAlgorithm):
"""
Hinge loss GAN: used in SAGAN, BigGAN, StyleGAN-XL.
D loss: E[max(0, 1 − D(x))] + E[max(0, 1 + D(G(z)))]
G loss: −E[D(G(z))]
The hinge loss saturates once D is "confident enough" (score > 1
for real, < −1 for fake), preventing D from becoming too strong
and starving G of gradients.
Typically paired with spectral normalisation on D's weights
(applied externally to the network architecture, not shown here).
"""
def discriminator_loss(self, real_scores, fake_scores):
return (F.relu(1 - real_scores).mean() +
F.relu(1 + fake_scores).mean())
def generator_loss(self, fake_scores):
return -fake_scores.mean()
# ── 5. Conditional GAN (cGAN: class-conditional generation) ─────
class ConditionalGAN(GANAlgorithm):
"""
Both G and D receive a condition (class label, text, etc.).
G(z, c) → fake: "generate a cat" vs "generate a dog"
D(x, c) → score: "is this a real cat?" not just "is this real?"
The condition can be injected by concatenation, projection,
or adaptive normalisation (class-conditional BatchNorm, as in
BigGAN and StyleGAN).
This class wraps any base loss variant — the conditioning
mechanism is orthogonal to the loss function choice.
"""
def __init__(self, G, D, g_optimizer, d_optimizer,
base_loss: GANAlgorithm, cond_encoder=None, **kw):
super().__init__(G, D, g_optimizer, d_optimizer, **kw)
self.base = base_loss
self.cond_encoder = cond_encoder
def discriminator_loss(self, real_scores, fake_scores):
return self.base.discriminator_loss(real_scores, fake_scores)
def generator_loss(self, fake_scores):
return self.base.generator_loss(fake_scores)
def train_step_conditional(self, real_data, conditions):
device = real_data.device
B = real_data.size(0)
if self.cond_encoder is not None:
c = self.cond_encoder(conditions)
else:
c = conditions
# ── Update D ──────────────────────────────────────────────
z = self.sample_noise(B, device)
with torch.no_grad():
fake_data = self.G(z, c)
real_scores = self.D(real_data, c)
fake_scores = self.D(fake_data, c)
d_loss = self.discriminator_loss(real_scores, fake_scores)
self.d_optimizer.zero_grad()
d_loss.backward()
self.d_optimizer.step()
# ── Update G ──────────────────────────────────────────────
z = self.sample_noise(B, device)
fake_data = self.G(z, c)
fake_scores = self.D(fake_data, c)
g_loss = self.generator_loss(fake_scores)
self.g_optimizer.zero_grad()
g_loss.backward()
self.g_optimizer.step()
return {"g_loss": g_loss.item(), "d_loss": d_loss.item(),
"d_real": real_scores.mean().item(),
"d_fake": fake_scores.mean().item()}
# ── 6. Pix2Pix-style (paired image translation) ────────────────
class PairedTranslationGAN(GANAlgorithm):
"""
Image-to-image translation with paired data (Pix2Pix).
G takes an input IMAGE (not noise) and produces an output image:
G(x_input) → x_output
D sees (input, output) pairs and judges whether the output is
real or generated. Typically a PatchGAN discriminator that
classifies overlapping patches rather than the whole image.
Loss adds a reconstruction term (L1) so G doesn't just produce
realistic-looking images that ignore the input.
"""
def __init__(self, G, D, g_optimizer, d_optimizer,
l1_weight=100.0, **kw):
super().__init__(G, D, g_optimizer, d_optimizer, **kw)
self.l1_weight = l1_weight
def discriminator_loss(self, real_scores, fake_scores):
real_loss = F.binary_cross_entropy_with_logits(
real_scores, torch.ones_like(real_scores))
fake_loss = F.binary_cross_entropy_with_logits(
fake_scores, torch.zeros_like(fake_scores))
return real_loss + fake_loss
def generator_loss(self, fake_scores):
return F.binary_cross_entropy_with_logits(
fake_scores, torch.ones_like(fake_scores))
def train_step_paired(self, input_data, target_data):
device = input_data.device
# G takes input, produces output (no noise)
with torch.no_grad():
fake_output = self.G(input_data)
# D sees (input, target) pairs
real_scores = self.D(torch.cat([input_data, target_data], dim=1))
fake_scores = self.D(torch.cat([input_data, fake_output], dim=1))
d_loss = self.discriminator_loss(real_scores, fake_scores)
self.d_optimizer.zero_grad()
d_loss.backward()
self.d_optimizer.step()
# G: adversarial + L1 reconstruction
fake_output = self.G(input_data)
fake_scores = self.D(torch.cat([input_data, fake_output], dim=1))
g_loss_adv = self.generator_loss(fake_scores)
g_loss_l1 = F.l1_loss(fake_output, target_data)
g_loss = g_loss_adv + self.l1_weight * g_loss_l1
self.g_optimizer.zero_grad()
g_loss.backward()
self.g_optimizer.step()
return {"g_loss": g_loss.item(), "d_loss": d_loss.item(),
"g_l1": g_loss_l1.item()}