Skip to content

Unified Transformer / Self-Attention

The structure follows the same pattern as the Q-learning and policy gradient files. The core computation — Q @ K^T / √d, mask, softmax, @ V — is literally five lines that never change. Everything else is pluggable.

The axes of variation are a bit different here though. Instead of one thing (the algorithm) with a few swappable methods, the transformer is more like a configuration of independent choices along five axes: KV head structure, masking, positional encoding, norm placement, and FFN type. You mix and match freely — GQA works with any mask, RoPE works with any FFN, etc.

The key progression across the variants:

MHA → GQA → MQA is purely about inference efficiency. During autoregressive generation you cache K, V for all past tokens (the “KV cache”). MHA caches H heads worth; MQA caches 1. GQA is the sweet spot — LLaMA 2 70B, Mistral, and most modern LLMs use it because you get most of MQA’s memory savings with almost no quality loss.

Post-norm → Pre-norm is about trainability. Moving the LayerNorm before the sublayer lets gradients flow cleanly through the residual stream, which is why every model since GPT-2 uses pre-norm.

RoPE is the most interesting positional encoding — it encodes relative position through rotation, which means the attention score between two tokens depends on their distance, not their absolute positions. That’s a much better inductive bias for language.

The “HOW REAL MODELS MAP” table at the bottom shows that the original transformer and a modern LLaMA are the exact same skeleton — just different choices at each plug point.

Summary: What changes vs. what stays the same

Section titled “Summary: What changes vs. what stays the same”
  • scores = Q @ K^T / sqrt(d_k)
  • scores += mask
  • weights = softmax(scores)
  • output = weights @ V
  • concat heads → output projection
  • Residual connection around attention
  • Residual connection around FFN
  • Layer normalisation (placement varies)
ComponentChoice AChoice BChoice C
KV headsMHA (H kv heads)GQA (G kv heads)MQA (1 kv head)
MaskNone (encoder)Causal (decoder)Cross-attention
PositionSinusoidalLearnedRoPE / ALiBi
NormPost-normPre-normRMSNorm
FFNReLU / GELUSwiGLUMoE
ChoiceProblem SolvedIntuition for Solution
Self-Attention (core idea)RNNs process tokens sequentially — slow to train, struggle with long-range dependenciesLet every token attend to every other token in parallel. O(T²) but fully parallelisable and captures arbitrary pairwise relationships
Multi-HeadA single attention pattern can only focus on one type of relationship at a timeRun H independent attention ops in parallel on d/H-dim subspaces, then concatenate. Each head learns a different relationship type
GQA / MQAMHA’s KV cache scales as O(H · T · d_k) — huge memory cost during autoregressive inferenceShare K,V projections across groups of query heads. KV cache shrinks by H/G× with minimal quality loss. MQA is the extreme (1 KV head)
Causal maskDuring generation, future tokens don’t exist yet — the model must not attend to them during training eitherMask upper triangle of attention scores to -inf. Each position can only see itself and earlier positions
RoPELearned/sinusoidal PE encode absolute position, but relative position (distance) matters more for languageRotate Q, K vectors by position-dependent angles. The dot product Q·K then depends on the DIFFERENCE in positions, not absolute values
Pre-normPost-norm is hard to train deep: gradients must flow through norm layers, causing instability at depthNormalise BEFORE each sublayer. The residual stream carries un-normalised gradients directly, stabilising deep training without warmup hacks
SwiGLUStandard FFN (linear-ReLU-linear) is a blunt instrument — every dimension is treated identicallyUse a gated linear unit: one linear path modulates another via SiLU. The gating lets the network learn which features to pass through, improving quality for the same param count
ModelAttnPos EncFFNNormNorm Pos
Original TFMHASinusoidalReLULayerNormPost
GPT-2MHALearnedGELULayerNormPre
LLaMA 2GQARoPESwiGLURMSNormPre
MistralGQARoPESwiGLURMSNormPre
PaLMMQARoPESwiGLURMSNormPre