Attention Mechanisms

Self-Attention

The Q·Kᵀ/√d mechanism that lets every token attend to every other token — the single operation that defines transformers. Understand this and the rest of the architecture falls into place.

O(n²)
Time complexity
√d
Scaling factor
Softmax
Normalisation

Table of Contents

SECTION 01

The intuition

Self-attention is a mechanism for letting each token in a sequence gather information from all other tokens. Think of it as a soft lookup: each token asks a query ("what information do I need?"), every token broadcasts a key ("here's what I contain"), and when a query matches a key, the token retrieves that token's value ("here's my content").

In a sentence like "The animal didn't cross the street because it was too tired", the word "it" needs to figure out whether it refers to "animal" or "street". Self-attention allows "it" to attend strongly to "animal" (high query-key compatibility) and weakly to "street", incorporating the animal's representation into the encoding of "it".

This happens simultaneously for every token, in parallel. Unlike RNNs which process tokens left-to-right and struggle with long-range dependencies, self-attention has a direct path between any two tokens regardless of distance — one of the key reasons transformers outperform RNNs.

SECTION 02

Queries, keys, and values

Given a sequence of token embeddings X ∈ ℝn×d (n tokens, d dimensions), self-attention learns three weight matrices: WQ, WK, WV ∈ ℝd×d_k.

These project X into three spaces:

import numpy as np

d_model = 512   # embedding dimension
d_k = 64        # key/query dimension (often d_model / num_heads)
n = 10          # sequence length

# Random input embeddings (in practice, from an embedding table + pos encoding)
X = np.random.randn(n, d_model)

# Learned projection matrices (initialised randomly, trained by backprop)
W_Q = np.random.randn(d_model, d_k) * 0.02
W_K = np.random.randn(d_model, d_k) * 0.02
W_V = np.random.randn(d_model, d_k) * 0.02

Q = X @ W_Q   # (n, d_k)  — what each token is looking for
K = X @ W_K   # (n, d_k)  — what each token contains / advertises
V = X @ W_V   # (n, d_k)  — what each token will contribute if attended to

Intuitively: Q is "what I'm looking for", K is "what I have to offer", V is "what I'll give you if you attend to me". The attention score between token i and token j is Q[i] · K[j] — how well token i's query matches token j's key.

SECTION 03

Scaled dot-product attention

The full formula: Attention(Q, K, V) = softmax(QKT / √d_k) · V

Why divide by √d_k? Without scaling, the dot products grow large as d_k increases, pushing softmax into saturation regions where gradients vanish. Dividing by √d_k keeps variance ≈ 1.

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.shape[-1]

    # Step 1: compute raw attention scores — (n, n) matrix
    scores = Q @ K.T / np.sqrt(d_k)

    # Step 2: apply mask (for causal or padding masking)
    if mask is not None:
        scores = scores + mask   # mask contains -inf where attention is blocked

    # Step 3: softmax — convert scores to probabilities
    # subtract max for numerical stability
    scores -= scores.max(axis=-1, keepdims=True)
    attn_weights = np.exp(scores)
    attn_weights /= attn_weights.sum(axis=-1, keepdims=True)

    # Step 4: weighted sum of values
    output = attn_weights @ V   # (n, d_k)
    return output, attn_weights

output, weights = scaled_dot_product_attention(Q, K, V)
print(f"Output shape: {output.shape}")   # (10, 64)
print(f"Attention weights sum: {weights.sum(axis=-1)}")  # all 1.0
SECTION 04

NumPy implementation

import numpy as np

class SelfAttention:
    def __init__(self, d_model: int, d_k: int):
        self.d_k = d_k
        # Xavier initialisation
        scale = np.sqrt(2.0 / (d_model + d_k))
        self.W_Q = np.random.randn(d_model, d_k) * scale
        self.W_K = np.random.randn(d_model, d_k) * scale
        self.W_V = np.random.randn(d_model, d_k) * scale
        self.W_O = np.random.randn(d_k, d_model) * scale  # output projection

    def forward(self, X: np.ndarray, mask=None):
        # X: (batch, seq_len, d_model)
        Q = X @ self.W_Q   # (batch, seq, d_k)
        K = X @ self.W_K
        V = X @ self.W_V

        # Attention scores: (batch, seq, seq)
        scores = (Q @ K.transpose(0, 2, 1)) / np.sqrt(self.d_k)

        if mask is not None:
            scores += mask

        # Softmax
        scores -= scores.max(axis=-1, keepdims=True)
        weights = np.exp(scores)
        weights /= weights.sum(axis=-1, keepdims=True)

        # Weighted values + output projection
        context = weights @ V   # (batch, seq, d_k)
        output = context @ self.W_O  # (batch, seq, d_model)
        return output, weights

# Test
batch, seq, d_model, d_k = 2, 8, 64, 32
X = np.random.randn(batch, seq, d_model)
attn = SelfAttention(d_model, d_k)
out, w = attn.forward(X)
print(f"Output: {out.shape}")  # (2, 8, 64)
print(f"Weights row sum: {w[0, 0].sum():.4f}")  # 1.0
SECTION 05

Causal masking

In decoder-only models (GPT, Llama), token i should only attend to tokens 0..i — not future tokens. This is enforced with a causal mask: a triangular matrix of -∞ values in the upper triangle.

def causal_mask(seq_len: int) -> np.ndarray:
    # Upper triangle (excluding diagonal) = -inf, lower triangle = 0
    mask = np.triu(np.full((seq_len, seq_len), -np.inf), k=1)
    return mask  # shape: (seq_len, seq_len)

# Visualise for seq_len=4
mask = causal_mask(4)
print(mask)
# [[  0. -inf -inf -inf]
#  [  0.   0. -inf -inf]
#  [  0.   0.   0. -inf]
#  [  0.   0.   0.   0.]]

# After adding to scores and softmax:
# Token 0 attends only to itself
# Token 1 attends to tokens 0 and 1
# Token 3 attends to all 4 tokens

import torch, torch.nn.functional as F

def causal_attention(Q, K, V):
    seq = Q.shape[-2]
    mask = torch.triu(torch.full((seq, seq), float('-inf')), diagonal=1)
    scores = (Q @ K.transpose(-2, -1)) / Q.shape[-1] ** 0.5
    return F.softmax(scores + mask, dim=-1) @ V

# PyTorch also has built-in causal attention:
# F.scaled_dot_product_attention(Q, K, V, is_causal=True)
SECTION 06

Computational cost

Self-attention has O(n²·d) time and O(n²) memory complexity — the attention matrix is n×n. For n=512 this is fine; for n=32768 (32K context) it's 32768² = 1 billion entries per layer per head.

This quadratic scaling is why long-context models are expensive, and why alternatives like Flash Attention (IO-aware tiling), linear attention, and SSMs (Mamba) exist.

import torch

# PyTorch's optimised attention — use this in practice
Q = torch.randn(1, 8, 512, 64)   # (batch, heads, seq, d_k)
K = torch.randn(1, 8, 512, 64)
V = torch.randn(1, 8, 512, 64)

# Flash Attention under the hood when available (requires PyTorch 2.0+)
out = torch.nn.functional.scaled_dot_product_attention(
    Q, K, V,
    attn_mask=None,
    dropout_p=0.0,
    is_causal=True,     # enables causal masking without materialising the mask
)
print(out.shape)  # (1, 8, 512, 64)

# Memory comparison: standard vs Flash Attention
# Standard: stores full n×n attention matrix in HBM — O(n²) memory
# Flash Attention: tiles computation, never stores full matrix — O(n) memory
# For n=32768: standard = 4GB per head; Flash Attention = ~32MB per head
SECTION 07

Gotchas

Softmax in float16 overflows. QKT values can be large; in float16 (max ~65504), exp(score) overflows to NaN. Always subtract the max before softmax (numerically stable softmax) and consider running attention in float32 even when the rest of the model is float16. PyTorch's scaled_dot_product_attention handles this correctly.

d_k must be divisible by num_heads. In multi-head attention, d_k is typically d_model / num_heads. If d_model=512 and num_heads=6, d_k=85.3 — not an integer. Choose num_heads that evenly divides d_model (8 or 16 for d_model=512).

Attention is not the only operation. The full transformer block is: LayerNorm → Attention → residual add → LayerNorm → FFN → residual add. The FFN (2 linear layers with activation) typically uses 4× the dimension of attention and accounts for ~67% of parameters. Don't optimise attention in isolation.

SECTION 08

Attention Variant Comparison

VariantKV HeadsMemoryQualityUsed In
Multi-Head Attention (MHA)= query headsHighHighestGPT-2, BERT
Multi-Query Attention (MQA)1 shared KV headVery lowSlightly lowerPaLM, Falcon
Grouped-Query Attention (GQA)H/G groupsMediumClose to MHALlama 2/3, Mistral, Gemma
Sliding Window Attention= query heads (windowed)Low at long ctxGood for local patternsMistral, Longformer

Grouped-Query Attention (GQA) has become the dominant choice in modern open-weight models because it achieves near-MHA quality with much lower KV cache memory requirements. For a 7B model with 32 query heads and 8 KV groups, the KV cache is 4x smaller than MHA, directly reducing memory pressure during inference and enabling larger batch sizes or longer context windows at the same VRAM budget. When comparing model architectures for deployment, always check the number of KV heads -- it has a larger impact on serving cost than total parameter count for long-context workloads.

The computational complexity of self-attention — O(n²) in sequence length — was the central bottleneck in scaling transformers to long contexts. Flash Attention addressed this by restructuring the attention computation into SRAM-friendly tiles, dramatically reducing the number of HBM memory reads and writes. The mathematical result is identical to standard attention, but the memory access pattern is far more cache-efficient, enabling both faster execution and reduced peak memory usage.

Relative position encodings (like RoPE and ALiBi) were developed to give self-attention a better inductive bias for position than absolute sinusoidal encodings. RoPE rotates query and key vectors in a position-dependent manner so that the dot product naturally encodes relative distance, allowing models to generalize better to sequence lengths not seen during training. This property has made RoPE the de facto standard in modern open-weight models like Llama and Mistral.

Multi-head attention's practical benefit over single-head attention is that different heads specialize in capturing different types of relationships simultaneously. Empirical studies show that some heads consistently attend to syntactic structure (subject-verb agreement), others to coreference (tracking which "it" refers to), and still others to local phrase patterns. This specialization emerges from training without explicit supervision, and pruning attention heads that have near-zero contribution can reduce model size with minimal quality loss.