01 — Challenge
The Memory Problem
Training a 7B model in FP32 requires: 7B × 4 bytes = 28 GB for weights alone. Add gradients (28 GB) + optimizer states (56 GB for Adam) = 112 GB minimum — exceeding a single A100's 80 GB VRAM. Solutions include mixed precision (halve weight memory), gradient checkpointing (trade compute for activation memory), and distributed training (spread across GPUs).
Memory Breakdown for 7B Model Training
| Component | FP32 | BF16 mixed precision |
| Weights | 28 GB | 14 GB (BF16) + 28 GB (FP32 master) |
| Gradients | 28 GB | 14 GB (BF16) |
| Adam optimizer states | 56 GB | 56 GB (FP32 always) |
| Activations (seq 2K) | 20–40 GB | 10–20 GB |
| Total (rough) | 132–172 GB | 102–122 GB |
⚠️
Adam optimizer states are always FP32 regardless of mixed precision. They consume 2× the model size by themselves. This is why optimizer state sharding (ZeRO) is critical for scaling.
02 — Optimization
Mixed Precision Training
BF16 vs FP16: BF16 has the same dynamic range as FP32 (8 exponent bits) while FP16 has a smaller range → overflow/underflow more likely. Use BF16 on A100/H100; FP16 on older V100s. FP8 (H100) halves memory again vs BF16 but requires per-tensor scaling.
The mixed precision recipe: (1) forward pass in BF16, (2) compute loss in FP32, (3) backward pass in BF16, (4) update FP32 master weights, (5) copy BF16 for next forward pass.
PyTorch torch.amp Setup
import torch
from torch.cuda.amp import GradScaler, autocast
model = MyModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
scaler = GradScaler() # only needed for FP16, not BF16
for batch in dataloader:
optimizer.zero_grad()
# Forward pass in BF16
with autocast(dtype=torch.bfloat16):
outputs = model(batch["input_ids"])
loss = outputs.loss
# Backward pass
loss.backward() # BF16 grads automatically upcast
# Gradient clipping (essential for stability)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
✓
Always use gradient clipping (max_norm=1.0) with mixed precision. Without it, a single large gradient can corrupt training.
03 — Memory Trade-off
Gradient Checkpointing
A transformer's activations (intermediate outputs saved for backward pass) consume 10–40 GB for long sequences. Gradient checkpointing trades 30–40% extra compute for 5–10× less activation memory by recomputing activations during backward instead of saving them.
Selective checkpointing works well: only checkpoint expensive layers (attention, MLP). These are the main culprits for memory consumption.
Checkpointing Example
from torch.utils.checkpoint import checkpoint
# Without checkpointing: saves all activations (~40 GB for 7B)
outputs = model(input_ids)
# With gradient checkpointing (HuggingFace):
model.gradient_checkpointing_enable()
outputs = model(input_ids) # activations recomputed during backward
# Manual per-layer checkpointing:
def forward(self, x):
x = checkpoint(self.attention_layer, x) # recompute in backward
x = checkpoint(self.mlp_layer, x)
return x
Checkpointing Trade-offs
| Strategy | Memory saved | Compute overhead | Use case |
| None | 0% | 0% | Fits in VRAM, inference |
| Full (all layers) | ~60% | +30–40% | Always recommended for training |
| Selective (attn only) | ~40% | +15–20% | When compute is the bottleneck |
| Activation offload | ~80% | +50% | Extreme memory pressure |
04 — Distribution
Data Parallelism: DDP and FSDP
Data parallelism: Each GPU holds a full copy of the model, processes a different mini-batch, gradients are averaged (all-reduce) across GPUs. DDP (Distributed Data Parallel) is PyTorch's standard. Each GPU: full model copy + full gradient copy + full optimizer state. Limited by single-GPU memory.
FSDP (Fully Sharded Data Parallel) shards model weights + gradients + optimizer states across all GPUs. Each GPU holds 1/N of everything. Enables training models larger than single-GPU memory.
FSDP Training Setup
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
model = MyTransformer()
# Shard weights, gradients, AND optimizer states across all GPUs
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3 equivalent
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
)
DDP vs FSDP vs DeepSpeed ZeRO
| Method | Memory savings | Communication overhead | Ease of use |
| DDP | None | Low | Easy |
| FSDP (SHARD_GRAD_OP) | Gradients + opt states | Medium | Medium |
| FSDP (FULL_SHARD) | Weights + grads + opt | Higher | Medium |
| DeepSpeed ZeRO-3 | Same as FULL_SHARD | Higher | Complex config |
05 — State Sharding
DeepSpeed ZeRO Optimizer
ZeRO (Zero Redundancy Optimizer): Three stages of sharding: Stage 1 shards optimizer states (4× reduction), Stage 2 shards optimizer states + gradients (8× reduction), Stage 3 shards optimizer states + gradients + parameters (full reduction — enables any model size). ZeRO-Infinity offloads optimizer states and/or parameters to CPU RAM or NVMe SSD for extreme scale.
DeepSpeed Config for ZeRO-3
# ds_config.json
{
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": 5e8,
"offload_optimizer": {"device": "cpu"}, # ZeRO-Infinity
"offload_param": {"device": "cpu"}
},
"bf16": {"enabled": true},
"gradient_clipping": 1.0,
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 8
}
⚠️
ZeRO-3 + CPU offload enables training 70B models on 8×A100s — but CPU-GPU communication adds 20–40% overhead. Benchmark before committing.
06 — Advanced
Tensor and Pipeline Parallelism
Tensor parallelism (TP): Split individual weight matrices across GPUs (e.g., split attention heads across GPUs). Very low latency — all-reduce per layer. Megatron-LM pioneered this.
Pipeline parallelism (PP): Split model layers across GPUs (GPU 0 runs layers 0–8, GPU 1 runs 9–16, etc.). Bubbles in the pipeline reduce efficiency. Requires careful microbatch scheduling.
Sequence parallelism: Split the sequence dimension across GPUs for ultra-long context training.
Parallelism Strategies
| Strategy | What's split | Comm pattern | Used for |
| Data (DDP/FSDP) | Input batches | All-reduce grads | Default for most |
| Tensor | Weight matrices | All-reduce per layer | Multi-node, large models |
| Pipeline | Model layers | Point-to-point | Very large models |
| Sequence | Sequence length | All-to-all | Long context (>32K) |
| Mixture (3D) | All three | Combined | Frontier training runs |
07 — Stability
Learning Rate Scheduling and Stability
Warmup: Ramp LR from 0 to target over first 1–5% of steps. Prevents early instability when weights are random. Cosine annealing: Decay LR following a cosine curve from target → 10% of target over training. Smooth, widely used.
WSD (Warmup-Stable-Decay): Constant LR for most of training, then sharp cosine decay at end. Used by Mistral, enables mid-training checkpoint reuse.
Loss spikes: Sudden loss increases during training are caused by bad batches, large gradients, or numerical instability. Fix: gradient clipping, skip-bad-batch logic, lower LR.
Cosine Schedule with Warmup
from transformers import get_cosine_schedule_with_warmup
total_steps = len(dataloader) * num_epochs
warmup_steps = int(0.03 * total_steps) # 3% warmup
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps,
num_cycles=0.5 # single cosine cycle
)
Scheduling Strategies
🔥 Warmup Phase
- Ramp from 0 → target over 1–5%
- Prevents gradient explosion early
- Standard 3% warmup for large models
〰️ Decay Phase
- Cosine from target → 10% over rest
- Smooth convergence, no abrupt drops
- Proportional gains decrease over time
🛡️ Stability
- Monitor loss for spikes
- Gradient clipping always on
- Skip bad batches if detected
⚙️ Tuning
- Start warmup ~2–3% of total steps
- Cosine decay rest of training
- Monitor learning curve for anomalies
Tools and Frameworks
08 — Further Reading
References
Academic Papers
-
Paper
Rajbhandari, S. et al. (2019).
ZeRO: Memory Optimizations Toward Training Trillion Parameter Models.
Microsoft. arXiv:1910.02054. —
arxiv:1910.02054 ↗
-
Paper
Zhao, Y. et al. (2023).
Fully Sharded Data Parallel: A New Data Parallelism Paradigm.
arXiv:2304.11277. —
arxiv:2304.11277 ↗
-
Paper
Shoeybi, M. et al. (2019).
Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism.
NVIDIA. arXiv:1909.08053. —
arxiv:1909.08053 ↗
-
Paper
Micikevicius, P. et al. (2017).
Mixed Precision Training.
NVIDIA. arXiv:1710.03740. —
arxiv:1710.03740 ↗
-
Paper
Schulman, J. et al. (2017).
Proximal Policy Optimization Algorithms.
OpenAI. arXiv:1707.06347. —
arxiv:1707.06347 ↗
Documentation & Guides
Blog & Articles
-
Blog
Microsoft. (2021). Training Trillion-Parameter Models Using DeepSpeed and Megatron. —
microsoft.com ↗
-
Blog
PyTorch. (2023). Introducing FSDP for Distributed Training. —
pytorch.org/blog ↗
-
Blog
HuggingFace. (2023). Distributed Training with HuggingFace and DeepSpeed. —
huggingface.co/blog ↗