"""
Unified Transformer / Self-Attention Algorithm
================================================
A single file covering: vanilla multi-head attention, causal (GPT-style)
attention, cross-attention, multi-query attention (MQA), and grouped-query
attention (GQA). Plus the full transformer block and stack.
The core computation that NEVER changes:
Attention(Q, K, V) = softmax( score(Q, K) + mask ) · V
Every variant only changes:
1. how Q, K, V are projected — (MHA vs MQA vs GQA)
2. what mask is applied — (none, causal, cross, padding)
3. how position is encoded — (sinusoidal, learned, RoPE, ALiBi)
4. where the norm goes — (pre-norm vs post-norm)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from abc import ABC, abstractmethod
# ═══════════════════════════════════════════════════════════════════
# CORE ATTENTION (the part that NEVER changes)
# ═══════════════════════════════════════════════════════════════════
#
# Every attention variant ends up doing this:
#
# 1. Project inputs → Q, K, V (PLUGGABLE: how many KV heads?)
# 2. scores = Q @ K^T / √d_k (scaled dot-product)
# 3. scores += mask (PLUGGABLE: causal? padding?)
# 4. weights = softmax(scores)
# 5. output = weights @ V
# 6. Concatenate heads, project out
#
# Steps 2, 4, 5 are literally identical across ALL variants.
class AttentionBase(ABC, nn.Module):
"""
The universal attention mechanism. Subclasses only override
how Q/K/V are projected and how many KV heads exist.
"""
def __init__(self, d_model, n_heads, dropout=0.0):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.dropout = nn.Dropout(dropout)
# Q projection is always per-head
self.W_q = nn.Linear(d_model, d_model)
# K, V projections are defined by subclass (different head counts)
self._init_kv_projections()
# Output projection is always full-rank
self.W_o = nn.Linear(d_model, d_model)
@abstractmethod
def _init_kv_projections(self):
"""Create self.W_k and self.W_v with appropriate dimensions."""
...
@abstractmethod
def _reshape_kv(self, k, v, B, T):
"""Reshape K, V to (B, n_heads, T, d_k), repeating if needed."""
...
def forward(self, q_input, kv_input=None, mask=None):
"""
Args:
q_input: (B, T_q, d_model) — queries come from here
kv_input: (B, T_kv, d_model) — keys/values come from here
if None, self-attention (kv_input = q_input)
mask: (T_q, T_kv) or (B, 1, T_q, T_kv) — additive mask
-inf to block, 0 to allow
Returns:
output: (B, T_q, d_model)
"""
if kv_input is None:
kv_input = q_input # self-attention
B, T_q, _ = q_input.shape
T_kv = kv_input.shape[1]
# ── Step 1: Project to Q, K, V ────────────────────────────
q = self.W_q(q_input) # (B, T_q, d_model)
k = self.W_k(kv_input) # (B, T_kv, d_kv) — d_kv varies
v = self.W_v(kv_input) # (B, T_kv, d_kv)
# Reshape Q: always (B, n_heads, T, d_k)
q = q.view(B, T_q, self.n_heads, self.d_k).transpose(1, 2)
# Reshape K, V: depends on variant (MHA, MQA, GQA)
k, v = self._reshape_kv(k, v, B, T_kv) # (B, n_heads, T_kv, d_k)
# ── Step 2: Scaled dot-product attention (ALWAYS THE SAME) ─
scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
# (B, n_heads, T_q, T_kv)
# ── Step 3: Apply mask ────────────────────────────────────
if mask is not None:
scores = scores + mask # -inf positions → 0 after softmax
# ── Step 4: Softmax ───────────────────────────────────────
weights = self.dropout(F.softmax(scores, dim=-1))
# ── Step 5: Weighted sum of values ────────────────────────
out = weights @ v # (B, n_heads, T_q, d_k)
# ── Step 6: Concat heads + output projection ──────────────
out = out.transpose(1, 2).contiguous().view(B, T_q, self.d_model)
return self.W_o(out)
# ═══════════════════════════════════════════════════════════════════
# ATTENTION VARIANTS (only the KV projection / reshaping changes)
# ═══════════════════════════════════════════════════════════════════
# ── 1. Multi-Head Attention (MHA) — the original ─────────────────
class MultiHeadAttention(AttentionBase):
"""
Standard MHA: each head gets its own K, V projections.
Total KV params: 2 × d_model × d_model
"""
def _init_kv_projections(self):
self.W_k = nn.Linear(self.d_model, self.d_model)
self.W_v = nn.Linear(self.d_model, self.d_model)
def _reshape_kv(self, k, v, B, T):
k = k.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
v = v.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
return k, v
# ── 2. Multi-Query Attention (MQA) ──────────────────────────────
class MultiQueryAttention(AttentionBase):
"""
MQA: ALL heads share a single K and a single V projection.
Total KV params: 2 × d_model × d_k (n_heads × cheaper)
Broadcast the single KV head across all query heads.
"""
def _init_kv_projections(self):
self.W_k = nn.Linear(self.d_model, self.d_k) # 1 head
self.W_v = nn.Linear(self.d_model, self.d_k) # 1 head
def _reshape_kv(self, k, v, B, T):
# (B, T, d_k) → (B, 1, T, d_k) → broadcast to (B, n_heads, T, d_k)
k = k.view(B, T, 1, self.d_k).transpose(1, 2).expand(-1, self.n_heads, -1, -1)
v = v.view(B, T, 1, self.d_k).transpose(1, 2).expand(-1, self.n_heads, -1, -1)
return k, v
# ── 3. Grouped-Query Attention (GQA) ────────────────────────────
class GroupedQueryAttention(AttentionBase):
"""
GQA: n_kv_heads groups of query heads share KV projections.
Interpolates between MHA (n_kv_heads = n_heads) and
MQA (n_kv_heads = 1).
Used in LLaMA 2 70B, Mistral, Gemma, etc.
"""
def __init__(self, d_model, n_heads, n_kv_heads, **kw):
assert n_heads % n_kv_heads == 0
self.n_kv_heads = n_kv_heads
self.n_groups = n_heads // n_kv_heads # queries per KV head
super().__init__(d_model, n_heads, **kw)
def _init_kv_projections(self):
kv_dim = self.n_kv_heads * self.d_k
self.W_k = nn.Linear(self.d_model, kv_dim)
self.W_v = nn.Linear(self.d_model, kv_dim)
def _reshape_kv(self, k, v, B, T):
# (B, T, n_kv_heads, d_k) → repeat each KV head for its group of Q heads
k = k.view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)
v = v.view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)
# Expand: (B, n_kv, T, d_k) → (B, n_kv, n_groups, T, d_k) → (B, n_heads, T, d_k)
k = k.unsqueeze(2).expand(-1, -1, self.n_groups, -1, -1).reshape(B, self.n_heads, T, self.d_k)
v = v.unsqueeze(2).expand(-1, -1, self.n_groups, -1, -1).reshape(B, self.n_heads, T, self.d_k)
return k, v
# ═══════════════════════════════════════════════════════════════════
# MASKS (the other axis of variation)
# ═══════════════════════════════════════════════════════════════════
#
# Masks are additive: 0 = attend, -inf = block.
# They're independent of the attention variant — any mask works
# with any of MHA/MQA/GQA.
def causal_mask(T, device="cpu"):
"""Lower-triangular: each position can only attend to itself and earlier."""
return torch.triu(torch.full((T, T), float("-inf"), device=device), diagonal=1)
def padding_mask(lengths, max_T, device="cpu"):
"""Mask out padding tokens. lengths: (B,) ints."""
arange = torch.arange(max_T, device=device).unsqueeze(0) # (1, T)
mask = (arange >= lengths.unsqueeze(1)).float() * float("-inf")
return mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, T)
# ═══════════════════════════════════════════════════════════════════
# POSITIONAL ENCODING (how the model knows token order)
# ═══════════════════════════════════════════════════════════════════
#
# Attention is permutation-equivariant — without positional info,
# "the cat sat on the mat" = "mat the on sat cat the."
# Position must be injected. Common approaches:
#
# • Sinusoidal — fixed, no learned params. Original transformer.
# • Learned — a trainable embedding per position. GPT-2.
# • RoPE — rotates Q, K by position-dependent angles.
# Encodes RELATIVE position via rotation.
# Used in LLaMA, Mistral, GPT-NeoX, most modern LLMs.
# • ALiBi — adds a linear bias to attention scores based on
# distance. No modification to Q, K. Used in BLOOM.
class SinusoidalPE(nn.Module):
"""Original 'Attention Is All You Need' positional encoding."""
def __init__(self, d_model, max_len=8192):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(max_len).unsqueeze(1).float()
div = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer("pe", pe.unsqueeze(0)) # (1, max_len, d_model)
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class LearnedPE(nn.Module):
"""Simple learned position embeddings (GPT-2 style)."""
def __init__(self, d_model, max_len=8192):
super().__init__()
self.pe = nn.Embedding(max_len, d_model)
def forward(self, x):
positions = torch.arange(x.size(1), device=x.device)
return x + self.pe(positions)
class RoPE(nn.Module):
"""
Rotary Position Embedding. Applied to Q and K AFTER projection,
BEFORE the dot product. Encodes relative position because
rot(q, m) · rot(k, n) depends only on (m − n).
"""
def __init__(self, d_k, max_len=8192):
super().__init__()
freqs = 1.0 / (10000 ** (torch.arange(0, d_k, 2).float() / d_k))
t = torch.arange(max_len).float()
angles = torch.outer(t, freqs) # (max_len, d_k/2)
self.register_buffer("cos", angles.cos().unsqueeze(0).unsqueeze(0)) # (1, 1, T, d_k/2)
self.register_buffer("sin", angles.sin().unsqueeze(0).unsqueeze(0))
def forward(self, x):
"""x: (B, n_heads, T, d_k)"""
T = x.size(2)
x1, x2 = x[..., ::2], x[..., 1::2]
cos, sin = self.cos[:, :, :T], self.sin[:, :, :T]
return torch.stack([x1 * cos - x2 * sin,
x1 * sin + x2 * cos], dim=-1).flatten(-2)
# ═══════════════════════════════════════════════════════════════════
# TRANSFORMER BLOCK (attention + FFN + residual + norm)
# ═══════════════════════════════════════════════════════════════════
#
# The block structure is:
# x → [Norm] → Attention → + residual → [Norm] → FFN → + residual
#
# Norm placement:
# • Post-norm (original paper): x + Sublayer(LayerNorm(x))
# Harder to train, needs warmup. Used in the original transformer.
# • Pre-norm (modern default): x + Sublayer(LayerNorm(x))
# More stable, used by GPT-2+, LLaMA, etc.
#
# FFN variants:
# • Standard: Linear → ReLU → Linear
# • GLU / SwiGLU: Linear_gate → SiLU ⊙ Linear_up → Linear_down
# Used in LLaMA, PaLM, Mistral. Better quality for same param count.
class FeedForward(nn.Module):
"""Standard FFN: up-project → activation → down-project."""
def __init__(self, d_model, d_ff=None, dropout=0.0):
super().__init__()
d_ff = d_ff or 4 * d_model
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class SwiGLUFFN(nn.Module):
"""SwiGLU FFN: gated linear unit with SiLU activation. Modern default."""
def __init__(self, d_model, d_ff=None, dropout=0.0):
super().__init__()
d_ff = d_ff or int(4 * d_model * 2 / 3) # LLaMA-style scaling
self.w_gate = nn.Linear(d_model, d_ff, bias=False)
self.w_up = nn.Linear(d_model, d_ff, bias=False)
self.w_down = nn.Linear(d_ff, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))
class TransformerBlock(nn.Module):
"""Single transformer block. Pre-norm by default (modern standard)."""
def __init__(self, attention: AttentionBase, ffn: nn.Module,
d_model: int, dropout=0.0, pre_norm=True):
super().__init__()
self.attention = attention
self.ffn = ffn
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.drop1 = nn.Dropout(dropout)
self.drop2 = nn.Dropout(dropout)
self.pre_norm = pre_norm
def forward(self, x, mask=None, kv=None):
if self.pre_norm:
# Pre-norm: norm BEFORE sublayer (GPT-2+, LLaMA, modern)
x = x + self.drop1(self.attention(self.norm1(x), kv, mask))
x = x + self.drop2(self.ffn(self.norm2(x)))
else:
# Post-norm: norm AFTER sublayer (original transformer)
x = self.norm1(x + self.drop1(self.attention(x, kv, mask)))
x = self.norm2(x + self.drop2(self.ffn(x)))
return x
# ═══════════════════════════════════════════════════════════════════
# FULL TRANSFORMER (stack of blocks + embeddings)
# ═══════════════════════════════════════════════════════════════════
class Transformer(nn.Module):
"""
A complete transformer. Configure for different architectures
by choosing attention variant, FFN, positional encoding, etc.
"""
def __init__(self, vocab_size, d_model, n_heads, n_layers,
d_ff=None, dropout=0.1, max_len=8192,
attn_cls=MultiHeadAttention, attn_kw=None,
ffn_cls=FeedForward, pos_enc_cls=LearnedPE,
pre_norm=True):
super().__init__()
self.tok_emb = nn.Embedding(vocab_size, d_model)
self.pos_enc = pos_enc_cls(d_model, max_len)
self.drop = nn.Dropout(dropout)
attn_kw = attn_kw or {}
self.layers = nn.ModuleList([
TransformerBlock(
attention=attn_cls(d_model, n_heads, dropout=dropout, **attn_kw),
ffn=ffn_cls(d_model, d_ff, dropout=dropout),
d_model=d_model, dropout=dropout, pre_norm=pre_norm,
)
for _ in range(n_layers)
])
self.final_norm = nn.LayerNorm(d_model) if pre_norm else nn.Identity()
self.head = nn.Linear(d_model, vocab_size, bias=False)
# Weight tying: share embedding and output projection weights
self.head.weight = self.tok_emb.weight
def forward(self, token_ids, mask=None):
"""
Args:
token_ids: (B, T) long tensor of token indices
mask: (T, T) or (B, 1, T, T) additive attention mask
Returns:
logits: (B, T, vocab_size) — raw scores for next-token prediction
"""
x = self.drop(self.pos_enc(self.tok_emb(token_ids)))
for layer in self.layers:
x = layer(x, mask=mask)
x = self.final_norm(x)
return self.head(x)
# ═══════════════════════════════════════════════════════════════════
# EXAMPLE CONFIGURATIONS (how real models map to these pieces)
# ═══════════════════════════════════════════════════════════════════
def gpt2_small():
"""GPT-2 Small (124M): MHA + learned PE + GELU FFN + pre-norm"""
return Transformer(
vocab_size=50257, d_model=768, n_heads=12, n_layers=12,
attn_cls=MultiHeadAttention, ffn_cls=FeedForward,
pos_enc_cls=LearnedPE, pre_norm=True,
)
def llama_style():
"""LLaMA-style: GQA + RoPE + SwiGLU FFN + pre-norm"""
return Transformer(
vocab_size=32000, d_model=1024, n_heads=16, n_layers=16,
attn_cls=GroupedQueryAttention, attn_kw={"n_kv_heads": 4},
ffn_cls=SwiGLUFFN,
pos_enc_cls=lambda d, m: LearnedPE(d, m), # RoPE applied inside attn in practice
pre_norm=True,
)