The Q·Kᵀ/√d mechanism that lets every token attend to every other token — the single operation that defines transformers. Understand this and the rest of the architecture falls into place.
Self-attention is a mechanism for letting each token in a sequence gather information from all other tokens. Think of it as a soft lookup: each token asks a query ("what information do I need?"), every token broadcasts a key ("here's what I contain"), and when a query matches a key, the token retrieves that token's value ("here's my content").
In a sentence like "The animal didn't cross the street because it was too tired", the word "it" needs to figure out whether it refers to "animal" or "street". Self-attention allows "it" to attend strongly to "animal" (high query-key compatibility) and weakly to "street", incorporating the animal's representation into the encoding of "it".
This happens simultaneously for every token, in parallel. Unlike RNNs which process tokens left-to-right and struggle with long-range dependencies, self-attention has a direct path between any two tokens regardless of distance — one of the key reasons transformers outperform RNNs.
Given a sequence of token embeddings X ∈ ℝn×d (n tokens, d dimensions), self-attention learns three weight matrices: WQ, WK, WV ∈ ℝd×d_k.
These project X into three spaces:
import numpy as np
d_model = 512 # embedding dimension
d_k = 64 # key/query dimension (often d_model / num_heads)
n = 10 # sequence length
# Random input embeddings (in practice, from an embedding table + pos encoding)
X = np.random.randn(n, d_model)
# Learned projection matrices (initialised randomly, trained by backprop)
W_Q = np.random.randn(d_model, d_k) * 0.02
W_K = np.random.randn(d_model, d_k) * 0.02
W_V = np.random.randn(d_model, d_k) * 0.02
Q = X @ W_Q # (n, d_k) — what each token is looking for
K = X @ W_K # (n, d_k) — what each token contains / advertises
V = X @ W_V # (n, d_k) — what each token will contribute if attended to
Intuitively: Q is "what I'm looking for", K is "what I have to offer", V is "what I'll give you if you attend to me". The attention score between token i and token j is Q[i] · K[j] — how well token i's query matches token j's key.
The full formula: Attention(Q, K, V) = softmax(QKT / √d_k) · V
Why divide by √d_k? Without scaling, the dot products grow large as d_k increases, pushing softmax into saturation regions where gradients vanish. Dividing by √d_k keeps variance ≈ 1.
def scaled_dot_product_attention(Q, K, V, mask=None):
d_k = Q.shape[-1]
# Step 1: compute raw attention scores — (n, n) matrix
scores = Q @ K.T / np.sqrt(d_k)
# Step 2: apply mask (for causal or padding masking)
if mask is not None:
scores = scores + mask # mask contains -inf where attention is blocked
# Step 3: softmax — convert scores to probabilities
# subtract max for numerical stability
scores -= scores.max(axis=-1, keepdims=True)
attn_weights = np.exp(scores)
attn_weights /= attn_weights.sum(axis=-1, keepdims=True)
# Step 4: weighted sum of values
output = attn_weights @ V # (n, d_k)
return output, attn_weights
output, weights = scaled_dot_product_attention(Q, K, V)
print(f"Output shape: {output.shape}") # (10, 64)
print(f"Attention weights sum: {weights.sum(axis=-1)}") # all 1.0
import numpy as np
class SelfAttention:
def __init__(self, d_model: int, d_k: int):
self.d_k = d_k
# Xavier initialisation
scale = np.sqrt(2.0 / (d_model + d_k))
self.W_Q = np.random.randn(d_model, d_k) * scale
self.W_K = np.random.randn(d_model, d_k) * scale
self.W_V = np.random.randn(d_model, d_k) * scale
self.W_O = np.random.randn(d_k, d_model) * scale # output projection
def forward(self, X: np.ndarray, mask=None):
# X: (batch, seq_len, d_model)
Q = X @ self.W_Q # (batch, seq, d_k)
K = X @ self.W_K
V = X @ self.W_V
# Attention scores: (batch, seq, seq)
scores = (Q @ K.transpose(0, 2, 1)) / np.sqrt(self.d_k)
if mask is not None:
scores += mask
# Softmax
scores -= scores.max(axis=-1, keepdims=True)
weights = np.exp(scores)
weights /= weights.sum(axis=-1, keepdims=True)
# Weighted values + output projection
context = weights @ V # (batch, seq, d_k)
output = context @ self.W_O # (batch, seq, d_model)
return output, weights
# Test
batch, seq, d_model, d_k = 2, 8, 64, 32
X = np.random.randn(batch, seq, d_model)
attn = SelfAttention(d_model, d_k)
out, w = attn.forward(X)
print(f"Output: {out.shape}") # (2, 8, 64)
print(f"Weights row sum: {w[0, 0].sum():.4f}") # 1.0
In decoder-only models (GPT, Llama), token i should only attend to tokens 0..i — not future tokens. This is enforced with a causal mask: a triangular matrix of -∞ values in the upper triangle.
def causal_mask(seq_len: int) -> np.ndarray:
# Upper triangle (excluding diagonal) = -inf, lower triangle = 0
mask = np.triu(np.full((seq_len, seq_len), -np.inf), k=1)
return mask # shape: (seq_len, seq_len)
# Visualise for seq_len=4
mask = causal_mask(4)
print(mask)
# [[ 0. -inf -inf -inf]
# [ 0. 0. -inf -inf]
# [ 0. 0. 0. -inf]
# [ 0. 0. 0. 0.]]
# After adding to scores and softmax:
# Token 0 attends only to itself
# Token 1 attends to tokens 0 and 1
# Token 3 attends to all 4 tokens
import torch, torch.nn.functional as F
def causal_attention(Q, K, V):
seq = Q.shape[-2]
mask = torch.triu(torch.full((seq, seq), float('-inf')), diagonal=1)
scores = (Q @ K.transpose(-2, -1)) / Q.shape[-1] ** 0.5
return F.softmax(scores + mask, dim=-1) @ V
# PyTorch also has built-in causal attention:
# F.scaled_dot_product_attention(Q, K, V, is_causal=True)
Self-attention has O(n²·d) time and O(n²) memory complexity — the attention matrix is n×n. For n=512 this is fine; for n=32768 (32K context) it's 32768² = 1 billion entries per layer per head.
This quadratic scaling is why long-context models are expensive, and why alternatives like Flash Attention (IO-aware tiling), linear attention, and SSMs (Mamba) exist.
import torch
# PyTorch's optimised attention — use this in practice
Q = torch.randn(1, 8, 512, 64) # (batch, heads, seq, d_k)
K = torch.randn(1, 8, 512, 64)
V = torch.randn(1, 8, 512, 64)
# Flash Attention under the hood when available (requires PyTorch 2.0+)
out = torch.nn.functional.scaled_dot_product_attention(
Q, K, V,
attn_mask=None,
dropout_p=0.0,
is_causal=True, # enables causal masking without materialising the mask
)
print(out.shape) # (1, 8, 512, 64)
# Memory comparison: standard vs Flash Attention
# Standard: stores full n×n attention matrix in HBM — O(n²) memory
# Flash Attention: tiles computation, never stores full matrix — O(n) memory
# For n=32768: standard = 4GB per head; Flash Attention = ~32MB per head
Softmax in float16 overflows. QKT values can be large; in float16 (max ~65504), exp(score) overflows to NaN. Always subtract the max before softmax (numerically stable softmax) and consider running attention in float32 even when the rest of the model is float16. PyTorch's scaled_dot_product_attention handles this correctly.
d_k must be divisible by num_heads. In multi-head attention, d_k is typically d_model / num_heads. If d_model=512 and num_heads=6, d_k=85.3 — not an integer. Choose num_heads that evenly divides d_model (8 or 16 for d_model=512).
Attention is not the only operation. The full transformer block is: LayerNorm → Attention → residual add → LayerNorm → FFN → residual add. The FFN (2 linear layers with activation) typically uses 4× the dimension of attention and accounts for ~67% of parameters. Don't optimise attention in isolation.
| Variant | KV Heads | Memory | Quality | Used In |
|---|---|---|---|---|
| Multi-Head Attention (MHA) | = query heads | High | Highest | GPT-2, BERT |
| Multi-Query Attention (MQA) | 1 shared KV head | Very low | Slightly lower | PaLM, Falcon |
| Grouped-Query Attention (GQA) | H/G groups | Medium | Close to MHA | Llama 2/3, Mistral, Gemma |
| Sliding Window Attention | = query heads (windowed) | Low at long ctx | Good for local patterns | Mistral, Longformer |
Grouped-Query Attention (GQA) has become the dominant choice in modern open-weight models because it achieves near-MHA quality with much lower KV cache memory requirements. For a 7B model with 32 query heads and 8 KV groups, the KV cache is 4x smaller than MHA, directly reducing memory pressure during inference and enabling larger batch sizes or longer context windows at the same VRAM budget. When comparing model architectures for deployment, always check the number of KV heads -- it has a larger impact on serving cost than total parameter count for long-context workloads.
The computational complexity of self-attention — O(n²) in sequence length — was the central bottleneck in scaling transformers to long contexts. Flash Attention addressed this by restructuring the attention computation into SRAM-friendly tiles, dramatically reducing the number of HBM memory reads and writes. The mathematical result is identical to standard attention, but the memory access pattern is far more cache-efficient, enabling both faster execution and reduced peak memory usage.
Relative position encodings (like RoPE and ALiBi) were developed to give self-attention a better inductive bias for position than absolute sinusoidal encodings. RoPE rotates query and key vectors in a position-dependent manner so that the dot product naturally encodes relative distance, allowing models to generalize better to sequence lengths not seen during training. This property has made RoPE the de facto standard in modern open-weight models like Llama and Mistral.
Multi-head attention's practical benefit over single-head attention is that different heads specialize in capturing different types of relationships simultaneously. Empirical studies show that some heads consistently attend to syntactic structure (subject-verb agreement), others to coreference (tracking which "it" refers to), and still others to local phrase patterns. This specialization emerges from training without explicit supervision, and pruning attention heads that have near-zero contribution can reduce model size with minimal quality loss.