Prepend trainable soft token vectors to each transformer layer's key-value inputs. Only the prefix parameters are updated; the base model is frozen. More expressive than prompt tuning but adds no inference tokens.
Prompt tuning (Lester et al. 2021) prepends trainable soft tokens to the input embedding layer only β the prefix exists at the embedding level and propagates through the model via attention. Prefix tuning (Li & Lam, 2021) goes further: trainable prefix vectors are injected at every transformer layer β directly into the key and value matrices of each layer's attention. This gives the prefix direct influence over computation at every depth, making it more expressive than prompt tuning for complex tasks.
Each transformer layer has multi-head attention that computes Q, K, V from the current hidden states. In prefix tuning, k learnable vectors P_K and P_V are prepended to the key and value sequences before attention:
The attention query (from actual input tokens) now attends over both the prefix keys and the real input keys. The prefix vectors act as a kind of "task context" that steers every layer's computation. The prefix length k is a hyperparameter (typically 10β200 tokens). Only P_K and P_V are trained β the base model's weights are frozen.
In practice, a reparameterization trick is used: instead of directly training P_K and P_V (which can be unstable), train a smaller matrix and project it up via a feed-forward network, then discard the projection at inference time.
from peft import PrefixTuningConfig, get_peft_model, TaskType
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "meta-llama/Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="auto"
)
peft_config = PrefixTuningConfig(
task_type=TaskType.CAUSAL_LM,
num_virtual_tokens=20, # prefix length β 20 virtual tokens per layer
encoder_hidden_size=None, # None = use model hidden size
)
model = get_peft_model(base_model, peft_config)
model.print_trainable_parameters()
# trainable params: 983,040 || all params: 8,030,261,248 || trainable%: 0.01%
# Training loop (standard β only prefix params have gradients)
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
training_args = TrainingArguments(
output_dir="./prefix-tuned",
num_train_epochs=3,
per_device_train_batch_size=4,
learning_rate=3e-2, # higher LR needed for prefix tuning vs LoRA
warmup_ratio=0.1,
fp16=True,
)
# Trainer setup as usual...
In practice, LoRA has largely superseded prefix tuning for most use cases. LoRA is more stable to train, scales better to large models, and achieves better performance on most benchmarks at equivalent parameter counts. Prefix tuning has one advantage: the prefix vectors are naturally task-specific and can be swapped without any model weight modification β you can switch tasks by simply loading a different prefix, which is useful in multi-task serving scenarios where you need to handle many tasks with a single frozen model.
# Fine-tune with prefix tuning vs LoRA β side-by-side setup
from peft import PrefixTuningConfig, LoraConfig, get_peft_model, TaskType
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from datasets import load_dataset
model_name = "facebook/opt-350m"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# ββ Prefix tuning config ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
prefix_cfg = PrefixTuningConfig(
task_type=TaskType.CAUSAL_LM,
num_virtual_tokens=20, # 20 soft tokens prepended to each layer
encoder_hidden_size=512, # MLP hidden size for reparametrisation
)
prefix_model = get_peft_model(
AutoModelForCausalLM.from_pretrained(model_name), prefix_cfg
)
prefix_model.print_trainable_parameters()
# trainable params: ~200K (only the prefix MLP)
# ββ LoRA config for comparison ββββββββββββββββββββββββββββββββββββββββββββββββ
lora_cfg = LoraConfig(task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=16)
lora_model = get_peft_model(
AutoModelForCausalLM.from_pretrained(model_name), lora_cfg
)
lora_model.print_trainable_parameters()
# trainable params: ~1.3M β larger but typically more expressive
# ββ Both models train with standard Trainer βββββββββββββββββββββββββββββββββββ
args = TrainingArguments(output_dir="./out", num_train_epochs=3,
per_device_train_batch_size=8, learning_rate=3e-4)
Prefix tuning is worth considering when: (1) you have a frozen API model where you can inject prefix tokens but not modify weights; (2) you need to serve many tasks from one frozen base model and want the cheapest per-task overhead; (3) you're adapting a model where LoRA is not supported (some architectures). For most practitioners with access to the model weights, LoRA is the better default choice.
model.merge_and_unload() equivalents if you need a self-contained model checkpoint.Prefix tuning prepends a sequence of learnable "virtual tokens" to the input of every transformer layer, conditioning the model's behavior without modifying any of its original weights. Only the prefix parameters are updated during training, keeping the base model frozen and making prefix tuning highly parameter-efficient while allowing task-specific customization through the learned prefix vectors.
| Method | Trainable Params | Where Modified | Inference Overhead | Quality |
|---|---|---|---|---|
| Full fine-tuning | 100% | All weights | None | Highest |
| LoRA | 0.1β1% | Attention projections | Minimal (merge at deploy) | Very good |
| Prefix tuning | 0.1% | KV cache prefix | Extra prefix tokens | Good |
| Prompt tuning | ~0.01% | Input embeddings only | Extra tokens | Moderate |
| Adapter layers | 0.5β3% | Inserted FFN layers | Adapter forward passes | Very good |
Prefix tuning's key advantage over prompt tuning is that it conditions every layer of the transformer through the prefix key-value pairs, not just the input embedding layer. Prompt tuning only prepends soft tokens to the input and relies on the frozen transformer to propagate that signal through its layers. The additional conditioning depth of prefix tuning makes it more expressive for complex tasks, though at the cost of slightly more trainable parameters and a constant inference overhead proportional to the prefix length.
The inference overhead from prefix tokens is a practical consideration for high-throughput deployments. A prefix of length L adds L tokens to every request's KV cache, consuming additional memory and adding a constant latency cost for the attention computation over the prefix. For interactive applications where latency matters, keeping prefix lengths short (8β32 tokens) is important. LoRA has effectively replaced prefix tuning for most use cases because it has no inference overhead when the adapter is merged into the base model weights before deployment.
from peft import PrefixTuningConfig, get_peft_model, TaskType
from transformers import AutoModelForSeq2SeqLM
# Load base model
model = AutoModelForSeq2SeqLM.from_pretrained("t5-large")
# Configure prefix tuning
config = PrefixTuningConfig(
task_type=TaskType.SEQ_2_SEQ_LM,
num_virtual_tokens=20, # prefix length
encoder_hidden_size=1024, # reparameterization hidden size
prefix_projection=True # use MLP reparameterization
)
# Wrap model: only prefix params are trainable
model = get_peft_model(model, config)
model.print_trainable_parameters()
# trainable params: 3,686,400 || all params: 740,536,320 || ~0.498%
Reparameterization is a training stabilization technique used in prefix tuning. Rather than directly optimizing the prefix vectors, the original paper trains a small MLP network that projects a low-dimensional input to the full-dimensional prefix vectors. This reparameterization prevents the prefix vectors from getting stuck in poor local minima during early training by providing a smoother optimization landscape. After training, the MLP is discarded and only the final prefix vectors are stored for inference β there is no runtime cost from reparameterization.
Multi-task learning with prefix tuning uses separate prefix parameters for each task while sharing all base model weights. Given K tasks, only K Γ prefix_length Γ d_model parameters are added, regardless of model size. Task selection at inference time simply routes the request through the appropriate prefix. This sharing of base model weights is more memory-efficient than maintaining K separately fine-tuned full models, and avoids catastrophic forgetting of other tasks since the original model weights are never modified during training.