Adaptive Moment Estimation — the standard optimizer for training transformers and fine-tuning LLMs. Combines momentum with adaptive per-parameter learning rates.
Vanilla SGD uses the same learning rate for all parameters. Adam maintains per-parameter adaptive learning rates — parameters that rarely get large gradients get a relatively higher effective learning rate. This makes it much more robust to sparse gradients and different feature scales.
The original Adam applies L2 regularization by adding λw to the gradient. This couples weight decay with the adaptive learning rate — parameters with large gradients get less regularization. AdamW decouples them: apply weight decay directly to the parameter, not through the gradient.
| Setting | Pre-training | Fine-tuning | Notes |
|---|---|---|---|
| Learning rate | 1e-4 to 3e-4 | 1e-5 to 3e-4 | Use warmup + cosine decay |
| β₁ (momentum) | 0.9 | 0.9 | Rarely changed |
| β₂ (RMSProp) | 0.95 | 0.999 | 0.95 for LLM pre-training (Chinchilla), 0.999 for fine-tuning |
| ε (epsilon) | 1e-8 | 1e-8 | Increase to 1e-6 if NaN in bfloat16 |
| Weight decay | 0.1 | 0.01–0.1 | Applied to all params except norms and biases |
| Gradient clip | 1.0 | 1.0 | max_norm=1.0 standard |
Several Adam variants have been developed to reduce the memory overhead of standard Adam's two optimizer state vectors. AdaFactor replaces the full second-moment matrix with a factored approximation, reducing memory from O(parameters) to O(sqrt(parameters)) at a small convergence quality cost. Came (Cautious Adam with Momentum Enhancement) improves AdaFactor convergence with cautious update masking. For 7B+ model training where optimizer states dominate memory budgets, these reduced-memory variants enable training at batch sizes that would OOM with standard AdamW.
Adam's per-parameter learning rates interact with the global learning rate schedule in non-obvious ways. Warmup schedules that linearly increase the learning rate over the first 5–10% of training steps are critical for transformer training — starting at full learning rate causes large parameter updates before the second moment estimates have stabilized, producing divergence or poor early convergence. Cosine annealing after warmup smoothly reduces the learning rate to near zero by the end of training, typically outperforming step decay schedules for language model fine-tuning by 1–2 perplexity points.
| Variant | Memory vs Adam | Convergence | Use case |
|---|---|---|---|
| AdamW | Same | Better (proper L2) | Default for fine-tuning |
| 8-bit Adam | 4x reduction | Comparable | Limited VRAM fine-tuning |
| AdaFactor | ~4x reduction | Slightly worse | Pretraining, TPUs |
| Lion | Same as Adam | Competitive | Vision models |
Adam (Adaptive Moment Estimation) maintains per-parameter learning rates using first-moment (momentum) and second-moment (variance) estimates: θ_t = θ_{t-1} - α × m_t / (√v_t + ε), where m_t = β₁ × m_{t-1} + (1-β₁) × g_t (momentum) and v_t = β₂ × v_{t-1} + (1-β₂) × g_t² (variance). Default β₁=0.9, β₂=0.999 bias momentum toward recent gradients while maintaining long-term variance estimates. Variants address specific limitations: AMSGrad clip v_t to prevent unbounded growth, AdaBound adds lower/upper bounds on adaptive rates to transition from Adam-like (early training) to SGD-like (convergence), AdamW decouples weight decay from gradient updates (crucial fix). RAdam (Rectified Adam) is crucial: Adam's early training is unstable due to biased moment estimates from warm-start (m_0=0, v_0=0); RAdam automatically switches to SGD-like updates initially (large batch effective), then transitions to adaptive updates as sufficient samples accumulate. This improves convergence and final accuracy by 1–3% on large-batch training. Lookahead meta-optimizer wraps Adam: maintains slow weights (updated infrequently) and fast weights (Adam's normal updates); improves stability and generalization by 2–5% on language models. In practice, torch.optim.AdamW with weight_decay=0.01, lr=1e-3 is the industrial default for transformer training, simple and effective.
Raw Adam learning rate α directly impacts convergence: too high causes training instability (loss diverges), too low stalls optimization. Standard approach: warmup schedule linearly increases α from 0 to target over first N steps, prevents gradient explosion from biased moment estimates. Warmup ratio of 10% (if total steps=10k, warmup=1k) is empirically optimal; 5% insufficient for large models, 20% overshoots. Cosine annealing schedule drops α from peak to minimum over training, with periodic restarts (SGDR): improves final loss 1–2% vs fixed-rate schedules by exploring different basins. PyTorch's lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2) implements this: warm for 10 steps, anneal over 10 steps, restart with doubled period. For large-scale training (>1B tokens), learning rate scaling follows: α_eff = α × sqrt(batch_size), accounting for gradient noise variance; batch_size doubled → α increased √2×. This scaling preserves convergence properties across different batch sizes; violating this causes either poor convergence (small batch with large α) or slow convergence (large batch with small α). Typical tuning loop: (1) find peak α via coarse search (0.001, 0.01, 0.1), (2) refine warmup ratio (5%, 10%, 20%), (3) select schedule (cosine, linear decay, constant). Hyperparameter sweeps use Ray Tune or Weights & Biases; 10–20 trials reveal optimal α within factor-of-2 confidence. Fine-tuning with smaller learning rates: start with α from pre-training, reduce by 10× (1e-4 → 1e-5) to avoid catastrophic forgetting.
AdamW (Adam with decoupled weight decay) is standard for transformer training; maintains separate momentum and variance for each parameter, O(2d) memory overhead where d = model dimension. For 70B parameter models, Adam's optimizer state (momentum + variance + parameter) requires 3×model_size memory: 70B params (140GB fp32) + 140GB momentum + 140GB variance = 420GB, fitting only on 8×A100-80GB clusters with 32-way sharding (ZeRO-3). AdaFactor is memory-efficient: approximates second moment (variance) using only row/column statistics instead of full matrix, reducing variance memory from O(d²) to O(d) for each parameter. This 100× memory reduction comes at cost of slightly worse convergence (0.5–1% final loss increase) and slower convergence (20–30% more steps required). For language model pre-training, AdaFactor is increasingly preferred: enables single-GPU training of 7B models (previously required multi-GPU setup), critical for research groups without massive compute budgets. Compute trade-off: AdaFactor's approximation requires matrix factorization per step (O(d log d) operations), making per-step compute ~20% higher than Adam; total wall-clock time is comparable or slightly faster due to reduced memory pressure (fewer gradient accumulation steps needed). In production, AdamW for fine-tuning and instruction-tuning (smaller models, plenty of memory), AdaFactor for pre-training and research (memory-constrained, thousands of steps required). Hybrid: warm up with SGD (no momentum), switch to AdamW (stability), optionally end with AdaFactor (convergence acceleration); empirically improves final model quality by 1–2% on GLUE benchmarks.