Transformer Architecture

Cross-Attention

Cross-attention lets a decoder query information from an encoder's output. The mechanism that connects encoder and decoder in seq2seq models โ€” and reappears in multimodal architectures to fuse vision and language.

Q from decoder
KV from encoder
Seq2seq core
Translation, summarisation
Multimodal
Vision-language fusion

Table of Contents

SECTION 01

Cross-attention vs self-attention

In self-attention, queries, keys, and values all come from the same sequence โ€” each position attends to every other position in the same sequence. In cross-attention, the queries come from one sequence (typically the decoder's hidden states) while the keys and values come from a different sequence (typically the encoder's output). This lets the decoder look up which parts of the source are most relevant for generating the next token.

Cross-attention is the mechanism that connects encoder and decoder in seq2seq architectures like the original Transformer (Vaswani et al. 2017), T5, and BART. It also reappears in multimodal models (LLaVA, Flamingo) to fuse visual features into a language decoder.

SECTION 02

Mathematical formulation

Given decoder hidden states H_dec โˆˆ โ„T_dec ร— d_model and encoder output H_enc โˆˆ โ„T_enc ร— d_model:

Attention(Q, K, V) = softmax(QKT / โˆšd_k) ยท V

The result has shape T_dec ร— d_v โ€” each decoder position gets a weighted combination of encoder positions. The attention weights (shape T_dec ร— T_enc) show which source positions the decoder is attending to when generating each target token.

SECTION 03

Implementation in PyTorch

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class CrossAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, dec_hidden, enc_output, enc_mask=None):
        # dec_hidden: (B, T_dec, d_model)
        # enc_output: (B, T_enc, d_model)
        B, T_dec, _ = dec_hidden.shape
        T_enc = enc_output.shape[1]

        Q = self.W_q(dec_hidden).view(B, T_dec, self.n_heads, self.d_head).transpose(1, 2)
        K = self.W_k(enc_output).view(B, T_enc, self.n_heads, self.d_head).transpose(1, 2)
        V = self.W_v(enc_output).view(B, T_enc, self.n_heads, self.d_head).transpose(1, 2)

        scale = math.sqrt(self.d_head)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / scale  # (B, H, T_dec, T_enc)

        if enc_mask is not None:
            scores = scores.masked_fill(enc_mask.unsqueeze(1).unsqueeze(2) == 0, -1e9)

        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, V)  # (B, H, T_dec, d_head)
        out = out.transpose(1, 2).contiguous().view(B, T_dec, -1)
        return self.W_o(out), attn  # return attn weights for visualisation
SECTION 04

Role in encoder-decoder models

In the original Transformer decoder, each layer has three sublayers: (1) masked self-attention over previously generated tokens, (2) cross-attention over the encoder output, and (3) a feed-forward network. The cross-attention sublayer is what allows the model to focus on relevant source tokens when generating each target token.

Modern variants like T5 follow the same pattern. During inference, the encoder processes the input once; the decoder runs cross-attention at every step, re-attending to encoder states. This is efficient because encoder keys and values can be cached for the entire generation.

SECTION 05

Cross-attention in multimodal models

Cross-attention is the dominant mechanism for fusing modalities. In Flamingo, visual features from a frozen vision encoder are injected into a language model via cross-attention layers inserted between transformer blocks. In LLaVA, visual tokens are projected into the language model's embedding space and prepended to the text โ€” effectively using the language model's self-attention as a form of cross-attention. In Stable Diffusion's UNet, cross-attention fuses CLIP text embeddings with image features at multiple resolutions to condition the denoising process on the text prompt.

SECTION 06

Attention patterns visualised

When you inspect cross-attention weights in a trained translation model, clear patterns emerge: when generating a verb in German, the model strongly attends to the corresponding verb in the English source. In summarisation models, attention heads focus on key sentences in the source document. This interpretability is one reason cross-attention is preferred over alternatives like concatenation for fusing two sequences.

import matplotlib.pyplot as plt

def plot_cross_attn(attn_weights, src_tokens, tgt_tokens, head=0):
    # attn_weights: (n_heads, T_dec, T_enc)
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.imshow(attn_weights[head].cpu().numpy(), cmap="Blues", aspect="auto")
    ax.set_xticks(range(len(src_tokens)))
    ax.set_xticklabels(src_tokens, rotation=45, ha="right")
    ax.set_yticks(range(len(tgt_tokens)))
    ax.set_yticklabels(tgt_tokens)
    ax.set_xlabel("Source (encoder)"); ax.set_ylabel("Target (decoder)")
    plt.tight_layout(); plt.show()
SECTION 07

Gotchas

Cross-Attention Applications in LLMs

Cross-attention allows one sequence to attend to another, enabling transformer architectures to condition generation on external information. While self-attention queries a sequence against itself, cross-attention queries the target sequence against a separate source sequence, making it the mechanism for encoder-decoder interaction in translation and summarization architectures and for conditioning image generation on text prompts in diffusion models.

ArchitectureCross-Attention RoleSource SequenceTarget Sequence
Encoder-decoder (T5)Attend to encoder outputInput documentOutput summary
Diffusion (Stable Diffusion)Condition on textCLIP text embeddingLatent image
MLA (DeepSeek)Low-rank KV compressionCompressed latentQuery tokens
Retrieval-augmented (RETRO)Attend to retrieved docsRetrieved chunksCurrent generation

In encoder-decoder architectures like T5 and BART, every decoder layer uses cross-attention to query the full encoder output, giving the decoder direct access to all encoded representations of the input at every generation step. This is distinct from decoder-only architectures (GPT-family) that include the source document in the same context as the generated output โ€” the encoder-decoder design has separate computational paths for understanding the input and generating the output, which can be more efficient for tasks like translation where the full source must be processed before output generation begins.

Cross-attention in image generation models serves as the conditioning mechanism that binds text semantics to spatial features in the image latent. The text prompt is encoded into a sequence of embeddings, and each convolutional or transformer layer in the UNet queries these text embeddings through cross-attention. This conditioning is present at every denoising step, continuously guiding the diffusion process toward the text-described content. ControlNet's architectural contribution is to add additional cross-attention pathways for structural condition images (depth maps, edge maps, poses) alongside the text cross-attention, enabling joint conditioning on both spatial structure and semantic content.

import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.q_proj = nn.Linear(d_model, d_model)   # from target
        self.k_proj = nn.Linear(d_model, d_model)   # from source
        self.v_proj = nn.Linear(d_model, d_model)   # from source
        self.out_proj = nn.Linear(d_model, d_model)
        self.n_heads = n_heads
        self.scale = (d_model // n_heads) ** -0.5

    def forward(self, target, source, source_mask=None):
        B, T, D = target.shape
        S = source.shape[1]
        q = self.q_proj(target).view(B, T, self.n_heads, -1).transpose(1,2)
        k = self.k_proj(source).view(B, S, self.n_heads, -1).transpose(1,2)
        v = self.v_proj(source).view(B, S, self.n_heads, -1).transpose(1,2)
        attn = (q @ k.transpose(-2,-1)) * self.scale
        if source_mask is not None:
            attn = attn.masked_fill(~source_mask.unsqueeze(1), float('-inf'))
        return self.out_proj((attn.softmax(-1) @ v).transpose(1,2).reshape(B,T,D))

Cross-attention masking strategies differ from self-attention masking. Self-attention uses causal masks for autoregressive generation, preventing tokens from attending to future positions. Cross-attention uses padding masks on the source sequence to prevent attention to padding tokens when source sequences in a batch have different lengths. For retrieval-augmented architectures that attend to retrieved documents, cross-attention masking can additionally implement document-level isolation โ€” preventing the model from attending across document boundaries within the retrieved context pool.

The computational cost of cross-attention scales as O(T_target ร— T_source) where T_target and T_source are the lengths of the target and source sequences respectively. For architectures where the source sequence is long โ€” a full document in summarization, a dense point cloud in 3D vision โ€” this quadratic scaling creates the same memory and compute bottleneck as long-sequence self-attention. Efficient cross-attention variants that use sparse attention patterns, local windowed attention, or linear attention approximations are active areas of research for applications requiring attention over very long source sequences.