"""
Unified Policy Gradient Algorithm
===================================
A single skeleton that covers: REINFORCE, REINFORCE with baseline,
Vanilla Policy Gradient (VPG/A2C), PPO (clip), and PPO with
entropy bonus.
The core idea shared by ALL policy gradient methods:
∇J(θ) ≈ E[ Ψ · ∇log π(a|s) ]
where Ψ is some measure of "how good was this action."
Every variant just changes what Ψ is and how the gradient
is used.
The pluggable components are:
1. compute_advantages() — what Ψ is (returns, advantages, GAE, ...)
2. policy_loss() — how the gradient signal is shaped
(vanilla, clipped surrogate, ...)
3. value_loss() — how the value baseline is trained (if any)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import ABC, abstractmethod
from typing import NamedTuple
# ─── Shared Data Types ────────────────────────────────────────────
class Rollout(NamedTuple):
"""A batch of complete trajectories collected by the current policy."""
states: torch.Tensor # (T, *state_shape)
actions: torch.Tensor # (T,) or (T, action_dim)
rewards: torch.Tensor # (T,)
dones: torch.Tensor # (T,) — 1.0 at episode boundaries
log_probs: torch.Tensor # (T,) — log π_old(a|s) at collection time
values: torch.Tensor # (T,) — V(s) estimates (zeros if no baseline)
# ─── Utility: Generalised Advantage Estimation (GAE) ──────────────
#
# GAE is the standard way to compute advantages in modern policy
# gradient methods. It interpolates between:
# λ=0 → TD(0) advantage: A = r + γV(s') - V(s) (low variance, high bias)
# λ=1 → Monte Carlo: A = R_t - V(s) (high variance, low bias)
#
# Nearly everyone uses λ=0.95, γ=0.99.
def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
T = len(rewards)
advantages = torch.zeros(T)
gae = 0.0
for t in reversed(range(T)):
next_val = values[t + 1] if t + 1 < T else 0.0
delta = rewards[t] + gamma * next_val * (1 - dones[t]) - values[t]
gae = delta + gamma * lam * (1 - dones[t]) * gae
advantages[t] = gae
returns = advantages + values # A = R - V, so R = A + V
return advantages, returns
# ─── Utility: Simple discounted returns (no baseline) ─────────────
def compute_discounted_returns(rewards, dones, gamma=0.99):
T = len(rewards)
returns = torch.zeros(T)
R = 0.0
for t in reversed(range(T)):
R = rewards[t] + gamma * R * (1 - dones[t])
returns[t] = R
return returns
# ═══════════════════════════════════════════════════════════════════
# CORE ALGORITHM (the part that NEVER changes)
# ═══════════════════════════════════════════════════════════════════
class PolicyGradient(ABC):
"""
The universal policy gradient training step.
Every variant inherits this and only overrides:
- compute_advantages(rollout) -> (advantages, returns)
- policy_loss(log_probs, old_log_probs, advantages) -> Tensor
- value_loss(values, returns) -> Tensor (optional)
"""
def __init__(self, policy, optimizer, entropy_coeff=0.01,
value_net=None, value_optimizer=None):
self.policy = policy
self.optimizer = optimizer
self.entropy_coeff = entropy_coeff
self.value_net = value_net
self.value_optimizer = value_optimizer
# ── The three pluggable pieces ────────────────────────────────
@abstractmethod
def compute_advantages(self, rollout: Rollout):
"""Return (advantages, returns) tensors of shape (T,)."""
...
@abstractmethod
def policy_loss(self, log_probs, old_log_probs, advantages):
"""Return the scalar policy loss to minimise."""
...
def value_loss(self, values, returns):
"""Default: MSE. Override if you want clipped value loss, etc."""
return F.mse_loss(values, returns)
# ── Core update step (IDENTICAL for every variant) ────────────
def update(self, rollout: Rollout):
advantages, returns = self.compute_advantages(rollout)
# Normalise advantages (nearly universal, stabilises training)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# ── Policy update ─────────────────────────────────────────
dist = self.policy(rollout.states) # action distribution
log_probs = dist.log_prob(rollout.actions) # (T,)
entropy = dist.entropy().mean() # scalar
loss_pi = self.policy_loss(log_probs, rollout.log_probs, advantages)
loss_pi = loss_pi - self.entropy_coeff * entropy # encourage exploration
self.optimizer.zero_grad()
loss_pi.backward()
nn.utils.clip_grad_norm_(self.policy.parameters(), max_norm=0.5)
self.optimizer.step()
# ── Value update (if we have a value network) ─────────────
loss_v = torch.tensor(0.0)
if self.value_net is not None and self.value_optimizer is not None:
v = self.value_net(rollout.states).squeeze(-1) # (T,)
loss_v = self.value_loss(v, returns.detach())
self.value_optimizer.zero_grad()
loss_v.backward()
nn.utils.clip_grad_norm_(self.value_net.parameters(), max_norm=0.5)
self.value_optimizer.step()
return {"policy_loss": loss_pi.item(),
"value_loss": loss_v.item(),
"entropy": entropy.item()}
# ═══════════════════════════════════════════════════════════════════
# TRAINING LOOP (collect rollout → update → repeat)
# ═══════════════════════════════════════════════════════════════════
#
# Key difference from Q-learning: policy gradient methods are
# ON-POLICY — you must collect fresh data with the CURRENT policy,
# use it for one (or a few) updates, then throw it away.
# (This is why they're less sample-efficient than Q-learning.)
def collect_rollout(policy, value_net, env, n_steps, device="cpu"):
"""Roll out the current policy in the environment for n_steps."""
states, actions, rewards, dones, log_probs, values = [], [], [], [], [], []
s, _ = env.reset()
for _ in range(n_steps):
s_t = torch.tensor(s, dtype=torch.float32, device=device)
with torch.no_grad():
dist = policy(s_t.unsqueeze(0))
a = dist.sample()
lp = dist.log_prob(a)
v = value_net(s_t.unsqueeze(0)).squeeze() if value_net else torch.tensor(0.0)
s_next, r, done, trunc, _ = env.step(a.squeeze(0).cpu().numpy())
states.append(s_t)
actions.append(a.squeeze(0))
rewards.append(r)
dones.append(float(done or trunc))
log_probs.append(lp.squeeze(0))
values.append(v)
s = s_next
if done or trunc:
s, _ = env.reset()
return Rollout(
states=torch.stack(states),
actions=torch.stack(actions),
rewards=torch.tensor(rewards),
dones=torch.tensor(dones),
log_probs=torch.stack(log_probs),
values=torch.stack(values),
)
def train(algo: PolicyGradient, policy, value_net, env,
n_iters=500, rollout_len=2048, device="cpu"):
for i in range(n_iters):
rollout = collect_rollout(policy, value_net, env, rollout_len, device)
metrics = algo.update(rollout)
if (i + 1) % 10 == 0:
print(f"Iter {i+1:4d} │ "
f"π loss {metrics['policy_loss']:+.4f} "
f"V loss {metrics['value_loss']:.4f} "
f"entropy {metrics['entropy']:.3f}")
# ═══════════════════════════════════════════════════════════════════
# VARIANT IMPLEMENTATIONS (only the parts that differ)
# ═══════════════════════════════════════════════════════════════════
# ── 1. REINFORCE (vanilla, no baseline) ──────────────────────────
class REINFORCE(PolicyGradient):
"""
Ψ = G_t (discounted return from time t)
The simplest possible policy gradient. High variance because the
raw return includes reward from the entire episode.
"""
def __init__(self, policy, optimizer, gamma=0.99, **kw):
super().__init__(policy, optimizer, **kw)
self.gamma = gamma
def compute_advantages(self, rollout):
returns = compute_discounted_returns(rollout.rewards, rollout.dones, self.gamma)
# No baseline, so advantages = returns
return returns, returns
def policy_loss(self, log_probs, old_log_probs, advantages):
# Classic REINFORCE: −E[G_t · log π(a|s)]
return -(log_probs * advantages).mean()
# ── 2. REINFORCE with learned baseline ───────────────────────────
class REINFORCEBaseline(PolicyGradient):
"""
Ψ = G_t − V(s_t)
Subtracting a baseline (the value function) doesn't change the
expected gradient but dramatically reduces variance, because now
the signal is "how much BETTER was this action than average"
rather than "how good was the entire trajectory."
"""
def __init__(self, policy, optimizer, gamma=0.99, **kw):
super().__init__(policy, optimizer, **kw)
self.gamma = gamma
def compute_advantages(self, rollout):
returns = compute_discounted_returns(rollout.rewards, rollout.dones, self.gamma)
advantages = returns - rollout.values
return advantages, returns
def policy_loss(self, log_probs, old_log_probs, advantages):
return -(log_probs * advantages.detach()).mean()
# ── 3. A2C (Advantage Actor-Critic) ─────────────────────────────
class A2C(PolicyGradient):
"""
Ψ = GAE(δ_t) where δ_t = r + γV(s') − V(s)
Uses GAE instead of full Monte Carlo returns for the advantage.
This reduces variance further at the cost of some bias (controlled
by λ). Otherwise identical to REINFORCE with baseline.
"""
def __init__(self, policy, optimizer, gamma=0.99, lam=0.95, **kw):
super().__init__(policy, optimizer, **kw)
self.gamma = gamma
self.lam = lam
def compute_advantages(self, rollout):
return compute_gae(rollout.rewards, rollout.values,
rollout.dones, self.gamma, self.lam)
def policy_loss(self, log_probs, old_log_probs, advantages):
return -(log_probs * advantages.detach()).mean()
# ── 4. PPO (Proximal Policy Optimisation, clipped) ──────────────
class PPO(PolicyGradient):
"""
Same advantages as A2C (GAE), but changes how the gradient is used.
Problem: vanilla policy gradient takes one big step, which can
destroy the policy (performance collapses and never recovers).
Solution: instead of −log π · A, use a clipped surrogate objective
that prevents the policy ratio π/π_old from moving too far from 1.
L = min( ratio · A, clip(ratio, 1−ε, 1+ε) · A )
This is the key insight of PPO: constrain the update size without
the complexity of TRPO's KL constraint.
"""
def __init__(self, policy, optimizer, gamma=0.99, lam=0.95,
clip_eps=0.2, n_policy_epochs=4, minibatch_size=64, **kw):
super().__init__(policy, optimizer, **kw)
self.gamma = gamma
self.lam = lam
self.clip_eps = clip_eps
self.n_policy_epochs = n_policy_epochs
self.minibatch_size = minibatch_size
def compute_advantages(self, rollout):
return compute_gae(rollout.rewards, rollout.values,
rollout.dones, self.gamma, self.lam)
def policy_loss(self, log_probs, old_log_probs, advantages):
ratio = (log_probs - old_log_probs).exp() # π_new / π_old
clipped = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps)
return -torch.min(ratio * advantages, clipped * advantages).mean()
# ── PPO overrides update() to do multiple epochs on the same data ──
def update(self, rollout: Rollout):
advantages, returns = self.compute_advantages(rollout)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# PPO reuses the same rollout for several epochs of minibatch updates.
# This is what makes it more sample-efficient than vanilla PG, while
# the clipping keeps each step safe.
T = len(rollout.states)
metrics = {"policy_loss": 0, "value_loss": 0, "entropy": 0}
n_updates = 0
for _ in range(self.n_policy_epochs):
indices = torch.randperm(T)
for start in range(0, T, self.minibatch_size):
idx = indices[start:start + self.minibatch_size]
dist = self.policy(rollout.states[idx])
log_probs = dist.log_prob(rollout.actions[idx])
entropy = dist.entropy().mean()
# Policy loss (clipped surrogate)
loss_pi = self.policy_loss(
log_probs, rollout.log_probs[idx], advantages[idx])
loss_pi = loss_pi - self.entropy_coeff * entropy
self.optimizer.zero_grad()
loss_pi.backward()
nn.utils.clip_grad_norm_(self.policy.parameters(), max_norm=0.5)
self.optimizer.step()
# Value loss
loss_v = torch.tensor(0.0)
if self.value_net is not None and self.value_optimizer is not None:
v = self.value_net(rollout.states[idx]).squeeze(-1)
loss_v = self.value_loss(v, returns[idx].detach())
self.value_optimizer.zero_grad()
loss_v.backward()
nn.utils.clip_grad_norm_(self.value_net.parameters(), max_norm=0.5)
self.value_optimizer.step()
metrics["policy_loss"] += loss_pi.item()
metrics["value_loss"] += loss_v.item()
metrics["entropy"] += entropy.item()
n_updates += 1
return {k: v / n_updates for k, v in metrics.items()}