Trade compute for memory by recomputing activations during the backward pass instead of storing them — critical for training on long sequences or large models.
Standard backpropagation stores ALL intermediate activations from the forward pass — needed to compute gradients during backward. For a 32-layer transformer with seq_len=2048, d_model=4096, batch=8:
Instead of storing all activations, checkpoint at regular intervals (every N layers). During backward, recompute the activations from the nearest checkpoint.
| Strategy | Memory | Speed | Use When |
|---|---|---|---|
| No checkpointing | 100% | Fastest | Memory allows it |
| Checkpoint every layer | ~20% | ~67% of max | Very long sequences |
| Checkpoint every N layers | √n × layer_mem | ~75% of max | Sweet spot for most cases |
| Gradient checkpointing + AMP | ~10% | ~60% of max | Maximize sequence length |
Not all transformer layers benefit equally from gradient checkpointing. Attention layers store large activation tensors proportional to sequence_length² while feed-forward layers store activations proportional to sequence_length × hidden_dim. For models with long sequences, checkpointing only the attention layers captures most of the memory savings (60–80% reduction) while incurring less recomputation overhead than checkpointing every layer. Hugging Face's gradient_checkpointing_kwargs parameter in TrainingArguments enables fine-grained control over which modules are checkpointed.
Activation offloading extends gradient checkpointing by moving stored activations to CPU RAM instead of recomputing them. This approach reduces GPU memory usage while avoiding the compute overhead of full recomputation, at the cost of PCIe bandwidth for transferring activations between GPU and CPU memory. Activation offloading is most beneficial when GPU memory is the binding constraint but PCIe bandwidth is not — typical for single-GPU setups with fast NVMe SSDs where CPU memory is plentiful. DeepSpeed ZeRO-3 with CPU offload implements this strategy at the parameter level for both model weights and optimizer states.
| Strategy | Memory savings | Compute overhead | Best for |
|---|---|---|---|
| No checkpointing | None | None | Small models, adequate VRAM |
| Selective (attention only) | 40–60% | ~10% | Long sequences |
| Full checkpointing | 60–80% | ~30–40% | Very large models |
| CPU offloading | 80–90% | PCIe bandwidth | VRAM-constrained single GPU |
Gradient checkpointing becomes essential in fully sharded data parallel (FSDP) setups, where model weights, activations, and gradients are distributed across multiple GPUs. When using FSDP with Hugging Face transformers, combining gradient_checkpointing=True with fsdp="full_shard" dramatically reduces per-GPU memory footprint. The interaction works by checkpointing activations before gradient computation, allowing FSDP to gather only the necessary weight shards for backward pass rather than maintaining full activation tensors. This synergy is crucial for training 70B+ parameter models on 8×H100 clusters; without checkpointing, activation memory alone exceeds per-GPU capacity. FSDP's activation checkpointing is compatible with both use_reentrant=True and False modes, though False (modern implementation) is recommended for better integration with torch.compile and asynchronous communication overlapping in newer frameworks.
The checkpoint granularity — how often activations are saved along the forward path — directly impacts the memory-compute trade-off curve. Checkpointing every layer minimizes recomputation (only 1 extra forward pass) but stores many checkpoints. Checkpointing every N layers (where N=√depth for optimal memory) balances storage and recomputation. For 32-layer models, N≈6 gives ~20% memory usage with ~33% compute overhead. For 175B-parameter models with 96 layers, N≈10 checkpoint segments prove optimal in practice. PyTorch's checkpoint_sequential function allows specifying segment count directly; segment_count=4 divides a 96-layer model into 24-layer chunks. Empirically, activation memory is dominated by attention score matrices (O(seq²)), so selective checkpointing that targets only attention layers reclaims 60–75% of memory savings with just 10–15% compute overhead, making it the preferred strategy for long-context inference-adapted models like LLaMA-2-70b-chat.
Measuring gradient checkpointing's real-world impact requires profiling both peak memory and wall-clock training time. torch.cuda.memory_reserved() and torch.cuda.max_memory_allocated() reveal the true memory footprint; empirically, a 13B model trained on A100-40GB shows ~35GB peak without checkpointing, ~15GB with full checkpointing (57% reduction). Wall-clock overhead depends on GPU compute density and memory bandwidth: models with high arithmetic intensity (transformer layers) see ~25–30% slowdown, while memory-bandwidth-limited operations (large linear layers) may slow by 40–50%. Adaptive strategies that enable checkpointing only when peak memory approaches capacity, detected via monitoring torch.cuda.memory_allocated() during training, provide optimal throughput. This approach is implemented in frameworks like DeepSpeed's default training recipes and Hugging Face's Trainer with gradient_checkpointing="auto" (experimental) flag.
Gradient checkpointing's memory savings vary dramatically by layer type. Attention layers dominate memory consumption due to attention score matrices of size batch×seq_len×seq_len — a single attention head with seq_len=4096 and batch=8 requires 4096×4096×8×2 bytes = 256MB per head, multiplied by number of heads (e.g., 32 heads → 8GB just for scores). Feed-forward layers consume batch×seq_len×4×hidden_dim, much smaller: 8×4096×4×4096×2 = 1GB. This 8:1 ratio in storage explains selective checkpointing strategies that focus on attention layers exclusively. Empirical profiling with torch.cuda.memory_allocated() before and after attention layers reveals the asymmetry; many practitioners checkpoint only attention (10–15% speedup) instead of all layers (30–40% speedup), recovering 60–75% of memory benefit with far less recomputation. For sequence lengths exceeding 8K tokens on single H100 GPUs, selective checkpointing becomes necessary — full checkpointing cannot sustain inference latency SLAs.
Gradient checkpointing during training explicitly disables KV cache (must set model.config.use_cache=False). At inference time (generation), KV cache is essential for reducing autoregressive generation cost from O(n²) to O(n): previously computed key/value pairs are cached, avoiding recomputation for each new token. This creates a training-inference gap: models trained with gradient_checkpointing=True have no KV cache experience, potentially degrading generation efficiency. In practice, this gap is negligible (inference automatically caches regardless of training procedure), but implications exist for efficient generation libraries: vLLM and text-generation-webui's KV cache manager must allocate sufficient GPU memory for caching all intermediate states. For a 70B model generating 2k tokens with batch_size=32, KV cache consumes ~40GB (scales linearly with output length), often exceeding model weights. Combined strategy in inference: use KV cache for efficient generation, reduce batch size if memory-bound, and optionally enable per-layer KV cache offloading (store older layers' cache on CPU, younger layers on GPU) for 20–30% memory reduction at 5–10% latency cost.