Skip to content

Replay Buffers for Q-Learning — Implementation

"""
Replay Buffers for Q-Learning
==============================
Standard 1-step and N-step replay buffers used by the online training loop.

# sidebar.label: Replay Buffer Implementation
"""

import torch
from typing import Any, NamedTuple, Tuple


# ─── Shared Data Types ────────────────────────────────────────────

class Batch(NamedTuple):
    s: torch.Tensor       # states        (B, *state_shape)
    a: torch.Tensor       # actions       (B,) or (B, action_dim)
    r: torch.Tensor       # rewards       (B,)
    s_next: torch.Tensor  # next states   (B, *state_shape)
    done: torch.Tensor    # terminal mask (B,)  — 1.0 if done

Transition = Tuple[Any, int, float, Any, float]  # (s, a, r, s_next, done)


# ─── Replay Buffer ───────────────────────────────────────────────

class ReplayBuffer:
    """Standard 1-step replay buffer backed by a fixed-size ring buffer."""

    def __init__(self, capacity: int) -> None:
        self.capacity = capacity
        self.buf: list[Transition] = []
        self.pos: int = 0

    def add(self, s: Any, a: int, r: float, s_next: Any, done: float) -> None:
        transition = (s, a, r, s_next, done)
        if len(self.buf) < self.capacity:
            self.buf.append(transition)
        else:
            self.buf[self.pos] = transition
        self.pos = (self.pos + 1) % self.capacity

    def sample(self, batch_size: int) -> Batch:
        idxs = torch.randint(len(self.buf), (batch_size,))
        ss, aa, rr, ss_next, dd = zip(*(self.buf[i] for i in idxs))
        return Batch(
            s=torch.tensor(ss, dtype=torch.float32),
            a=torch.tensor(aa, dtype=torch.long),
            r=torch.tensor(rr, dtype=torch.float32),
            s_next=torch.tensor(ss_next, dtype=torch.float32),
            done=torch.tensor(dd, dtype=torch.float32),
        )

    def __len__(self) -> int:
        return len(self.buf)


# ─── N-Step Replay Buffer ───────────────────────────────────────

class NStepReplayBuffer:
    """
    N-step replay buffer. Accumulates n-step discounted returns in a
    rolling window before storing the transition.

    Stores (s_0, a_0, r_nstep, s_n, done) where:
      r_nstep = r_0 + γ r_1 + γ² r_2 + ... + γ^(n-1) r_(n-1)
      s_n     = state n steps after s_0
    Truncates at episode boundaries.
    """

    def __init__(self, capacity: int, n_step: int, gamma: float) -> None:
        self.capacity = capacity
        self.n_step = n_step
        self.gamma = gamma
        self.buf: list[Transition] = []
        self.pos: int = 0
        self.pending: list[Transition] = []  # rolling window of recent transitions

    def add(self, s: Any, a: int, r: float, s_next: Any, done: float) -> None:
        self.pending.append((s, a, r, s_next, done))

        # Flush at episode boundary: all pending transitions get truncated n-step returns
        if done:
            while self.pending:
                self._flush_one()
            return

        # Once we have n transitions queued, the oldest one is ready
        if len(self.pending) >= self.n_step:
            self._flush_one()

    def _flush_one(self) -> None:
        """Pop the oldest pending transition, compute its n-step return, and store it."""
        k = min(len(self.pending), self.n_step)

        s_0, a_0 = self.pending[0][0], self.pending[0][1]

        # Accumulate discounted return, stopping early if a done is encountered
        r_nstep = 0.0
        for i in range(k):
            _, _, r_i, s_i_next, done_i = self.pending[i]
            r_nstep += (self.gamma ** i) * r_i
            if done_i:
                # Episode ended at step i: use this as the terminal transition
                self._store(s_0, a_0, r_nstep, s_i_next, done_i)
                self.pending.pop(0)
                return

        # No terminal state within the window: bootstrap from state k steps ahead
        s_k = self.pending[k - 1][3]  # s_next of the last transition in the window
        self._store(s_0, a_0, r_nstep, s_k, 0.0)
        self.pending.pop(0)

    def _store(self, s: Any, a: int, r: float, s_next: Any, done: float) -> None:
        transition = (s, a, r, s_next, done)
        if len(self.buf) < self.capacity:
            self.buf.append(transition)
        else:
            self.buf[self.pos] = transition
        self.pos = (self.pos + 1) % self.capacity

    def sample(self, batch_size: int) -> Batch:
        idxs = torch.randint(len(self.buf), (batch_size,))
        ss, aa, rr, ss_next, dd = zip(*(self.buf[i] for i in idxs))
        return Batch(
            s=torch.tensor(ss, dtype=torch.float32),
            a=torch.tensor(aa, dtype=torch.long),
            r=torch.tensor(rr, dtype=torch.float32),
            s_next=torch.tensor(ss_next, dtype=torch.float32),
            done=torch.tensor(dd, dtype=torch.float32),
        )

    def __len__(self) -> int:
        return len(self.buf)