Skip to content

Unified Transformer / Self-Attention Algorithm — Implementation

"""
Unified Transformer / Self-Attention Algorithm
================================================
A single file covering: vanilla multi-head attention, causal (GPT-style)
attention, cross-attention, multi-query attention (MQA), and grouped-query
attention (GQA). Plus the full transformer block and stack.

The core computation that NEVER changes:
  Attention(Q, K, V) = softmax( score(Q, K) + mask ) · V

Every variant only changes:
  1. how Q, K, V are projected    — (MHA vs MQA vs GQA)
  2. what mask is applied          — (none, causal, cross, padding)
  3. how position is encoded       — (sinusoidal, learned, RoPE, ALiBi)
  4. where the norm goes           — (pre-norm vs post-norm)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from abc import ABC, abstractmethod


# ═══════════════════════════════════════════════════════════════════
# CORE ATTENTION  (the part that NEVER changes)
# ═══════════════════════════════════════════════════════════════════
#
# Every attention variant ends up doing this:
#
#   1. Project inputs → Q, K, V         (PLUGGABLE: how many KV heads?)
#   2. scores = Q @ K^T / √d_k          (scaled dot-product)
#   3. scores += mask                    (PLUGGABLE: causal? padding?)
#   4. weights = softmax(scores)
#   5. output = weights @ V
#   6. Concatenate heads, project out
#
# Steps 2, 4, 5 are literally identical across ALL variants.

class AttentionBase(ABC, nn.Module):
    """
    The universal attention mechanism. Subclasses only override
    how Q/K/V are projected and how many KV heads exist.
    """

    def __init__(self, d_model, n_heads, dropout=0.0):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.dropout = nn.Dropout(dropout)

        # Q projection is always per-head
        self.W_q = nn.Linear(d_model, d_model)

        # K, V projections are defined by subclass (different head counts)
        self._init_kv_projections()

        # Output projection is always full-rank
        self.W_o = nn.Linear(d_model, d_model)

    @abstractmethod
    def _init_kv_projections(self):
        """Create self.W_k and self.W_v with appropriate dimensions."""
        ...

    @abstractmethod
    def _reshape_kv(self, k, v, B, T):
        """Reshape K, V to (B, n_heads, T, d_k), repeating if needed."""
        ...

    def forward(self, q_input, kv_input=None, mask=None):
        """
        Args:
            q_input:   (B, T_q, d_model)  — queries come from here
            kv_input:  (B, T_kv, d_model) — keys/values come from here
                       if None, self-attention (kv_input = q_input)
            mask:      (T_q, T_kv) or (B, 1, T_q, T_kv) — additive mask
                       -inf to block, 0 to allow

        Returns:
            output:    (B, T_q, d_model)
        """
        if kv_input is None:
            kv_input = q_input                      # self-attention

        B, T_q, _ = q_input.shape
        T_kv = kv_input.shape[1]

        # ── Step 1: Project to Q, K, V ────────────────────────────
        q = self.W_q(q_input)                       # (B, T_q, d_model)
        k = self.W_k(kv_input)                      # (B, T_kv, d_kv) — d_kv varies
        v = self.W_v(kv_input)                      # (B, T_kv, d_kv)

        # Reshape Q: always (B, n_heads, T, d_k)
        q = q.view(B, T_q, self.n_heads, self.d_k).transpose(1, 2)

        # Reshape K, V: depends on variant (MHA, MQA, GQA)
        k, v = self._reshape_kv(k, v, B, T_kv)     # (B, n_heads, T_kv, d_k)

        # ── Step 2: Scaled dot-product attention (ALWAYS THE SAME) ─
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
        #  (B, n_heads, T_q, T_kv)

        # ── Step 3: Apply mask ────────────────────────────────────
        if mask is not None:
            scores = scores + mask                  # -inf positions → 0 after softmax

        # ── Step 4: Softmax ───────────────────────────────────────
        weights = self.dropout(F.softmax(scores, dim=-1))

        # ── Step 5: Weighted sum of values ────────────────────────
        out = weights @ v                           # (B, n_heads, T_q, d_k)

        # ── Step 6: Concat heads + output projection ──────────────
        out = out.transpose(1, 2).contiguous().view(B, T_q, self.d_model)
        return self.W_o(out)


# ═══════════════════════════════════════════════════════════════════
# ATTENTION VARIANTS  (only the KV projection / reshaping changes)
# ═══════════════════════════════════════════════════════════════════

# ── 1. Multi-Head Attention (MHA) — the original ─────────────────

class MultiHeadAttention(AttentionBase):
    """
    Standard MHA: each head gets its own K, V projections.
    Total KV params: 2 × d_model × d_model
    """

    def _init_kv_projections(self):
        self.W_k = nn.Linear(self.d_model, self.d_model)
        self.W_v = nn.Linear(self.d_model, self.d_model)

    def _reshape_kv(self, k, v, B, T):
        k = k.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        return k, v


# ── 2. Multi-Query Attention (MQA) ──────────────────────────────

class MultiQueryAttention(AttentionBase):
    """
    MQA: ALL heads share a single K and a single V projection.
    Total KV params: 2 × d_model × d_k  (n_heads × cheaper)
    Broadcast the single KV head across all query heads.
    """

    def _init_kv_projections(self):
        self.W_k = nn.Linear(self.d_model, self.d_k)       # 1 head
        self.W_v = nn.Linear(self.d_model, self.d_k)       # 1 head

    def _reshape_kv(self, k, v, B, T):
        # (B, T, d_k) → (B, 1, T, d_k) → broadcast to (B, n_heads, T, d_k)
        k = k.view(B, T, 1, self.d_k).transpose(1, 2).expand(-1, self.n_heads, -1, -1)
        v = v.view(B, T, 1, self.d_k).transpose(1, 2).expand(-1, self.n_heads, -1, -1)
        return k, v


# ── 3. Grouped-Query Attention (GQA) ────────────────────────────

class GroupedQueryAttention(AttentionBase):
    """
    GQA: n_kv_heads groups of query heads share KV projections.
    Interpolates between MHA (n_kv_heads = n_heads) and
    MQA (n_kv_heads = 1).

    Used in LLaMA 2 70B, Mistral, Gemma, etc.
    """

    def __init__(self, d_model, n_heads, n_kv_heads, **kw):
        assert n_heads % n_kv_heads == 0
        self.n_kv_heads = n_kv_heads
        self.n_groups = n_heads // n_kv_heads       # queries per KV head
        super().__init__(d_model, n_heads, **kw)

    def _init_kv_projections(self):
        kv_dim = self.n_kv_heads * self.d_k
        self.W_k = nn.Linear(self.d_model, kv_dim)
        self.W_v = nn.Linear(self.d_model, kv_dim)

    def _reshape_kv(self, k, v, B, T):
        # (B, T, n_kv_heads, d_k) → repeat each KV head for its group of Q heads
        k = k.view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)
        v = v.view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)
        # Expand: (B, n_kv, T, d_k) → (B, n_kv, n_groups, T, d_k) → (B, n_heads, T, d_k)
        k = k.unsqueeze(2).expand(-1, -1, self.n_groups, -1, -1).reshape(B, self.n_heads, T, self.d_k)
        v = v.unsqueeze(2).expand(-1, -1, self.n_groups, -1, -1).reshape(B, self.n_heads, T, self.d_k)
        return k, v


# ═══════════════════════════════════════════════════════════════════
# MASKS  (the other axis of variation)
# ═══════════════════════════════════════════════════════════════════
#
# Masks are additive: 0 = attend, -inf = block.
# They're independent of the attention variant — any mask works
# with any of MHA/MQA/GQA.

def causal_mask(T, device="cpu"):
    """Lower-triangular: each position can only attend to itself and earlier."""
    return torch.triu(torch.full((T, T), float("-inf"), device=device), diagonal=1)

def padding_mask(lengths, max_T, device="cpu"):
    """Mask out padding tokens. lengths: (B,) ints."""
    arange = torch.arange(max_T, device=device).unsqueeze(0)   # (1, T)
    mask = (arange >= lengths.unsqueeze(1)).float() * float("-inf")
    return mask.unsqueeze(1).unsqueeze(2)                       # (B, 1, 1, T)


# ═══════════════════════════════════════════════════════════════════
# POSITIONAL ENCODING  (how the model knows token order)
# ═══════════════════════════════════════════════════════════════════
#
# Attention is permutation-equivariant — without positional info,
# "the cat sat on the mat" = "mat the on sat cat the."
# Position must be injected. Common approaches:
#
#   • Sinusoidal   — fixed, no learned params. Original transformer.
#   • Learned      — a trainable embedding per position. GPT-2.
#   • RoPE         — rotates Q, K by position-dependent angles.
#                    Encodes RELATIVE position via rotation.
#                    Used in LLaMA, Mistral, GPT-NeoX, most modern LLMs.
#   • ALiBi        — adds a linear bias to attention scores based on
#                    distance. No modification to Q, K. Used in BLOOM.

class SinusoidalPE(nn.Module):
    """Original 'Attention Is All You Need' positional encoding."""

    def __init__(self, d_model, max_len=8192):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe.unsqueeze(0))     # (1, max_len, d_model)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class LearnedPE(nn.Module):
    """Simple learned position embeddings (GPT-2 style)."""

    def __init__(self, d_model, max_len=8192):
        super().__init__()
        self.pe = nn.Embedding(max_len, d_model)

    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device)
        return x + self.pe(positions)


class RoPE(nn.Module):
    """
    Rotary Position Embedding. Applied to Q and K AFTER projection,
    BEFORE the dot product. Encodes relative position because
    rot(q, m) · rot(k, n) depends only on (m − n).
    """

    def __init__(self, d_k, max_len=8192):
        super().__init__()
        freqs = 1.0 / (10000 ** (torch.arange(0, d_k, 2).float() / d_k))
        t = torch.arange(max_len).float()
        angles = torch.outer(t, freqs)                   # (max_len, d_k/2)
        self.register_buffer("cos", angles.cos().unsqueeze(0).unsqueeze(0))  # (1, 1, T, d_k/2)
        self.register_buffer("sin", angles.sin().unsqueeze(0).unsqueeze(0))

    def forward(self, x):
        """x: (B, n_heads, T, d_k)"""
        T = x.size(2)
        x1, x2 = x[..., ::2], x[..., 1::2]
        cos, sin = self.cos[:, :, :T], self.sin[:, :, :T]
        return torch.stack([x1 * cos - x2 * sin,
                            x1 * sin + x2 * cos], dim=-1).flatten(-2)


# ═══════════════════════════════════════════════════════════════════
# TRANSFORMER BLOCK  (attention + FFN + residual + norm)
# ═══════════════════════════════════════════════════════════════════
#
# The block structure is:
#   x → [Norm] → Attention → + residual → [Norm] → FFN → + residual
#
# Norm placement:
#   • Post-norm (original paper): x + Sublayer(LayerNorm(x))
#     Harder to train, needs warmup. Used in the original transformer.
#   • Pre-norm (modern default):  x + Sublayer(LayerNorm(x))
#     More stable, used by GPT-2+, LLaMA, etc.
#
# FFN variants:
#   • Standard:   Linear → ReLU → Linear
#   • GLU / SwiGLU: Linear_gate → SiLU ⊙ Linear_up → Linear_down
#     Used in LLaMA, PaLM, Mistral. Better quality for same param count.

class FeedForward(nn.Module):
    """Standard FFN: up-project → activation → down-project."""

    def __init__(self, d_model, d_ff=None, dropout=0.0):
        super().__init__()
        d_ff = d_ff or 4 * d_model
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class SwiGLUFFN(nn.Module):
    """SwiGLU FFN: gated linear unit with SiLU activation. Modern default."""

    def __init__(self, d_model, d_ff=None, dropout=0.0):
        super().__init__()
        d_ff = d_ff or int(4 * d_model * 2 / 3)    # LLaMA-style scaling
        self.w_gate = nn.Linear(d_model, d_ff, bias=False)
        self.w_up   = nn.Linear(d_model, d_ff, bias=False)
        self.w_down = nn.Linear(d_ff, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))


class TransformerBlock(nn.Module):
    """Single transformer block. Pre-norm by default (modern standard)."""

    def __init__(self, attention: AttentionBase, ffn: nn.Module,
                 d_model: int, dropout=0.0, pre_norm=True):
        super().__init__()
        self.attention = attention
        self.ffn = ffn
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.drop1 = nn.Dropout(dropout)
        self.drop2 = nn.Dropout(dropout)
        self.pre_norm = pre_norm

    def forward(self, x, mask=None, kv=None):
        if self.pre_norm:
            # Pre-norm: norm BEFORE sublayer (GPT-2+, LLaMA, modern)
            x = x + self.drop1(self.attention(self.norm1(x), kv, mask))
            x = x + self.drop2(self.ffn(self.norm2(x)))
        else:
            # Post-norm: norm AFTER sublayer (original transformer)
            x = self.norm1(x + self.drop1(self.attention(x, kv, mask)))
            x = self.norm2(x + self.drop2(self.ffn(x)))
        return x


# ═══════════════════════════════════════════════════════════════════
# FULL TRANSFORMER  (stack of blocks + embeddings)
# ═══════════════════════════════════════════════════════════════════

class Transformer(nn.Module):
    """
    A complete transformer. Configure for different architectures
    by choosing attention variant, FFN, positional encoding, etc.
    """

    def __init__(self, vocab_size, d_model, n_heads, n_layers,
                 d_ff=None, dropout=0.1, max_len=8192,
                 attn_cls=MultiHeadAttention, attn_kw=None,
                 ffn_cls=FeedForward, pos_enc_cls=LearnedPE,
                 pre_norm=True):
        super().__init__()

        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_enc = pos_enc_cls(d_model, max_len)
        self.drop = nn.Dropout(dropout)

        attn_kw = attn_kw or {}
        self.layers = nn.ModuleList([
            TransformerBlock(
                attention=attn_cls(d_model, n_heads, dropout=dropout, **attn_kw),
                ffn=ffn_cls(d_model, d_ff, dropout=dropout),
                d_model=d_model, dropout=dropout, pre_norm=pre_norm,
            )
            for _ in range(n_layers)
        ])

        self.final_norm = nn.LayerNorm(d_model) if pre_norm else nn.Identity()
        self.head = nn.Linear(d_model, vocab_size, bias=False)

        # Weight tying: share embedding and output projection weights
        self.head.weight = self.tok_emb.weight

    def forward(self, token_ids, mask=None):
        """
        Args:
            token_ids: (B, T) long tensor of token indices
            mask:      (T, T) or (B, 1, T, T) additive attention mask

        Returns:
            logits: (B, T, vocab_size) — raw scores for next-token prediction
        """
        x = self.drop(self.pos_enc(self.tok_emb(token_ids)))

        for layer in self.layers:
            x = layer(x, mask=mask)

        x = self.final_norm(x)
        return self.head(x)


# ═══════════════════════════════════════════════════════════════════
# EXAMPLE CONFIGURATIONS  (how real models map to these pieces)
# ═══════════════════════════════════════════════════════════════════

def gpt2_small():
    """GPT-2 Small (124M): MHA + learned PE + GELU FFN + pre-norm"""
    return Transformer(
        vocab_size=50257, d_model=768, n_heads=12, n_layers=12,
        attn_cls=MultiHeadAttention, ffn_cls=FeedForward,
        pos_enc_cls=LearnedPE, pre_norm=True,
    )

def llama_style():
    """LLaMA-style: GQA + RoPE + SwiGLU FFN + pre-norm"""
    return Transformer(
        vocab_size=32000, d_model=1024, n_heads=16, n_layers=16,
        attn_cls=GroupedQueryAttention, attn_kw={"n_kv_heads": 4},
        ffn_cls=SwiGLUFFN,
        pos_enc_cls=lambda d, m: LearnedPE(d, m),  # RoPE applied inside attn in practice
        pre_norm=True,
    )