Attention Mechanisms

Multi-Head Attention

h parallel attention heads with different learned projections β€” each specialises in different relationship types (syntax, coreference, position). Outputs are concatenated and projected back to d_model.

h heads
Parallel
d_model/h
Per-head dim
4 weight
Matrices

Table of Contents

SECTION 01

Why multiple heads

A single attention head computes one weighted combination of values. But language has multiple simultaneous relationship structures: word A might be the syntactic subject of B, while also being co-referential with C, and positionally adjacent to D. A single head can only capture one weighted combination β€” it can't represent all these relationships simultaneously.

Multi-head attention runs h independent attention functions in parallel, each with its own learned Q, K, V projections. Head 1 might learn to track syntactic dependencies; head 2 might track coreference; head 3 might attend to nearby tokens. Each head projects into a lower-dimensional subspace (d_k = d_model/h), so the total computation cost is similar to one full-dimensional head.

The outputs of all heads are concatenated into a d_model-dimensional vector and passed through a final linear projection W_O. This projection learns how to combine the diverse information captured by different heads.

SECTION 02

The MHA architecture

For h heads and d_model dimensions, each head i uses:

where d_k = d_v = d_model / h.

The output projection WO ∈ ℝhΒ·d_v Γ— d_model maps the concatenated head outputs back to d_model.

Total parameters per MHA layer: 4 Γ— d_modelΒ² (three input projections + one output projection). For d_model=768 (BERT-base): 4 Γ— 768Β² β‰ˆ 2.4M parameters per MHA layer.

SECTION 03

Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
        super().__init__()
        assert d_model % num_heads == 0
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # In practice, W_Q/K/V for all heads are stacked into one matrix
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        B, N, D = x.shape
        h = self.num_heads

        # Project and split into heads: (B, N, D) -> (B, h, N, d_k)
        Q = self.W_Q(x).view(B, N, h, self.d_k).transpose(1, 2)
        K = self.W_K(x).view(B, N, h, self.d_k).transpose(1, 2)
        V = self.W_V(x).view(B, N, h, self.d_k).transpose(1, 2)

        # Scaled dot-product attention for all heads at once
        out = F.scaled_dot_product_attention(Q, K, V,
            attn_mask=mask, dropout_p=self.dropout.p if self.training else 0.0)

        # Merge heads: (B, h, N, d_k) -> (B, N, D)
        out = out.transpose(1, 2).contiguous().view(B, N, D)
        return self.W_O(out)

# Test
model = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 16, 512)  # (batch, seq_len, d_model)
out = model(x)
print(out.shape)  # (2, 16, 512)
print(f"Params: {sum(p.numel() for p in model.parameters()):,}")  # ~1M
SECTION 04

What heads specialise in

Research on BERT's attention heads (Clark et al. 2019) found clear specialisation:

Syntactic heads: heads that strongly attend to direct objects from verbs, or to heads of noun phrases. These effectively implement dependency parsing within the attention weights.

Coreference heads: certain heads attend from pronouns to their antecedents β€” the "it β†’ animal" example from before. These heads track entity co-reference across sentences.

Positional heads: heads that primarily attend to the next or previous token, implementing something like a local sliding window. Common in early layers.

Separator heads: heads that almost always attend to the [SEP] token β€” functioning as a "soft no-op" where a head doesn't have useful information to contribute.

Not all heads are equally useful β€” up to 20% of heads can be pruned from BERT with minimal quality loss. Heads that specialise clearly tend to be more robust to pruning than heads with diffuse attention patterns.

SECTION 05

GQA and MQA

Standard MHA uses h distinct K and V projections β€” one per head. This means the KV cache (stored during inference for each previously-generated token) scales as O(n Γ— h Γ— 2 Γ— d_k) β€” expensive for long contexts.

Multi-Query Attention (MQA): all heads share a single K and V projection. Only Q has h separate projections. Reduces KV cache by hΓ—, but can hurt quality on some tasks.

Grouped-Query Attention (GQA): compromise between MHA and MQA. g groups of heads share K/V. Llama 3 uses GQA with g=8 groups for 70B, giving an 8Γ— KV cache reduction vs full MHA with minimal quality loss.

class GQA(nn.Module):
    def __init__(self, d_model, num_heads, num_kv_heads):
        super().__init__()
        assert num_heads % num_kv_heads == 0
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.d_k = d_model // num_heads
        self.groups = num_heads // num_kv_heads

        self.W_Q = nn.Linear(d_model, num_heads * self.d_k, bias=False)
        self.W_K = nn.Linear(d_model, num_kv_heads * self.d_k, bias=False)
        self.W_V = nn.Linear(d_model, num_kv_heads * self.d_k, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x):
        B, N, D = x.shape
        Q = self.W_Q(x).view(B, N, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_K(x).view(B, N, self.num_kv_heads, self.d_k).transpose(1, 2)
        V = self.W_V(x).view(B, N, self.num_kv_heads, self.d_k).transpose(1, 2)
        # Repeat K/V for each group of Q heads
        K = K.repeat_interleave(self.groups, dim=1)
        V = V.repeat_interleave(self.groups, dim=1)
        out = F.scaled_dot_product_attention(Q, K, V)
        return self.W_O(out.transpose(1,2).contiguous().view(B, N, D))
SECTION 06

PyTorch built-in

import torch.nn as nn

# PyTorch's nn.MultiheadAttention β€” production-ready implementation
mha = nn.MultiheadAttention(
    embed_dim=512,
    num_heads=8,
    dropout=0.1,
    bias=True,
    batch_first=True,   # input shape: (batch, seq, features) instead of (seq, batch, features)
)

x = torch.randn(2, 16, 512)

# Self-attention (query = key = value = x)
output, attn_weights = mha(x, x, x)
print(output.shape)        # (2, 16, 512)
print(attn_weights.shape)  # (2, 16, 16) β€” averaged over heads by default

# With causal mask (for decoder)
causal_mask = nn.Transformer.generate_square_subsequent_mask(16)
output, _ = mha(x, x, x, attn_mask=causal_mask)

# Cross-attention (query from decoder, key/value from encoder)
enc_out = torch.randn(2, 32, 512)  # encoder output, different seq length
dec_in  = torch.randn(2, 16, 512)  # decoder input
output, _ = mha(dec_in, enc_out, enc_out)  # Q from dec, K/V from enc
SECTION 07

Gotchas

num_heads must divide d_model evenly. d_k = d_model / num_heads must be an integer. A common mistake: using num_heads=6 with d_model=512 gives d_k=85.3. Standard choices: 8 heads with d_model=512 (d_k=64), 12 heads with d_model=768 (d_k=64), 32 heads with d_model=4096 (d_k=128).

W_O is often forgotten. The output projection W_O is essential β€” it's what allows different heads' information to be mixed. Implementing MHA without W_O (just concatenating and returning) loses the model's ability to combine cross-head information and produces significantly worse results.

Attention dropout is applied to weights, not values. Dropout in attention zeroes out entire attention edges (rows in the weight matrix), forcing the model to route information through different paths. It's applied after softmax, before the weighted sum of V β€” not to the output.

Multi-Head Attention Variants Compared

Multi-head attention (MHA) splits the model dimension into H parallel attention heads, allowing the model to jointly attend to information from different representation subspaces. Variants like Multi-Query Attention (MQA) and Grouped-Query Attention (GQA) reduce the KV cache memory burden β€” a critical bottleneck during inference β€” by sharing key and value projections across multiple query heads.

VariantKV HeadsKV Cache SizeQuality vs MHAUsed In
MHA= Query headsFullBaselineGPT-2, BERT
MQA1MinimalSlight degradationPaLM, Falcon
GQAG groups (2–8)G/H Γ— FullNear MHALlama 2/3, Mistral
MLALow-rank KVVery smallNear MHADeepSeek-V2/V3

Grouped-Query Attention strikes the best practical balance between quality and efficiency. By grouping query heads to share a single KV pair within each group, GQA reduces the KV cache by a factor equal to the group size while retaining most of the expressive power of full MHA. At inference time, the reduced KV cache means more sequences can fit in GPU memory simultaneously, directly improving throughput for batched serving workloads.

Multi-head Latent Attention (MLA), introduced in DeepSeek-V2, takes a different approach: it compresses the KV cache by projecting keys and values into a low-rank latent space, then reconstructing them at attention time. This achieves even smaller KV cache footprints than GQA while maintaining high model quality, at the cost of slightly higher per-token computation during the reconstruction step.

The number of attention heads is typically set to H = d_model / d_head, where d_head is fixed at 64 or 128 dimensions per head. Empirically, smaller d_head values (64) train more stably and generalize better, while larger d_head values reduce the number of heads needed for the same model dimension, lowering the overhead of the head-splitting and concatenation operations. Modern models like Llama 3 use d_head = 128 with 32–64 heads for the 8B–70B size range.

Flash Attention implementation details interact with multi-head variants in important ways. For GQA, the FlashAttention-2 kernel handles the shared KV projection natively, broadcasting the same key and value tensors across the query heads in each group without materializing separate copies. This implementation detail means GQA does not just save KV cache memory at inference time β€” it also reduces the memory bandwidth required during training, accelerating training throughput by 15–25% for models using GQA versus full MHA.

Attention head pruning studies consistently find that a significant fraction of attention heads in trained models contribute negligibly to output quality and can be removed with minimal degradation. Structured pruning methods that target entire heads (rather than individual weights) produce models that are both smaller and faster, since the reduction in head count directly reduces the dimensionality of the Q, K, V projection matrices. Models compressed through head pruning often require brief continued pre-training to recover any quality loss.