Transformer Architecture

Grouped Query Attention

GQA reduces KV-cache memory by sharing key-value heads across multiple query heads. Llama 3, Mistral, and Gemma all use GQA β€” it's the default attention pattern for modern efficient LLMs.

MHA β†’ GQA
8 KV heads vs 32 query heads
4-8Γ— KV reduction
Same model quality
Llama 3, Mistral
Production default

Table of Contents

SECTION 01

MHA vs MQA vs GQA

Three variants of multi-head attention differ in how many key-value (KV) head sets they maintain:

SECTION 02

Why KV cache size matters

During autoregressive generation, you cache the K and V tensors for every previously generated token to avoid recomputing them. For a model with H attention heads, d_model dimensions, and context length L, the KV cache size per layer is:

2 Γ— L Γ— d_head Γ— H_kv Γ— bytes_per_element

For Llama 3 70B (bf16) with 80 layers, 4096 d_head, 8 KV heads, at context length 8192:

2 Γ— 8192 Γ— 128 Γ— 8 Γ— 80 Γ— 2 bytes β‰ˆ 21 GB

With full MHA (64 KV heads) this would be 168 GB β€” impossible to fit on a single GPU. GQA makes long-context 70B inference feasible on 2–4 GPUs.

SECTION 03

GQA implementation

import torch, torch.nn as nn, math

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, n_kv_heads: int):
        super().__init__()
        assert n_heads % n_kv_heads == 0, "n_heads must be divisible by n_kv_heads"
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.n_rep = n_heads // n_kv_heads   # query heads per KV group
        self.d_head = d_model // n_heads

        self.W_q = nn.Linear(d_model, n_heads * self.d_head, bias=False)
        self.W_k = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
        self.W_v = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, mask=None):
        B, T, _ = x.shape
        Q = self.W_q(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        K = self.W_k(x).view(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)
        V = self.W_v(x).view(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)

        # Repeat KV heads to match query head count
        K = K.repeat_interleave(self.n_rep, dim=1)   # (B, n_heads, T, d_head)
        V = V.repeat_interleave(self.n_rep, dim=1)

        scale = math.sqrt(self.d_head)
        attn = torch.softmax(torch.matmul(Q, K.transpose(-2, -1)) / scale, dim=-1)
        if mask is not None:
            attn = attn.masked_fill(mask == 0, 0)

        out = torch.matmul(attn, V).transpose(1, 2).contiguous().view(B, T, -1)
        return self.W_o(out)
SECTION 04

Memory savings calculation

def kv_cache_gb(n_layers, seq_len, n_kv_heads, d_head, bytes_per_elem=2):
    # bytes_per_elem: 2 for bf16/fp16, 4 for fp32
    bytes = 2 * n_layers * seq_len * n_kv_heads * d_head * bytes_per_elem
    return bytes / 1e9

# Llama 3 70B specs
print("MHA (64 KV heads):", kv_cache_gb(80, 8192, 64, 128), "GB")
print("GQA (8 KV heads): ", kv_cache_gb(80, 8192,  8, 128), "GB")
# MHA (64 KV heads): 168.95 GB
# GQA (8 KV heads):   21.12 GB
SECTION 05

GQA in Llama 3 and Mistral

All modern efficient LLMs use GQA. Llama 3 8B uses 8 KV heads and 32 query heads (4:1 ratio). Llama 3 70B uses 8 KV heads and 64 query heads (8:1 ratio). Mistral 7B uses 8 KV heads and 32 query heads. Gemma 2 uses GQA with varying group sizes by layer size.

The configuration is in the model's config.json: look for num_key_value_heads vs num_attention_heads. If they differ, it's GQA (or MQA if num_key_value_heads == 1).

SECTION 06

Training and conversion

GQA models must be trained (or fine-tuned) with GQA from the start β€” you can't simply remove KV heads from a trained MHA model without significant quality loss. However, the Ainslie et al. (2023) GQA paper showed that you can convert an MHA checkpoint to GQA by mean-pooling the KV heads of each group and then fine-tuning for a fraction of the original training compute (<5%), recovering most of the original quality.

SECTION 07

Gotchas

SECTION 08

GQA in practice: adoption and benchmarks

GQA has been adopted by major models including Llama 2, GPT-3.5, and newer Qwen variants. Empirical benchmarks show GQA typically preserves quality within 1–2% of full MQA on standard benchmarks while reducing KV cache by 8Γ— compared to MHA, and 4Γ— compared to MQA.

# Typical memory savings with GQA (4 query groups, H=128 heads)
# MHA:  H heads Γ— D dims = 128 Γ— 128 = 16,384 KV cache per token per batch
# MQA:  1 head Γ— D dims  = 128 KV cache per token per batch
# GQA:  H/G heads Γ— D dims = 32 Γ— 128 = 4,096 KV cache per token per batch

# For a 7B model, seq_len=2048, batch=32:
# MHA:  16,384 Γ— 2048 Γ— 32 β‰ˆ 1.06 GB
# GQA:  4,096 Γ— 2048 Γ— 32  β‰ˆ 268 MB (4Γ— smaller)

The tradeoff is compute: multi-headed attention is simpler to parallelize. Some frameworks still use full MHA during inference despite the memory cost, prioritizing training speed over memory efficiency.

VariantKV Cache SizeInference LatencyQuality vs MHA
MHA (baseline)Full (1Γ—)Baseline100%
MQA1/H~10% faster98–99%
GQA (G=4)1/(H/4)~5% faster99–100%
GQA (G=8)1/(H/8)~2% faster99.5–100%

GQA implementation details: Grouped query attention is essentially a compromise: instead of a single shared key-value head (MQA) for all query heads, you have a small number of shared heads, usually 4 or 8. During attention computation, each query head is replicated or broadcasted to match the KV head group. This increases parameter count slightly compared to MQA but maintains most of the memory savings. In training, the gradient flow to KV heads is summed across the group.

The adoption of GQA in recent large models (Llama 2, Qwen, GPT-3.5) signals industry confidence in the tradeoff. Newer research explores hybrid approaches: using full MHA early in training (for expressiveness) and switching to GQA during inference (for speed). Dynamic GQA, where the number of groups adapts based on sequence length, is an active research area.

GQA performance across sequence lengths: GQA's benefits scale with sequence length. For short sequences (tokens), all attention variants are fast; the KV cache size doesn't matter. But as sequences reach thousands of tokens (long documents, conversations), memory bandwidth becomes the bottleneck. GQA's smaller cache reduces memory bandwidth, yielding latency improvements of 2–5Γ— on long sequences while maintaining quality. This makes GQA especially valuable for long-context models (32K, 128K tokens).

Future work in attention mechanisms explores even more radical compression: query-key shared attention (single head for everything), learnable key-value compression (lossy dimensionality reduction), and sparse attention patterns (attending to only important tokens). These techniques are still research-stage but could eventually replace GQA if they prove practical. For now, GQA represents a mature, well-tested sweet spot.

Practitioners often ask: should I use GQA or MQA? GQA is the safe choiceβ€”it maintains quality better while still achieving significant compression. MQA is more aggressive (single KV head) and sometimes fails on reasoning tasks. If you're designing a model from scratch, start with GQA and only move to MQA if you're certain quality loss is acceptable.

GQA implementation details in frameworks: PyTorch and TensorFlow both support grouped query attention through custom CUDA kernels or TensorRT optimizations. HuggingFace models (Llama 2, Qwen, Mistral) expose GQA configuration in config.json: `"num_key_value_heads": 8` with `"num_attention_heads": 32` means 4 groups. When fine-tuning or quantizing a GQA model, ensure the grouping structure is preserved to avoid silent quality degradation.

Inference optimization: vLLM and TensorRT both auto-detect and optimize GQA models. If you're using a custom inference engine, ensure KV cache management respects the grouping (reusing KV heads for multiple query groups). Failure to do this results in incorrectly shaped attention computations and wrong outputs.

Scaling to production: GQA models handle 8Γ— larger batch sizes with the same GPU memory compared to MHA, enabling higher throughput. For serving applications, this translates directly to better cost-per-inference. Benchmarking on your hardware is essential: latency improvements vary based on batch size, sequence length, and memory bandwidth characteristics of your GPU.

GQA in the broader attention landscape: The attention mechanism is a bottleneck in transformer inference. Research explores many alternatives: sparse attention (attend to only relevant tokens, computed via learned indices), flash attention (hardware-aware optimization with fewer memory transfers), and grouped query attention (reduced KV heads). Flash attention and GQA are complementary; many recent models combine both. GQA is the simpler, more mature approach and will likely remain dominant for several years. Practitioners should understand GQA as a standard technique and know when to apply it (any model with KV cache bottlenecks, especially long-context models). For new projects, default to GQA unless you have specific reasons not to.