Cross-attention lets a decoder query information from an encoder's output. The mechanism that connects encoder and decoder in seq2seq models โ and reappears in multimodal architectures to fuse vision and language.
In self-attention, queries, keys, and values all come from the same sequence โ each position attends to every other position in the same sequence. In cross-attention, the queries come from one sequence (typically the decoder's hidden states) while the keys and values come from a different sequence (typically the encoder's output). This lets the decoder look up which parts of the source are most relevant for generating the next token.
Cross-attention is the mechanism that connects encoder and decoder in seq2seq architectures like the original Transformer (Vaswani et al. 2017), T5, and BART. It also reappears in multimodal models (LLaVA, Flamingo) to fuse visual features into a language decoder.
Given decoder hidden states H_dec โ โT_dec ร d_model and encoder output H_enc โ โT_enc ร d_model:
Attention(Q, K, V) = softmax(QKT / โd_k) ยท V
The result has shape T_dec ร d_v โ each decoder position gets a weighted combination of encoder positions. The attention weights (shape T_dec ร T_enc) show which source positions the decoder is attending to when generating each target token.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class CrossAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int):
super().__init__()
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def forward(self, dec_hidden, enc_output, enc_mask=None):
# dec_hidden: (B, T_dec, d_model)
# enc_output: (B, T_enc, d_model)
B, T_dec, _ = dec_hidden.shape
T_enc = enc_output.shape[1]
Q = self.W_q(dec_hidden).view(B, T_dec, self.n_heads, self.d_head).transpose(1, 2)
K = self.W_k(enc_output).view(B, T_enc, self.n_heads, self.d_head).transpose(1, 2)
V = self.W_v(enc_output).view(B, T_enc, self.n_heads, self.d_head).transpose(1, 2)
scale = math.sqrt(self.d_head)
scores = torch.matmul(Q, K.transpose(-2, -1)) / scale # (B, H, T_dec, T_enc)
if enc_mask is not None:
scores = scores.masked_fill(enc_mask.unsqueeze(1).unsqueeze(2) == 0, -1e9)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V) # (B, H, T_dec, d_head)
out = out.transpose(1, 2).contiguous().view(B, T_dec, -1)
return self.W_o(out), attn # return attn weights for visualisation
In the original Transformer decoder, each layer has three sublayers: (1) masked self-attention over previously generated tokens, (2) cross-attention over the encoder output, and (3) a feed-forward network. The cross-attention sublayer is what allows the model to focus on relevant source tokens when generating each target token.
Modern variants like T5 follow the same pattern. During inference, the encoder processes the input once; the decoder runs cross-attention at every step, re-attending to encoder states. This is efficient because encoder keys and values can be cached for the entire generation.
Cross-attention is the dominant mechanism for fusing modalities. In Flamingo, visual features from a frozen vision encoder are injected into a language model via cross-attention layers inserted between transformer blocks. In LLaVA, visual tokens are projected into the language model's embedding space and prepended to the text โ effectively using the language model's self-attention as a form of cross-attention. In Stable Diffusion's UNet, cross-attention fuses CLIP text embeddings with image features at multiple resolutions to condition the denoising process on the text prompt.
When you inspect cross-attention weights in a trained translation model, clear patterns emerge: when generating a verb in German, the model strongly attends to the corresponding verb in the English source. In summarisation models, attention heads focus on key sentences in the source document. This interpretability is one reason cross-attention is preferred over alternatives like concatenation for fusing two sequences.
import matplotlib.pyplot as plt
def plot_cross_attn(attn_weights, src_tokens, tgt_tokens, head=0):
# attn_weights: (n_heads, T_dec, T_enc)
fig, ax = plt.subplots(figsize=(8, 6))
ax.imshow(attn_weights[head].cpu().numpy(), cmap="Blues", aspect="auto")
ax.set_xticks(range(len(src_tokens)))
ax.set_xticklabels(src_tokens, rotation=45, ha="right")
ax.set_yticks(range(len(tgt_tokens)))
ax.set_yticklabels(tgt_tokens)
ax.set_xlabel("Source (encoder)"); ax.set_ylabel("Target (decoder)")
plt.tight_layout(); plt.show()
Cross-attention allows one sequence to attend to another, enabling transformer architectures to condition generation on external information. While self-attention queries a sequence against itself, cross-attention queries the target sequence against a separate source sequence, making it the mechanism for encoder-decoder interaction in translation and summarization architectures and for conditioning image generation on text prompts in diffusion models.
| Architecture | Cross-Attention Role | Source Sequence | Target Sequence |
|---|---|---|---|
| Encoder-decoder (T5) | Attend to encoder output | Input document | Output summary |
| Diffusion (Stable Diffusion) | Condition on text | CLIP text embedding | Latent image |
| MLA (DeepSeek) | Low-rank KV compression | Compressed latent | Query tokens |
| Retrieval-augmented (RETRO) | Attend to retrieved docs | Retrieved chunks | Current generation |
In encoder-decoder architectures like T5 and BART, every decoder layer uses cross-attention to query the full encoder output, giving the decoder direct access to all encoded representations of the input at every generation step. This is distinct from decoder-only architectures (GPT-family) that include the source document in the same context as the generated output โ the encoder-decoder design has separate computational paths for understanding the input and generating the output, which can be more efficient for tasks like translation where the full source must be processed before output generation begins.
Cross-attention in image generation models serves as the conditioning mechanism that binds text semantics to spatial features in the image latent. The text prompt is encoded into a sequence of embeddings, and each convolutional or transformer layer in the UNet queries these text embeddings through cross-attention. This conditioning is present at every denoising step, continuously guiding the diffusion process toward the text-described content. ControlNet's architectural contribution is to add additional cross-attention pathways for structural condition images (depth maps, edge maps, poses) alongside the text cross-attention, enabling joint conditioning on both spatial structure and semantic content.
import torch
import torch.nn as nn
class CrossAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.q_proj = nn.Linear(d_model, d_model) # from target
self.k_proj = nn.Linear(d_model, d_model) # from source
self.v_proj = nn.Linear(d_model, d_model) # from source
self.out_proj = nn.Linear(d_model, d_model)
self.n_heads = n_heads
self.scale = (d_model // n_heads) ** -0.5
def forward(self, target, source, source_mask=None):
B, T, D = target.shape
S = source.shape[1]
q = self.q_proj(target).view(B, T, self.n_heads, -1).transpose(1,2)
k = self.k_proj(source).view(B, S, self.n_heads, -1).transpose(1,2)
v = self.v_proj(source).view(B, S, self.n_heads, -1).transpose(1,2)
attn = (q @ k.transpose(-2,-1)) * self.scale
if source_mask is not None:
attn = attn.masked_fill(~source_mask.unsqueeze(1), float('-inf'))
return self.out_proj((attn.softmax(-1) @ v).transpose(1,2).reshape(B,T,D))
Cross-attention masking strategies differ from self-attention masking. Self-attention uses causal masks for autoregressive generation, preventing tokens from attending to future positions. Cross-attention uses padding masks on the source sequence to prevent attention to padding tokens when source sequences in a batch have different lengths. For retrieval-augmented architectures that attend to retrieved documents, cross-attention masking can additionally implement document-level isolation โ preventing the model from attending across document boundaries within the retrieved context pool.
The computational cost of cross-attention scales as O(T_target ร T_source) where T_target and T_source are the lengths of the target and source sequences respectively. For architectures where the source sequence is long โ a full document in summarization, a dense point cloud in 3D vision โ this quadratic scaling creates the same memory and compute bottleneck as long-sequence self-attention. Efficient cross-attention variants that use sparse attention patterns, local windowed attention, or linear attention approximations are active areas of research for applications requiring attention over very long source sequences.