Gradient Penalty
Gradient Penalty
Section titled “Gradient Penalty”Adds a penalty on the norm of the discriminator’s gradients with respect to its inputs: . Enforces the Lipschitz constraint required by Wasserstein GANs without weight clipping. The defining feature of WGAN-GP.
Intuition
Section titled “Intuition”The Wasserstein distance requires the discriminator (critic) to be 1-Lipschitz — its output can’t change faster than its input. Mathematically, this means everywhere. Rather than constraining the weights directly (clipping, spectral norm), gradient penalty enforces this constraint where it matters: in the input space, on the data the discriminator actually sees.
The penalty is computed on interpolated points — random points along straight lines between real and fake samples. The theory says the optimal WGAN critic has gradient norm exactly 1 almost everywhere along these lines, so the penalty targets rather than just . This two-sided penalty pushes gradients toward 1, not just below 1.
The cost: you must compute gradients of the discriminator output with respect to its input, which requires a second backward pass through the discriminator. This makes each step roughly 2-3x more expensive than standard training or spectral normalisation.
Gradient penalty term:
where is sampled along lines between real and fake data:
Full WGAN-GP discriminator loss:
Typical : 10 (from the original paper, rarely changed).
R1 penalty (simplified variant, used in StyleGAN):
Applied only on real data, penalises gradient magnitude (not the deviation from 1). Simpler and sufficient for non-Wasserstein GANs.
import torchimport torch.autograd as autograd
def gradient_penalty(discriminator, real, fake, device, lambda_gp=10.0): """ Compute WGAN-GP gradient penalty. real: (B, *) real samples fake: (B, *) generated samples (detached) """ B = real.size(0) # Random interpolation factor per sample eps = torch.rand(B, *([1] * (real.dim() - 1)), device=device) # (B, 1, ..., 1) interpolated = (eps * real + (1 - eps) * fake).requires_grad_(True) # (B, *)
d_out = discriminator(interpolated) # (B, 1)
# Compute gradients of D output w.r.t. interpolated input grads = autograd.grad( outputs=d_out, inputs=interpolated, grad_outputs=torch.ones_like(d_out), create_graph=True, # need second-order gradients for backprop retain_graph=True, )[0] # (B, *)
grads = grads.view(B, -1) # (B, flat_dim) penalty = ((grads.norm(2, dim=1) - 1) ** 2).mean() # scalar return lambda_gp * penalty
# WARNING: create_graph=True is essential — without it, the penalty# term won't receive gradients and the constraint won't be enforced.
# WARNING: Do NOT use batch normalisation in the discriminator with# gradient penalty. BatchNorm creates dependencies between samples in# a batch, making the per-sample gradient computation invalid.# Use layer norm or instance norm instead.Manual Implementation
Section titled “Manual Implementation”import numpy as np
def gradient_penalty_manual(D_fn, D_grad_fn, real, fake, lambda_gp=10.0): """ Compute WGAN-GP gradient penalty (forward only, no autograd). D_fn: discriminator forward function (x -> scalar per sample) D_grad_fn: function returning dD/dx at a given x (B, d) -> (B, d) real: (B, d) real samples fake: (B, d) fake samples """ B, d = real.shape eps = np.random.uniform(0, 1, size=(B, 1)) # (B, 1) interpolated = eps * real + (1 - eps) * fake # (B, d)
grads = D_grad_fn(interpolated) # (B, d) grad_norms = np.sqrt((grads ** 2).sum(axis=1) + 1e-12) # (B,) penalty = ((grad_norms - 1) ** 2).mean() # scalar return lambda_gp * penaltyPopular Uses
Section titled “Popular Uses”- WGAN-GP (Gulrajani et al. 2017): the defining application — replaced weight clipping as the standard way to enforce the Lipschitz constraint
- Progressive GAN (Karras et al. 2018): used GP for stable training during progressive resolution increases
- R1 penalty in StyleGAN (Karras et al. 2019): simplified gradient penalty on real data only, became standard for style-based generators
- Wasserstein autoencoders (WAE): gradient penalty on the discriminator in the latent space
- Domain adaptation: gradient penalty on domain discriminators to enforce smooth class boundaries
Alternatives
Section titled “Alternatives”| Alternative | When to use | Tradeoff |
|---|---|---|
| Spectral normalisation | Default for GAN discriminators | Much cheaper (no second backward pass); constrains weight space not input space |
| Weight clipping | Never (superseded) | Causes capacity underuse, vanishing/exploding gradients; strictly worse |
| R1 penalty | Non-Wasserstein GANs (StyleGAN) | Simpler — only penalises on real data, no interpolation; doesn’t enforce Lipschitz-1 exactly |
| Dragan penalty | Alternative gradient regularisation | Penalises around real data only with added noise; less theoretically motivated but sometimes effective |
| Consistency regularisation | When you want smooth D outputs for similar inputs | Penalises output differences directly rather than gradients; no second backward pass |
Historical Context
Section titled “Historical Context”Gradient penalty was introduced by Gulrajani et al. (2017, “Improved Training of Wasserstein GANs”) as a direct fix for the problems with weight clipping in the original WGAN (Arjovsky et al. 2017). Weight clipping caused the critic to learn very simple functions (weights pushed to the clip boundary), wasting model capacity. Gradient penalty solved this elegantly by enforcing the constraint in input space rather than weight space.
The paper’s choice of and interpolation between real and fake samples became standard without much subsequent tuning. However, the computational cost — roughly 2-3x per discriminator step due to the second backward pass — motivated the development of cheaper alternatives. Spectral normalisation (Miyato et al. 2018) largely replaced gradient penalty for standard GAN training, while the simpler R1 penalty (Mescheder et al. 2018) became the default for StyleGAN-family models. Gradient penalty remains important conceptually as the clearest implementation of input-space Lipschitz regularisation.