Foundations · Training

Optimisation Algorithms

SGD, Adam, AdamW, and beyond — how gradient-based optimisers shape training dynamics and convergence

6 optimisers
6 sections
PyTorch-first code examples
Contents
  1. Gradient descent variants
  2. 6 optimisers
  3. Adam deep dive
  4. Weight decay
  5. Learning rate scheduling
  6. Gradient clipping
  7. Memory-efficient
  8. References
01 — Foundation

Gradient Descent Variants

All modern optimisers are variants of gradient descent: θ ← θ - α∇L(θ), where α is the learning rate. The differences are in how you accumulate gradients and adapt the learning rate.

Batch GD: Compute loss over entire dataset. Stable, cheap gradient (amortised over N examples). But expensive memory and slow feedback loops.

Mini-batch GD: Compute loss over k examples. The standard in deep learning. Good tradeoff between variance and speed.

Stochastic GD (SGD): k = 1. Noisier gradients. But noise acts as regularisation — helps escape local minima. Cheaper per step.

💡 The noise insight: Gradient noise from small batches is beneficial for generalisation. Too much noise → instability. Too little → overfitting. Optimisers balance this with momentum and adaptive learning rates.
02 — Comparison

Optimiser Comparison

Optimiser Adaptive LR Memory Convergence Used in
SGD + momentum No O(θ) Slow, stable Vision CNNs, baselines
RMSProp Yes O(θ) Medium RNNs, Transformer warmup
Adam Yes 2×O(θ) Fast, sometimes overfits BERT, GPT, default LLM choice
AdamW Yes 2×O(θ) Fast, better generalization Most LLMs, state-of-the-art
Adafactor Yes O(θ) Fast, memory-efficient T5, large models on limited GPU
Sophia Yes 3×O(θ) Very fast (2-3× speedup) Research, scaling studies

Adaptive LR: Optimiser adjusts learning rate per parameter based on gradient history. Good for heterogeneous loss landscapes (different parameters converge at different rates).

Memory cost: Extra buffers (momentum, adaptive rates) require more VRAM. AdamW uses 2× model weights. Adafactor uses 1× (factorised second moment).

03 — Default

Adam: Deep Dive

Adam (Adaptive Moment Estimation) maintains running averages of first (m, momentum) and second moments (v, squared gradient) of gradients. It scales learning rate by √v, so parameters with high variance get smaller updates.

Update rule:
m ← β₁·m + (1-β₁)·g
v ← β₂·v + (1-β₂)·g²
m̂ ← m / (1-β₁ᵗ) [bias correction]
v̂ ← v / (1-β₂ᵗ)
θ ← θ - α·m̂ / (√v̂ + ε)

Default: β₁=0.9 (momentum), β₂=0.999 (RMSProp), ε=1e-8 (numerical stability).

Why ε Matters

The small constant ε in the denominator prevents division by zero when v is tiny. But ε=1e-8 (default) can hurt early training: it's large relative to small gradients. Many practitioners use ε=1e-6 or tune it. For mixed-precision training, ε=1e-5 is common.

Bias Correction

Without bias correction, m and v are biased toward zero in early steps (they're initialized at 0). The correction terms (1-β₁ᵗ) and (1-β₂ᵗ) warm up from step t=1, fixing the bias. Critical for early training stability.

Python: Adam in PyTorch

import torch.optim as optim optimizer = optim.Adam( model.parameters(), lr=1e-3, # learning rate betas=(0.9, 0.999), # (momentum, RMSProp) eps=1e-8, # numerical stability weight_decay=0, # L2 penalty (don't use with Adam, see AdamW) ) for epoch in range(num_epochs): loss = loss_fn(model(x), y) optimizer.zero_grad() loss.backward() optimizer.step() # m, v updated, θ updated
⚠️ Adam + weight_decay is broken: Adam's weight_decay applies L2 regularisation to the adaptive update, not true weight decay. Use AdamW (decoupled weight decay) instead. The difference matters: Adam with wd=0.01 is NOT equivalent to AdamW with wd=0.01.
04 — Regularisation

Weight Decay vs L2 Regularisation

L2 regularisation: Add λ·‖θ‖² to the loss. Gradient becomes ∇L + 2λ·θ. The penalty is adaptive — parameters with large gradients aren't penalised as much.

Weight decay: Subtract λ·θ directly from parameters: θ ← (1-λ)·θ - α·∇L. This is decoupled — all parameters decay uniformly, regardless of gradient magnitude.

For SGD, they're equivalent. But for Adam, they diverge: Adam divides by √v, so L2 regularisation is scaled differently across parameters. AdamW decouples them: weight decay applies uniformly, adaptive learning rate applies separately.

Python: AdamW (Correct)

import torch.optim as optim optimizer = optim.AdamW( model.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0.01, # Decoupled weight decay ) # Or with bitsandbytes (8-bit) for large models: # from bitsandbytes.optim import AdamW8bit # optimizer = AdamW8bit(model.parameters(), lr=1e-3)

For LLM training: use AdamW with weight_decay=0.1–0.01. Too high (>0.1) hurts training. Too low (<0.001) allows overfitting.

05 — Dynamics

Learning Rate Scheduling

Fixed learning rate rarely works. LR should be high early (explore fast) and low late (refine). Scheduling adapts LR during training.

Common Schedules

Linear warmup + cosine decay: Most popular for LLMs. Warmup phase (first 5% of training) ramps LR from ~0 to target. Then cosine annealing (cos(πt)) decays to 0 or small value (min_lr).

Linear decay: LR decreases linearly. Simpler than cosine, works ok but not optimal.

OneCycleLR: One cycle: ramp up, then ramp down. Good for shorter training runs.

Constant with restarts: Keep LR fixed for N epochs, reset optimiser, repeat. "Cyclical" approach. Less common now.

Why Warmup?

Gradient statistics (mean, variance) change wildly in the first few steps. Adam's adaptive rates are unstable. Warmup keeps the model moving slowly while gradient estimates stabilise. Critical for Transformers. Empirical: warmup to ~1% of training steps.

Python: Warmup + Cosine Decay

import math from torch.optim.lr_scheduler import LambdaLR def cosine_schedule(step, total_steps, warmup_steps): if step < warmup_steps: # Linear warmup return float(step) / float(max(1, warmup_steps)) # Cosine annealing progress = float(step - warmup_steps) / float( max(1, total_steps - warmup_steps)) return 0.5 * (1.0 + math.cos(math.pi * progress)) total_steps = 10000 warmup_steps = 500 scheduler = LambdaLR( optimizer, lambda step: cosine_schedule(step, total_steps, warmup_steps) ) for epoch in range(num_epochs): loss = loss_fn(model(x), y) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() # Update LR

PyTorch Cosine Annealing LR

from torch.optim.lr_scheduler import CosineAnnealingLR scheduler = CosineAnnealingLR( optimizer, T_max=10000, # Total steps eta_min=1e-6 # Min LR at end ) # Use with warmup: chain schedulers from torch.optim.lr_scheduler import SequentialLR warmup = LinearLR(optimizer, start_factor=0.01, total_iters=500) cosine = CosineAnnealingLR(optimizer, T_max=9500) scheduler = SequentialLR(optimizer, [warmup, cosine], milestones=[500])
06 — Stability

Gradient Clipping & Stability

Transformers can have "exploding gradients" — a forward pass through many layers accumulates gradients, and a bad batch can spike loss. Clipping caps gradient norm to prevent divergence.

Max-norm clipping: If ‖∇L‖ > clip_value, rescale gradients: ∇ ← (clip_value / ‖∇L‖) · ∇. Keeps direction, reduces magnitude.

Typical clip_value: 1.0 for Transformers. Too aggressive (0.1) slows training. Too lenient (10) allows spikes.

Python: Gradient Clipping

import torch loss = loss_fn(model(x), y) optimizer.zero_grad() loss.backward() # Clip gradient norm to 1.0 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Or by value (less common) torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0) optimizer.step()

Monitoring Gradient Norms

Log gradient norms during training. If many steps exceed the clip value, either increase clip_value or reduce learning rate (gradient is too large even before clipping). Gradient norm divergence signals training instability.

💡 Gradient clipping is essential for LLMs. Deep Transformers can accumulate enormous gradients. Omitting it risks NaN losses and divergence. Standard practice: clip_norm=1.0.
07 — Scaling

Memory-Efficient Optimisers

Adafactor: Factorises the second moment matrix. Instead of storing v (full matrix), store row and column statistics. Memory: O(n) instead of O(n²) for n-dimensional tensors. Slightly slower convergence, but acceptable.

8-bit Adam (bitsandbytes): Quantises momentum and second moment to 8 bits. Reduces memory by ~4×. Negligible accuracy loss. Fast on modern GPUs with NF4 support.

GaLore (Gradient Low Rank): Projects gradients to lower rank, trains low-rank approximation. Reduces memory for second moment. Good for fine-tuning on limited VRAM.

Python: Adafactor

from transformers import Adafactor optimizer = Adafactor( model.parameters(), lr=1e-3, scale_parameter=True, # Auto scale LR relative_step=False, # Manual LR ) # Usually paired with ScheduleFreeLR from transformers import ScheduleFreeLR scheduler = ScheduleFreeLR(optimizer)

Python: 8-bit Adam

from bitsandbytes.optim import AdamW8bit optimizer = AdamW8bit( model.parameters(), lr=1e-3, weight_decay=0.01, ) # Works on any device, especially efficient on A100/H100

Use Adafactor for on-device fine-tuning (< 24GB VRAM). Use 8-bit Adam for large-scale training (reduces VRAM by 4×). Both trade small compute overhead for large memory savings — worthwhile for expensive models.

08 — Ecosystem

Tools & Libraries

Framework
PyTorch optim
SGD, Adam, AdamW, RMSProp. Core optimisers. Always available.
Quantised
bitsandbytes
8-bit Adam, 8-bit SGD. Memory-efficient. NF4 support.
Optimiser
Sophia
Hessian-aware optimiser. 2–3× faster convergence.
Optimiser
Prodigy
Adaptive learning rate. Learning-rate-free training.
Scheduling
ScheduleFreeLR
No manual scheduling. Automatic learning rate decay.
Optimiser
GaLore
Gradient low-rank projection. Fine-tuning on limited VRAM.
Framework
Apex
NVIDIA optimisation. Mixed precision, distributed training.
Framework
DeepSpeed
Distributed training. ZeRO optimizer states sharding.
09 — Further Reading

References

Academic Papers
Documentation & Guides
Practitioner Writing