SGD, Adam, AdamW, and beyond — how gradient-based optimisers shape training dynamics and convergence
All modern optimisers are variants of gradient descent: θ ← θ - α∇L(θ), where α is the learning rate. The differences are in how you accumulate gradients and adapt the learning rate.
Batch GD: Compute loss over entire dataset. Stable, cheap gradient (amortised over N examples). But expensive memory and slow feedback loops.
Mini-batch GD: Compute loss over k examples. The standard in deep learning. Good tradeoff between variance and speed.
Stochastic GD (SGD): k = 1. Noisier gradients. But noise acts as regularisation — helps escape local minima. Cheaper per step.
| Optimiser | Adaptive LR | Memory | Convergence | Used in |
|---|---|---|---|---|
| SGD + momentum | No | O(θ) | Slow, stable | Vision CNNs, baselines |
| RMSProp | Yes | O(θ) | Medium | RNNs, Transformer warmup |
| Adam | Yes | 2×O(θ) | Fast, sometimes overfits | BERT, GPT, default LLM choice |
| AdamW | Yes | 2×O(θ) | Fast, better generalization | Most LLMs, state-of-the-art |
| Adafactor | Yes | O(θ) | Fast, memory-efficient | T5, large models on limited GPU |
| Sophia | Yes | 3×O(θ) | Very fast (2-3× speedup) | Research, scaling studies |
Adaptive LR: Optimiser adjusts learning rate per parameter based on gradient history. Good for heterogeneous loss landscapes (different parameters converge at different rates).
Memory cost: Extra buffers (momentum, adaptive rates) require more VRAM. AdamW uses 2× model weights. Adafactor uses 1× (factorised second moment).
Adam (Adaptive Moment Estimation) maintains running averages of first (m, momentum) and second moments (v, squared gradient) of gradients. It scales learning rate by √v, so parameters with high variance get smaller updates.
Update rule:
m ← β₁·m + (1-β₁)·g
v ← β₂·v + (1-β₂)·g²
m̂ ← m / (1-β₁ᵗ) [bias correction]
v̂ ← v / (1-β₂ᵗ)
θ ← θ - α·m̂ / (√v̂ + ε)
Default: β₁=0.9 (momentum), β₂=0.999 (RMSProp), ε=1e-8 (numerical stability).
The small constant ε in the denominator prevents division by zero when v is tiny. But ε=1e-8 (default) can hurt early training: it's large relative to small gradients. Many practitioners use ε=1e-6 or tune it. For mixed-precision training, ε=1e-5 is common.
Without bias correction, m and v are biased toward zero in early steps (they're initialized at 0). The correction terms (1-β₁ᵗ) and (1-β₂ᵗ) warm up from step t=1, fixing the bias. Critical for early training stability.
L2 regularisation: Add λ·‖θ‖² to the loss. Gradient becomes ∇L + 2λ·θ. The penalty is adaptive — parameters with large gradients aren't penalised as much.
Weight decay: Subtract λ·θ directly from parameters: θ ← (1-λ)·θ - α·∇L. This is decoupled — all parameters decay uniformly, regardless of gradient magnitude.
For SGD, they're equivalent. But for Adam, they diverge: Adam divides by √v, so L2 regularisation is scaled differently across parameters. AdamW decouples them: weight decay applies uniformly, adaptive learning rate applies separately.
For LLM training: use AdamW with weight_decay=0.1–0.01. Too high (>0.1) hurts training. Too low (<0.001) allows overfitting.
Fixed learning rate rarely works. LR should be high early (explore fast) and low late (refine). Scheduling adapts LR during training.
Linear warmup + cosine decay: Most popular for LLMs. Warmup phase (first 5% of training) ramps LR from ~0 to target. Then cosine annealing (cos(πt)) decays to 0 or small value (min_lr).
Linear decay: LR decreases linearly. Simpler than cosine, works ok but not optimal.
OneCycleLR: One cycle: ramp up, then ramp down. Good for shorter training runs.
Constant with restarts: Keep LR fixed for N epochs, reset optimiser, repeat. "Cyclical" approach. Less common now.
Gradient statistics (mean, variance) change wildly in the first few steps. Adam's adaptive rates are unstable. Warmup keeps the model moving slowly while gradient estimates stabilise. Critical for Transformers. Empirical: warmup to ~1% of training steps.
Transformers can have "exploding gradients" — a forward pass through many layers accumulates gradients, and a bad batch can spike loss. Clipping caps gradient norm to prevent divergence.
Max-norm clipping: If ‖∇L‖ > clip_value, rescale gradients: ∇ ← (clip_value / ‖∇L‖) · ∇. Keeps direction, reduces magnitude.
Typical clip_value: 1.0 for Transformers. Too aggressive (0.1) slows training. Too lenient (10) allows spikes.
Log gradient norms during training. If many steps exceed the clip value, either increase clip_value or reduce learning rate (gradient is too large even before clipping). Gradient norm divergence signals training instability.
Adafactor: Factorises the second moment matrix. Instead of storing v (full matrix), store row and column statistics. Memory: O(n) instead of O(n²) for n-dimensional tensors. Slightly slower convergence, but acceptable.
8-bit Adam (bitsandbytes): Quantises momentum and second moment to 8 bits. Reduces memory by ~4×. Negligible accuracy loss. Fast on modern GPUs with NF4 support.
GaLore (Gradient Low Rank): Projects gradients to lower rank, trains low-rank approximation. Reduces memory for second moment. Good for fine-tuning on limited VRAM.
Use Adafactor for on-device fine-tuning (< 24GB VRAM). Use 8-bit Adam for large-scale training (reduces VRAM by 4×). Both trade small compute overhead for large memory savings — worthwhile for expensive models.