Unified Transformer / Self-Attention
Unified Transformer / Self-Attention
Section titled “Unified Transformer / Self-Attention”Introduction
Section titled “Introduction”The structure follows the same pattern as the Q-learning and policy gradient files. The core computation — Q @ K^T / √d, mask, softmax, @ V — is literally five lines that never change. Everything else is pluggable.
The axes of variation are a bit different here though. Instead of one thing (the algorithm) with a few swappable methods, the transformer is more like a configuration of independent choices along five axes: KV head structure, masking, positional encoding, norm placement, and FFN type. You mix and match freely — GQA works with any mask, RoPE works with any FFN, etc.
The key progression across the variants:
MHA → GQA → MQA is purely about inference efficiency. During autoregressive generation you cache K, V for all past tokens (the “KV cache”). MHA caches H heads worth; MQA caches 1. GQA is the sweet spot — LLaMA 2 70B, Mistral, and most modern LLMs use it because you get most of MQA’s memory savings with almost no quality loss.
Post-norm → Pre-norm is about trainability. Moving the LayerNorm before the sublayer lets gradients flow cleanly through the residual stream, which is why every model since GPT-2 uses pre-norm.
RoPE is the most interesting positional encoding — it encodes relative position through rotation, which means the attention score between two tokens depends on their distance, not their absolute positions. That’s a much better inductive bias for language.
The “HOW REAL MODELS MAP” table at the bottom shows that the original transformer and a modern LLaMA are the exact same skeleton — just different choices at each plug point.
Summary: What changes vs. what stays the same
Section titled “Summary: What changes vs. what stays the same”Always the same (core attention)
Section titled “Always the same (core attention)”- scores = Q @ K^T / sqrt(d_k)
- scores += mask
- weights = softmax(scores)
- output = weights @ V
- concat heads → output projection
Always the same (block structure)
Section titled “Always the same (block structure)”- Residual connection around attention
- Residual connection around FFN
- Layer normalisation (placement varies)
What varies by variant
Section titled “What varies by variant”| Component | Choice A | Choice B | Choice C |
|---|---|---|---|
| KV heads | MHA (H kv heads) | GQA (G kv heads) | MQA (1 kv head) |
| Mask | None (encoder) | Causal (decoder) | Cross-attention |
| Position | Sinusoidal | Learned | RoPE / ALiBi |
| Norm | Post-norm | Pre-norm | RMSNorm |
| FFN | ReLU / GELU | SwiGLU | MoE |
Motives for each variant
Section titled “Motives for each variant”| Choice | Problem Solved | Intuition for Solution |
|---|---|---|
| Self-Attention (core idea) | RNNs process tokens sequentially — slow to train, struggle with long-range dependencies | Let every token attend to every other token in parallel. O(T²) but fully parallelisable and captures arbitrary pairwise relationships |
| Multi-Head | A single attention pattern can only focus on one type of relationship at a time | Run H independent attention ops in parallel on d/H-dim subspaces, then concatenate. Each head learns a different relationship type |
| GQA / MQA | MHA’s KV cache scales as O(H · T · d_k) — huge memory cost during autoregressive inference | Share K,V projections across groups of query heads. KV cache shrinks by H/G× with minimal quality loss. MQA is the extreme (1 KV head) |
| Causal mask | During generation, future tokens don’t exist yet — the model must not attend to them during training either | Mask upper triangle of attention scores to -inf. Each position can only see itself and earlier positions |
| RoPE | Learned/sinusoidal PE encode absolute position, but relative position (distance) matters more for language | Rotate Q, K vectors by position-dependent angles. The dot product Q·K then depends on the DIFFERENCE in positions, not absolute values |
| Pre-norm | Post-norm is hard to train deep: gradients must flow through norm layers, causing instability at depth | Normalise BEFORE each sublayer. The residual stream carries un-normalised gradients directly, stabilising deep training without warmup hacks |
| SwiGLU | Standard FFN (linear-ReLU-linear) is a blunt instrument — every dimension is treated identically | Use a gated linear unit: one linear path modulates another via SiLU. The gating lets the network learn which features to pass through, improving quality for the same param count |
How real models map
Section titled “How real models map”| Model | Attn | Pos Enc | FFN | Norm | Norm Pos |
|---|---|---|---|---|---|
| Original TF | MHA | Sinusoidal | ReLU | LayerNorm | Post |
| GPT-2 | MHA | Learned | GELU | LayerNorm | Pre |
| LLaMA 2 | GQA | RoPE | SwiGLU | RMSNorm | Pre |
| Mistral | GQA | RoPE | SwiGLU | RMSNorm | Pre |
| PaLM | MQA | RoPE | SwiGLU | RMSNorm | Pre |