TRAINING INFRASTRUCTURE

LLM Training Techniques

Mixed precision, gradient checkpointing, distributed training, and the engineering that makes large-scale training tractable

BF16 + FP32 master weights the standard setup
ZeRO Stage 1/2/3 DeepSpeed's progression
10–40× memory savings checkpointing vs recompute
Contents
  1. The memory problem
  2. Mixed precision
  3. Gradient checkpointing
  4. Data parallelism
  5. ZeRO optimizer
  6. Tensor & pipeline
  7. Learning rate
01 — Challenge

The Memory Problem

Training a 7B model in FP32 requires: 7B × 4 bytes = 28 GB for weights alone. Add gradients (28 GB) + optimizer states (56 GB for Adam) = 112 GB minimum — exceeding a single A100's 80 GB VRAM. Solutions include mixed precision (halve weight memory), gradient checkpointing (trade compute for activation memory), and distributed training (spread across GPUs).

Memory Breakdown for 7B Model Training

ComponentFP32BF16 mixed precision
Weights28 GB14 GB (BF16) + 28 GB (FP32 master)
Gradients28 GB14 GB (BF16)
Adam optimizer states56 GB56 GB (FP32 always)
Activations (seq 2K)20–40 GB10–20 GB
Total (rough)132–172 GB102–122 GB
⚠️ Adam optimizer states are always FP32 regardless of mixed precision. They consume 2× the model size by themselves. This is why optimizer state sharding (ZeRO) is critical for scaling.
02 — Optimization

Mixed Precision Training

BF16 vs FP16: BF16 has the same dynamic range as FP32 (8 exponent bits) while FP16 has a smaller range → overflow/underflow more likely. Use BF16 on A100/H100; FP16 on older V100s. FP8 (H100) halves memory again vs BF16 but requires per-tensor scaling.

The mixed precision recipe: (1) forward pass in BF16, (2) compute loss in FP32, (3) backward pass in BF16, (4) update FP32 master weights, (5) copy BF16 for next forward pass.

PyTorch torch.amp Setup

import torch from torch.cuda.amp import GradScaler, autocast model = MyModel().cuda() optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4) scaler = GradScaler() # only needed for FP16, not BF16 for batch in dataloader: optimizer.zero_grad() # Forward pass in BF16 with autocast(dtype=torch.bfloat16): outputs = model(batch["input_ids"]) loss = outputs.loss # Backward pass loss.backward() # BF16 grads automatically upcast # Gradient clipping (essential for stability) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step()
Always use gradient clipping (max_norm=1.0) with mixed precision. Without it, a single large gradient can corrupt training.
03 — Memory Trade-off

Gradient Checkpointing

A transformer's activations (intermediate outputs saved for backward pass) consume 10–40 GB for long sequences. Gradient checkpointing trades 30–40% extra compute for 5–10× less activation memory by recomputing activations during backward instead of saving them.

Selective checkpointing works well: only checkpoint expensive layers (attention, MLP). These are the main culprits for memory consumption.

Checkpointing Example

from torch.utils.checkpoint import checkpoint # Without checkpointing: saves all activations (~40 GB for 7B) outputs = model(input_ids) # With gradient checkpointing (HuggingFace): model.gradient_checkpointing_enable() outputs = model(input_ids) # activations recomputed during backward # Manual per-layer checkpointing: def forward(self, x): x = checkpoint(self.attention_layer, x) # recompute in backward x = checkpoint(self.mlp_layer, x) return x

Checkpointing Trade-offs

StrategyMemory savedCompute overheadUse case
None0%0%Fits in VRAM, inference
Full (all layers)~60%+30–40%Always recommended for training
Selective (attn only)~40%+15–20%When compute is the bottleneck
Activation offload~80%+50%Extreme memory pressure
04 — Distribution

Data Parallelism: DDP and FSDP

Data parallelism: Each GPU holds a full copy of the model, processes a different mini-batch, gradients are averaged (all-reduce) across GPUs. DDP (Distributed Data Parallel) is PyTorch's standard. Each GPU: full model copy + full gradient copy + full optimizer state. Limited by single-GPU memory.

FSDP (Fully Sharded Data Parallel) shards model weights + gradients + optimizer states across all GPUs. Each GPU holds 1/N of everything. Enables training models larger than single-GPU memory.

FSDP Training Setup

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import ShardingStrategy model = MyTransformer() # Shard weights, gradients, AND optimizer states across all GPUs model = FSDP( model, sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3 equivalent mixed_precision=MixedPrecision( param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16, ) )

DDP vs FSDP vs DeepSpeed ZeRO

MethodMemory savingsCommunication overheadEase of use
DDPNoneLowEasy
FSDP (SHARD_GRAD_OP)Gradients + opt statesMediumMedium
FSDP (FULL_SHARD)Weights + grads + optHigherMedium
DeepSpeed ZeRO-3Same as FULL_SHARDHigherComplex config
05 — State Sharding

DeepSpeed ZeRO Optimizer

ZeRO (Zero Redundancy Optimizer): Three stages of sharding: Stage 1 shards optimizer states (4× reduction), Stage 2 shards optimizer states + gradients (8× reduction), Stage 3 shards optimizer states + gradients + parameters (full reduction — enables any model size). ZeRO-Infinity offloads optimizer states and/or parameters to CPU RAM or NVMe SSD for extreme scale.

DeepSpeed Config for ZeRO-3

# ds_config.json { "zero_optimization": { "stage": 3, "overlap_comm": true, "contiguous_gradients": true, "sub_group_size": 1e9, "reduce_bucket_size": 5e8, "offload_optimizer": {"device": "cpu"}, # ZeRO-Infinity "offload_param": {"device": "cpu"} }, "bf16": {"enabled": true}, "gradient_clipping": 1.0, "train_micro_batch_size_per_gpu": 4, "gradient_accumulation_steps": 8 }
⚠️ ZeRO-3 + CPU offload enables training 70B models on 8×A100s — but CPU-GPU communication adds 20–40% overhead. Benchmark before committing.
06 — Advanced

Tensor and Pipeline Parallelism

Tensor parallelism (TP): Split individual weight matrices across GPUs (e.g., split attention heads across GPUs). Very low latency — all-reduce per layer. Megatron-LM pioneered this.

Pipeline parallelism (PP): Split model layers across GPUs (GPU 0 runs layers 0–8, GPU 1 runs 9–16, etc.). Bubbles in the pipeline reduce efficiency. Requires careful microbatch scheduling.

Sequence parallelism: Split the sequence dimension across GPUs for ultra-long context training.

Parallelism Strategies

StrategyWhat's splitComm patternUsed for
Data (DDP/FSDP)Input batchesAll-reduce gradsDefault for most
TensorWeight matricesAll-reduce per layerMulti-node, large models
PipelineModel layersPoint-to-pointVery large models
SequenceSequence lengthAll-to-allLong context (>32K)
Mixture (3D)All threeCombinedFrontier training runs
07 — Stability

Learning Rate Scheduling and Stability

Warmup: Ramp LR from 0 to target over first 1–5% of steps. Prevents early instability when weights are random. Cosine annealing: Decay LR following a cosine curve from target → 10% of target over training. Smooth, widely used.

WSD (Warmup-Stable-Decay): Constant LR for most of training, then sharp cosine decay at end. Used by Mistral, enables mid-training checkpoint reuse.

Loss spikes: Sudden loss increases during training are caused by bad batches, large gradients, or numerical instability. Fix: gradient clipping, skip-bad-batch logic, lower LR.

Cosine Schedule with Warmup

from transformers import get_cosine_schedule_with_warmup total_steps = len(dataloader) * num_epochs warmup_steps = int(0.03 * total_steps) # 3% warmup scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, num_cycles=0.5 # single cosine cycle )

Scheduling Strategies

🔥 Warmup Phase

  • Ramp from 0 → target over 1–5%
  • Prevents gradient explosion early
  • Standard 3% warmup for large models

〰️ Decay Phase

  • Cosine from target → 10% over rest
  • Smooth convergence, no abrupt drops
  • Proportional gains decrease over time

🛡️ Stability

  • Monitor loss for spikes
  • Gradient clipping always on
  • Skip bad batches if detected

⚙️ Tuning

  • Start warmup ~2–3% of total steps
  • Cosine decay rest of training
  • Monitor learning curve for anomalies

Tools and Frameworks

Framework
PyTorch FSDP
Native fully sharded data parallel.
Library
DeepSpeed
Microsoft. ZeRO optimizer and beyond.
Framework
Megatron-LM
NVIDIA. Tensor and pipeline parallelism.
Training
Axolotl
End-to-end fine-tuning framework.
Training
LLM Foundry
Composer/MosaicML stack.
Monitoring
W&B
Weights & Biases. Experiment tracking.
Monitoring
TensorBoard
TensorFlow. Real-time visualization.
Framework
torchtrain (TorchTitan)
PyTorch's native distributed training.
08 — Further Reading

References

Academic Papers
Documentation & Guides
Blog & Articles