"""
Unified Q-Learning Algorithm
=============================
A single skeleton that covers: Q-Learning, DQN, Double DQN, Dueling DQN,
CQL, IQL, SAC (Q-critic), and more.
The core loop is IDENTICAL across all variants. Only three pluggable
components change:
1. compute_target() — how the bootstrap target y is built
2. compute_loss() — the objective (MSE, Huber, + regularizers)
3. data source — online replay buffer vs. offline dataset
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional
from replay_buffers import Batch, ReplayBuffer, NStepReplayBuffer
# Batch is a NamedTuple of tensors, one row per sample in the mini-batch:
# s (B, *state_shape) — states
# a (B,) — actions taken
# r (B,) — rewards (1-step or n-step discounted return)
# s_next (B, *state_shape) — next states (1 or n steps ahead)
# done (B,) — 1.0 if episode ended, 0.0 otherwise
#
# A replay buffer exposes two methods:
# add(s, a, r, s_next, done) — store a single transition
# sample(batch_size) -> Batch — draw a random mini-batch for training
# See replay_buffers.py for ReplayBuffer (1-step) and NStepReplayBuffer.
# ═══════════════════════════════════════════════════════════════════
# CORE ALGORITHM (the part that NEVER changes)
# ═══════════════════════════════════════════════════════════════════
class QAlgorithm(ABC):
"""
The universal Q-learning training loop.
Every variant inherits this and only overrides:
- compute_target(batch) -> Tensor
- compute_loss(q_values, targets, batch) -> Tensor
"""
def __init__(self, Q: nn.Module, Q_target: nn.Module,
optimizer: torch.optim.Optimizer, gamma: float = 0.99,
target_update_freq: int = 1000, tau: Optional[float] = None) -> None:
self.Q = Q # online Q-network
self.Q_target = Q_target # target Q-network
self.optimizer = optimizer
self.gamma = gamma
self.target_update_freq = target_update_freq
self.tau = tau # if set, use Polyak averaging instead of hard copy
self._step: int = 0
# ── The two pluggable pieces ──────────────────────────────────
@abstractmethod
def compute_target(self, batch: Batch) -> torch.Tensor:
"""Return the scalar bootstrap target y for each sample."""
...
def compute_loss(self, q_a: torch.Tensor, targets: torch.Tensor,
batch: Batch) -> torch.Tensor:
"""Default: MSE / Huber. Override to add regularisers (e.g. CQL)."""
return F.mse_loss(q_a, targets)
# ── Core update step (IDENTICAL for every variant) ────────────
def update(self, batch: Batch) -> float:
# 1. Current Q-values for the actions actually taken
q_all = self.Q(batch.s) # (B, |A|)
q_a = q_all.gather(1, batch.a.unsqueeze(-1)).squeeze(-1) # (B,)
# 2. Compute bootstrap target (this is the part that varies)
with torch.no_grad():
targets = self.compute_target(batch) # (B,)
# 3. Loss (varies: plain MSE, Huber, +CQL penalty, expectile, ...)
loss = self.compute_loss(q_a, targets, batch)
# 4. Gradient step (always the same)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# 5. Target network update (always the same)
self._step += 1
self._update_target_network()
return loss.item()
def _update_target_network(self) -> None:
if self.tau is not None:
# Polyak / soft update (used by SAC, TD3, etc.)
for p, pt in zip(self.Q.parameters(), self.Q_target.parameters()):
pt.data.copy_(self.tau * p.data + (1 - self.tau) * pt.data)
elif self._step % self.target_update_freq == 0:
# Hard copy (used by DQN, CQL, etc.)
self.Q_target.load_state_dict(self.Q.state_dict())
# ═══════════════════════════════════════════════════════════════════
# ONLINE TRAINING LOOP (DQN-style: interact → store → sample → learn)
# ═══════════════════════════════════════════════════════════════════
def train_online(algo: QAlgorithm, env: Any,
replay_buffer: ReplayBuffer | NStepReplayBuffer,
n_steps: int, batch_size: int = 256, warmup: int = 1000,
eps_schedule: Optional[Callable[[int], float]] = None) -> None:
s, _ = env.reset()
for t in range(n_steps):
# ε-greedy (or swap in Boltzmann, UCB, etc.)
eps = eps_schedule(t) if eps_schedule else 0.1
if torch.rand(1).item() < eps:
a = env.action_space.sample()
else:
with torch.no_grad():
a = algo.Q(torch.tensor(s).unsqueeze(0)).argmax(-1).item()
s_next, r, done, trunc, _ = env.step(a)
replay_buffer.add(s, a, r, s_next, float(done or trunc))
s = s_next
if done or trunc:
s, _ = env.reset()
# Learn from replay
if len(replay_buffer) >= warmup:
batch = replay_buffer.sample(batch_size)
algo.update(batch)
# ═══════════════════════════════════════════════════════════════════
# OFFLINE TRAINING LOOP (CQL / IQL-style: just iterate over dataset)
# ═══════════════════════════════════════════════════════════════════
def train_offline(algo: QAlgorithm, dataset: Any, n_steps: int,
batch_size: int = 256) -> None:
for t in range(n_steps):
batch = dataset.sample(batch_size) # random mini-batch from static data
algo.update(batch)
# ═══════════════════════════════════════════════════════════════════
# VARIANT IMPLEMENTATIONS (only the parts that differ)
# ═══════════════════════════════════════════════════════════════════
# ── 1. Vanilla DQN ───────────────────────────────────────────────
class DQN(QAlgorithm):
"""Target: y = r + γ · max_a' Q_target(s', a')"""
def compute_target(self, batch: Batch) -> torch.Tensor:
q_next = self.Q_target(batch.s_next) # (B, |A|)
return batch.r + self.gamma * (1 - batch.done) * q_next.max(dim=-1).values
# ── 2. Double DQN ────────────────────────────────────────────────
class DoubleDQN(QAlgorithm):
"""
Target: y = r + γ · Q_target(s', argmax_a' Q_online(s', a'))
Online network SELECTS the action; target network EVALUATES it.
"""
def compute_target(self, batch: Batch) -> torch.Tensor:
# Online net picks the best action
a_best = self.Q(batch.s_next).argmax(dim=-1, keepdim=True) # (B, 1)
# Target net evaluates that action
q_next = self.Q_target(batch.s_next).gather(1, a_best).squeeze(-1)
return batch.r + self.gamma * (1 - batch.done) * q_next
# ── 3. CQL (Conservative Q-Learning, offline) ──────────────────
class CQL(QAlgorithm):
"""
Same target as DQN, but adds a regulariser that pushes down
Q-values on out-of-distribution actions and pushes up Q-values
on actions seen in the dataset.
loss = TD_loss + α · [ log Σ_a exp Q(s,a) − Q(s, a_data) ]
"""
def __init__(self, *args: Any, cql_alpha: float = 1.0, **kw: Any) -> None:
super().__init__(*args, **kw)
self.cql_alpha = cql_alpha
def compute_target(self, batch: Batch) -> torch.Tensor:
# Identical to DQN
q_next = self.Q_target(batch.s_next)
return batch.r + self.gamma * (1 - batch.done) * q_next.max(-1).values
def compute_loss(self, q_a: torch.Tensor, targets: torch.Tensor,
batch: Batch) -> torch.Tensor:
td_loss = F.mse_loss(q_a, targets)
# CQL regulariser: penalise high Q on ALL actions, reward Q on DATA actions
q_all = self.Q(batch.s) # (B, |A|)
logsumexp = torch.logsumexp(q_all, dim=-1).mean() # push down
data_q = q_a.mean() # push up
cql_penalty = logsumexp - data_q
return td_loss + self.cql_alpha * cql_penalty
# ── 4. IQL (Implicit Q-Learning, offline) ──────────────────────
class IQL(QAlgorithm):
"""
Avoids querying Q on out-of-distribution actions entirely by
learning a separate state-value V(s) with expectile regression,
then using V(s') as the bootstrap target instead of max_a' Q(s',a').
Two losses:
• V-loss : expectile regression of V(s) toward Q(s, a_data)
• Q-loss : standard TD using V(s') as the target
"""
def __init__(self, *args: Any, V: nn.Module, v_optimizer: torch.optim.Optimizer,
expectile: float = 0.7, **kw: Any) -> None:
super().__init__(*args, **kw)
self.V = V
self.v_optimizer = v_optimizer
self.expectile = expectile
def compute_target(self, batch: Batch) -> torch.Tensor:
# Bootstrap off V(s') — never need max over unseen actions
v_next = self.V(batch.s_next).squeeze(-1) # (B,)
return batch.r + self.gamma * (1 - batch.done) * v_next
def update(self, batch: Batch) -> float:
# ── Extra step: update V via expectile regression ────
with torch.no_grad():
q_all = self.Q_target(batch.s)
q_a = q_all.gather(1, batch.a.unsqueeze(-1)).squeeze(-1)
v = self.V(batch.s).squeeze(-1)
diff = q_a - v
weight = torch.where(diff > 0, self.expectile, 1 - self.expectile)
v_loss = (weight * diff.pow(2)).mean()
self.v_optimizer.zero_grad()
v_loss.backward()
self.v_optimizer.step()
# ── Then the standard Q update (inherited from QAlgorithm) ──
return super().update(batch)
# ── 5. Soft Q-Learning / SAC critic ─────────────────────────────
class SoftQ(QAlgorithm):
"""
Target: y = r + γ · (Q_target(s', a') − α log π(a'|s'))
where a' ~ π(·|s'). Used as the critic in SAC.
"""
def __init__(self, *args: Any, policy: Any, alpha: float = 0.2, **kw: Any) -> None:
super().__init__(*args, **kw)
self.policy = policy
self.alpha = alpha
def compute_target(self, batch: Batch) -> torch.Tensor:
a_next, log_prob = self.policy.sample(batch.s_next) # sample from π
q_next = self.Q_target(batch.s_next)
q_a_next = q_next.gather(1, a_next.unsqueeze(-1)).squeeze(-1)
return batch.r + self.gamma * (1 - batch.done) * (q_a_next - self.alpha * log_prob)
# ── 6. N-Step DQN ──────────────────────────────────────────────
class NStepDQN(QAlgorithm):
"""
Target: y = r_nstep + γⁿ · max_a' Q_target(s_n, a')
The n-step discounted return (r_nstep) and the n-step-ahead next state
(s_n) are computed by the NStepReplayBuffer, not here. This class
only adjusts the discount factor from γ to γⁿ in the bootstrap term.
"""
def __init__(self, *args: Any, n_step: int = 3, **kw: Any) -> None:
super().__init__(*args, **kw)
self.n_step = n_step
def compute_target(self, batch: Batch) -> torch.Tensor:
# batch.r is already the n-step discounted return from the buffer
# batch.s_next is the state n steps ahead
q_next = self.Q_target(batch.s_next)
return batch.r + (self.gamma ** self.n_step) * (1 - batch.done) * q_next.max(dim=-1).values