Training Tech

Gradient Checkpointing

Trade compute for memory by recomputing activations during the backward pass instead of storing them — critical for training on long sequences or large models.

√n
Memory Cost
33%
Extra Compute
2×+
Longer Sequences

Table of Contents

SECTION 01

The Memory Problem

Standard backpropagation stores ALL intermediate activations from the forward pass — needed to compute gradients during backward. For a 32-layer transformer with seq_len=2048, d_model=4096, batch=8:

# Memory estimate per layer (rough): seq_len = 2048; d_model = 4096; batch = 8 bytes_per_float = 2 # bfloat16 # Attention activations per layer q_k_v_activations = 3 * batch * seq_len * d_model * bytes_per_float # ~400 MB attn_scores = batch * 32 * seq_len * seq_len * bytes_per_float # ~2 GB (!) ffn_activations = batch * seq_len * 4 * d_model * bytes_per_float # ~800 MB per_layer_mb = (q_k_v_activations + attn_scores + ffn_activations) / 1e6 total_activation_gb = 32 * per_layer_mb / 1e3 print(f"Total activation memory: ~{total_activation_gb:.1f} GB for 32 layers") # → ~100 GB! Impossible on a single 80GB A100
Bottleneck: The attention score matrix is O(seq_len²) — doubling sequence length quadruples this memory. For seq_len=8192, attention scores alone are 8GB per layer.
SECTION 02

How Checkpointing Works

Instead of storing all activations, checkpoint at regular intervals (every N layers). During backward, recompute the activations from the nearest checkpoint.

Example: 32-layer model, checkpoint every 4 layers: store 8 checkpoints. During backward, recompute each 4-layer segment (8 × 4-layer forward passes = 1 extra full forward pass = 33% overhead).
SECTION 03

PyTorch Implementation

import torch from torch.utils.checkpoint import checkpoint, checkpoint_sequential # Manual checkpointing for a custom module class CheckpointedTransformerLayer(torch.nn.Module): def __init__(self, layer): super().__init__() self.layer = layer def forward(self, x): # checkpoint: recompute activations during backward return checkpoint( self.layer, x, use_reentrant=False # Modern API — recommended ) # checkpoint_sequential for nn.Sequential models model_sequential = torch.nn.Sequential(*layers) output = checkpoint_sequential( model_sequential, segments=4, # Split into 4 segments, checkpoint between each input=x, use_reentrant=False ) # What use_reentrant=False does: # Modern implementation using saved tensors hook # Better compatibility with torch.compile and other features # Always use False for new code
SECTION 04

HuggingFace Enable

from transformers import AutoModelForCausalLM, TrainingArguments, Trainer model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") # Method 1: Enable on model directly model.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) # Check it's enabled print(model.is_gradient_checkpointing) # True # Method 2: Via TrainingArguments args = TrainingArguments( gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, per_device_train_batch_size=4, gradient_accumulation_steps=8, ) # Important: gradient checkpointing disables model caching # (can't cache KV states when recomputing activations) model.config.use_cache = False # Required with gradient_checkpointing # PEFT + gradient checkpointing from peft import get_peft_model, LoraConfig model = get_peft_model(model, lora_config) model.enable_input_require_grads() # Required for PEFT + checkpointing
SECTION 05

Memory vs Compute Trade-off

StrategyMemorySpeedUse When
No checkpointing100%FastestMemory allows it
Checkpoint every layer~20%~67% of maxVery long sequences
Checkpoint every N layers√n × layer_mem~75% of maxSweet spot for most cases
Gradient checkpointing + AMP~10%~60% of maxMaximize sequence length
SECTION 06

When to Use It

# Decision tree for gradient checkpointing: # 1. Are you getting OOM errors during training? # → Enable gradient checkpointing first # 2. Are you training with very long sequences (>2048 tokens)? # → Almost certainly need checkpointing (attention is O(seq^2)) # 3. Are you LoRA fine-tuning on a small GPU? # → Enable checkpointing + use smaller batch with gradient accumulation # Example: fine-tuning LLaMA-7B with 4096 context on a single A100 (80GB) from transformers import AutoModelForCausalLM, TrainingArguments model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto" ) model.gradient_checkpointing_enable() model.config.use_cache = False args = TrainingArguments( per_device_train_batch_size=1, gradient_accumulation_steps=16, # Effective batch = 16 gradient_checkpointing=True, bf16=True, max_seq_length=4096, )
Combined strategy: bf16 + gradient checkpointing + gradient accumulation. This trio lets you fine-tune a 7B model on a single 24GB RTX 4090 with 2048-token sequences.

Selective checkpointing strategies

Not all transformer layers benefit equally from gradient checkpointing. Attention layers store large activation tensors proportional to sequence_length² while feed-forward layers store activations proportional to sequence_length × hidden_dim. For models with long sequences, checkpointing only the attention layers captures most of the memory savings (60–80% reduction) while incurring less recomputation overhead than checkpointing every layer. Hugging Face's gradient_checkpointing_kwargs parameter in TrainingArguments enables fine-grained control over which modules are checkpointed.

Activation offloading to CPU

Activation offloading extends gradient checkpointing by moving stored activations to CPU RAM instead of recomputing them. This approach reduces GPU memory usage while avoiding the compute overhead of full recomputation, at the cost of PCIe bandwidth for transferring activations between GPU and CPU memory. Activation offloading is most beneficial when GPU memory is the binding constraint but PCIe bandwidth is not — typical for single-GPU setups with fast NVMe SSDs where CPU memory is plentiful. DeepSpeed ZeRO-3 with CPU offload implements this strategy at the parameter level for both model weights and optimizer states.

StrategyMemory savingsCompute overheadBest for
No checkpointingNoneNoneSmall models, adequate VRAM
Selective (attention only)40–60%~10%Long sequences
Full checkpointing60–80%~30–40%Very large models
CPU offloading80–90%PCIe bandwidthVRAM-constrained single GPU

FSDP Interaction and Distributed Training

Gradient checkpointing becomes essential in fully sharded data parallel (FSDP) setups, where model weights, activations, and gradients are distributed across multiple GPUs. When using FSDP with Hugging Face transformers, combining gradient_checkpointing=True with fsdp="full_shard" dramatically reduces per-GPU memory footprint. The interaction works by checkpointing activations before gradient computation, allowing FSDP to gather only the necessary weight shards for backward pass rather than maintaining full activation tensors. This synergy is crucial for training 70B+ parameter models on 8×H100 clusters; without checkpointing, activation memory alone exceeds per-GPU capacity. FSDP's activation checkpointing is compatible with both use_reentrant=True and False modes, though False (modern implementation) is recommended for better integration with torch.compile and asynchronous communication overlapping in newer frameworks.

Fine-grained Checkpoint Granularity and Segment Sizing

The checkpoint granularity — how often activations are saved along the forward path — directly impacts the memory-compute trade-off curve. Checkpointing every layer minimizes recomputation (only 1 extra forward pass) but stores many checkpoints. Checkpointing every N layers (where N=√depth for optimal memory) balances storage and recomputation. For 32-layer models, N≈6 gives ~20% memory usage with ~33% compute overhead. For 175B-parameter models with 96 layers, N≈10 checkpoint segments prove optimal in practice. PyTorch's checkpoint_sequential function allows specifying segment count directly; segment_count=4 divides a 96-layer model into 24-layer chunks. Empirically, activation memory is dominated by attention score matrices (O(seq²)), so selective checkpointing that targets only attention layers reclaims 60–75% of memory savings with just 10–15% compute overhead, making it the preferred strategy for long-context inference-adapted models like LLaMA-2-70b-chat.

Memory vs. Latency Profiling and Adaptive Strategies

Measuring gradient checkpointing's real-world impact requires profiling both peak memory and wall-clock training time. torch.cuda.memory_reserved() and torch.cuda.max_memory_allocated() reveal the true memory footprint; empirically, a 13B model trained on A100-40GB shows ~35GB peak without checkpointing, ~15GB with full checkpointing (57% reduction). Wall-clock overhead depends on GPU compute density and memory bandwidth: models with high arithmetic intensity (transformer layers) see ~25–30% slowdown, while memory-bandwidth-limited operations (large linear layers) may slow by 40–50%. Adaptive strategies that enable checkpointing only when peak memory approaches capacity, detected via monitoring torch.cuda.memory_allocated() during training, provide optimal throughput. This approach is implemented in frameworks like DeepSpeed's default training recipes and Hugging Face's Trainer with gradient_checkpointing="auto" (experimental) flag.

Activation Memory Breakdown by Layer Type

Gradient checkpointing's memory savings vary dramatically by layer type. Attention layers dominate memory consumption due to attention score matrices of size batch×seq_len×seq_len — a single attention head with seq_len=4096 and batch=8 requires 4096×4096×8×2 bytes = 256MB per head, multiplied by number of heads (e.g., 32 heads → 8GB just for scores). Feed-forward layers consume batch×seq_len×4×hidden_dim, much smaller: 8×4096×4×4096×2 = 1GB. This 8:1 ratio in storage explains selective checkpointing strategies that focus on attention layers exclusively. Empirical profiling with torch.cuda.memory_allocated() before and after attention layers reveals the asymmetry; many practitioners checkpoint only attention (10–15% speedup) instead of all layers (30–40% speedup), recovering 60–75% of memory benefit with far less recomputation. For sequence lengths exceeding 8K tokens on single H100 GPUs, selective checkpointing becomes necessary — full checkpointing cannot sustain inference latency SLAs.

Checkpointing and KV Cache Behavior in Inference

Gradient checkpointing during training explicitly disables KV cache (must set model.config.use_cache=False). At inference time (generation), KV cache is essential for reducing autoregressive generation cost from O(n²) to O(n): previously computed key/value pairs are cached, avoiding recomputation for each new token. This creates a training-inference gap: models trained with gradient_checkpointing=True have no KV cache experience, potentially degrading generation efficiency. In practice, this gap is negligible (inference automatically caches regardless of training procedure), but implications exist for efficient generation libraries: vLLM and text-generation-webui's KV cache manager must allocate sufficient GPU memory for caching all intermediate states. For a 70B model generating 2k tokens with batch_size=32, KV cache consumes ~40GB (scales linearly with output length), often exceeding model weights. Combined strategy in inference: use KV cache for efficient generation, reduce batch size if memory-bound, and optionally enable per-layer KV cache offloading (store older layers' cache on CPU, younger layers on GPU) for 20–30% memory reduction at 5–10% latency cost.