RLHF, DPO, and Constitutional AI compared — how each shapes model behaviour, what it costs, and when to use it.
Attention is the mechanism that lets a transformer weigh how much each token should influence every other token. The core operation: for each token, compute a query vector, then dot it against every other token's key vector to get an attention score. Scale by √d_k to prevent vanishing gradients, apply softmax to get a probability distribution, then use those weights to sum the value vectors. The result is a context-aware representation of each token.
Modern LLMs use Multi-Head Attention (MHA): run h independent attention heads in parallel, each learning different relationship patterns (syntactic, semantic, positional), then concatenate. Variants like Grouped Query Attention (GQA) and Multi-Query Attention (MQA) reduce the KV cache by sharing key/value heads across multiple query heads.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.0):
super().__init__()
assert d_model % n_heads == 0
self.d_k = d_model // n_heads
self.n_heads = n_heads
# Single projection matrix for efficiency: Q, K, V concatenated
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor,
mask: torch.Tensor = None) -> torch.Tensor:
B, T, C = x.shape # batch, seq_len, d_model
# Project to Q, K, V
qkv = self.qkv_proj(x) # (B, T, 3*C)
Q, K, V = qkv.split(C, dim=-1) # each: (B, T, C)
# Reshape for multi-head: (B, n_heads, T, d_k)
def split_heads(t):
return t.reshape(B, T, self.n_heads, self.d_k).transpose(1, 2)
Q, K, V = split_heads(Q), split_heads(K), split_heads(V)
# Scaled dot-product attention
scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1)
weights = self.dropout(weights)
# Weighted sum of values
out = weights @ V # (B, n_heads, T, d_k)
out = out.transpose(1, 2).reshape(B, T, C) # (B, T, C)
return self.out_proj(out)
# Test
attn = MultiHeadAttention(d_model=512, n_heads=8)
x = torch.randn(2, 16, 512) # batch=2, seq_len=16
out = attn(x)
print(f"Input: {x.shape} → Output: {out.shape}") # same shape
The original Multi-Head Attention (MHA) stores a separate K and V matrix per head, which grows the KV cache linearly with the number of heads. Two efficient variants address this: Multi-Query Attention (MQA) uses a single shared K/V across all heads, drastically reducing cache size. Grouped Query Attention (GQA) — used in Llama 3, Mistral, and Gemma — is a middle ground: groups of query heads share a single K/V pair.
FlashAttention is an algorithmic optimization (not an architectural change) that rewrites the attention computation to be IO-aware — tiling the operation to fit in fast SRAM rather than repeatedly reading from HBM. FlashAttention 2 achieves ~2× the speed of standard attention on modern GPUs without changing the mathematical result. Most production LLM frameworks (vLLM, TGI) use FlashAttention by default.
| Variant | KV Heads | KV Cache Size | Quality vs MHA | Used In |
|---|---|---|---|---|
| MHA | = num_heads | Baseline | Baseline | GPT-2, BERT |
| MQA | 1 (shared) | num_heads× smaller | Slight drop | PaLM, Falcon |
| GQA | n_groups (e.g. 8) | 8× smaller vs MHA | Near-MHA | Llama 3, Mistral, Gemma 2 |
| FlashAttention | N/A (IO rewrite) | Same as base | Identical | All modern serving |
A pretrained LLM predicts the next token — it's a completion engine, not an assistant. Alignment is the process of steering that completion engine toward being helpful, honest, and harmless. It happens after pretraining and after SFT.
Raw pretraining teaches a model to predict plausible continuations of text. But "plausible" includes offensive, factually wrong, or harmful content if it's statistically likely given the prompt. Alignment techniques layer preferences on top of that statistical foundation — telling the model what humans actually want.
Supervised Fine-Tuning on demonstration data always comes first. You show the model (prompt, ideal response) pairs. It's the cheapest alignment step and gives the biggest quality jump. Every downstream alignment technique builds on a well-SFT'd model.
SFT shifts the base model's entire distribution toward assistant-like outputs. It teaches format, style, instruction following, and reasoning chains. Without good SFT, RLHF or DPO training becomes noisy — you're optimizing on top of a weak foundation.
Data quality: Even 10,000 high-quality SFT examples beat 100,000 noisy ones. Focus on diversity and clarity. Diversity: Cover instruction types, reasoning styles, and edge cases. Iteration: SFT early and often — each refinement compounds.
RLHF (Reinforcement Learning from Human Feedback) is the alignment method behind ChatGPT, Claude, and GPT-4. It maximizes a learned reward model via PPO (Proximal Policy Optimization) while penalizing divergence from the SFT checkpoint using KL divergence.
Annotators rank model responses (typically A vs B). This is expensive — usually 50–100 examples per prompt, across thousands of prompts, annotated by multiple annotators to ensure quality.
A separate model learns to score responses. Given (prompt, response A, response B), it predicts which humans prefer. This model becomes the ground truth during PPO training.
Fine-tune the LLM to maximize reward model scores while staying close to the SFT model via KL divergence penalty. The penalty prevents reward hacking and distribution collapse.
Collect new preference data on the updated model, retrain the reward model, run PPO again. Each iteration refines preferences and catches reward model drift.
Direct Preference Optimization (DPO) reformulates RLHF as a classification problem. Instead of training a reward model and running PPO, DPO directly optimizes the policy on preference pairs: given (prompt, chosen, rejected), update the policy to assign higher probability to chosen over rejected.
No reward model. No PPO. No reference model calls during training. Roughly 2× faster to implement and 2× faster to run. Empirically, DPO matches RLHF quality on many benchmarks.
DPO uses implicit reward modeling — the reward is hidden in the loss function. This simplicity comes with tradeoffs: DPO may be less stable than RLHF on very large models, and reward model evaluation is opaque. But for teams without massive annotation budgets or GPU fleets, DPO is often the pragmatic choice.
Constitutional AI (CAI) is Anthropic's approach to replacing human preference labels with AI-generated feedback. A set of principles — the "constitution" — guides a capable model to critique and revise its own outputs. The revised outputs become training data for alignment.
This scales without human annotation. Instead of paying annotators to rank responses, you craft a constitution and let the model self-improve. But it requires a capable enough base model to self-critique reliably — weak models will generate poor feedback.
This trades human effort for LLM compute. The constitution must be well-written and aligned with your values — vague principles lead to vague feedback. And the critique model must be strong enough to notice flaws and suggest improvements.
| Method | Human labels needed | Compute cost | Stability | Best when |
|---|---|---|---|---|
| SFT | Demonstrations | Low | High | Always — prerequisite |
| RLHF | Preference pairs + RM | High | Medium | Maximum quality, budget available |
| DPO | Preference pairs only | Medium | High | Simpler RLHF alternative |
| Constitutional AI | Minimal | Medium | Medium | Scaling without labellers |
RLHF achieves the highest quality but at high cost. DPO gets 90–95% of RLHF quality at half the compute. Constitutional AI sacrifices quality for annotation savings. Most teams should start with SFT + DPO, only moving to RLHF if quality plateaus and budget is available.