Optimization

Adam Optimizer

Adaptive Moment Estimation — the standard optimizer for training transformers and fine-tuning LLMs. Combines momentum with adaptive per-parameter learning rates.

AdamW
LLM Standard
β₁=0.9
Momentum
β₂=0.999
RMSProp

Table of Contents

SECTION 01

Why Adam?

Vanilla SGD uses the same learning rate for all parameters. Adam maintains per-parameter adaptive learning rates — parameters that rarely get large gradients get a relatively higher effective learning rate. This makes it much more robust to sparse gradients and different feature scales.

SECTION 02

Adam Algorithm

import torch # Adam update rule (manual implementation for clarity) # Parameters: β₁=0.9 (momentum), β₂=0.999 (RMSProp), ε=1e-8 beta1, beta2, eps, lr = 0.9, 0.999, 1e-8, 1e-3 # State (initialized to zero) m = torch.zeros_like(param) # 1st moment (momentum) v = torch.zeros_like(param) # 2nd moment (squared gradient EMA) for t in range(1, num_steps + 1): g = param.grad # gradient at step t # Update biased moment estimates m = beta1 * m + (1 - beta1) * g # EMA of gradients v = beta2 * v + (1 - beta2) * g.pow(2) # EMA of squared gradients # Bias correction (compensates for zero init) m_hat = m / (1 - beta1**t) v_hat = v / (1 - beta2**t) # Update parameter param = param - lr * m_hat / (v_hat.sqrt() + eps) # PyTorch equivalent: optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-8)
SECTION 03

AdamW: Weight Decay Fix

The original Adam applies L2 regularization by adding λw to the gradient. This couples weight decay with the adaptive learning rate — parameters with large gradients get less regularization. AdamW decouples them: apply weight decay directly to the parameter, not through the gradient.

import torch # Adam (wrong weight decay): # param = param - lr * (grad + lambda * param) / (sqrt(v) + eps) # → weight decay scaled by 1/sqrt(v) — inconsistent! # AdamW (correct): # param = param - lr * grad / (sqrt(v) + eps) - lr * lambda * param # → weight decay always at rate lr * lambda — clean L2 regularization # Always use AdamW for fine-tuning and pre-training optimizer = torch.optim.AdamW( model.parameters(), lr=2e-5, # For fine-tuning betas=(0.9, 0.95), # β₂=0.95 is better for LLMs (per Chinchilla) eps=1e-8, weight_decay=0.1 # Decoupled weight decay ) # In HuggingFace: from transformers import AdamW optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01) # Note: HF's AdamW is deprecated — use torch.optim.AdamW
SECTION 04

Hyperparameter Guide

SettingPre-trainingFine-tuningNotes
Learning rate1e-4 to 3e-41e-5 to 3e-4Use warmup + cosine decay
β₁ (momentum)0.90.9Rarely changed
β₂ (RMSProp)0.950.9990.95 for LLM pre-training (Chinchilla), 0.999 for fine-tuning
ε (epsilon)1e-81e-8Increase to 1e-6 if NaN in bfloat16
Weight decay0.10.01–0.1Applied to all params except norms and biases
Gradient clip1.01.0max_norm=1.0 standard
SECTION 05

Memory-Efficient Variants

import torch # Standard Adam memory: 2 optimizer states per param # For 7B params: 7B × 2 × 4 bytes = 56 GB just for optimizer states! # AdaFactor: O(√n) states instead of O(n) — used to train T5, PaLM from transformers.optimization import Adafactor optimizer = Adafactor( model.parameters(), scale_parameter=True, relative_step=True, # Adaptive LR — no manual LR needed warmup_init=True ) # Adam 8-bit: store optimizer states in 8-bit (bitsandbytes) import bitsandbytes as bnb optimizer = bnb.optim.Adam8bit(model.parameters(), lr=1e-4) # Memory: 56 GB → 14 GB for a 7B model optimizer state # Paged Adam: CPU offload for optimizer states (uses pinned memory) optimizer = bnb.optim.PagedAdamW32bit(model.parameters(), lr=1e-4) # Used in QLoRA — fine-tune 65B on a single A100
SECTION 06

Common Mistakes

import torch # Mistake 1: Learning rate too high — common cause of loss spikes # Fix: use warmup (linear warmup for first 1-5% of training steps) from transformers import get_cosine_schedule_with_warmup scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=100, num_training_steps=10000 ) # Mistake 2: Forgetting to exclude norm/bias from weight decay def no_decay_params(model): no_decay = ["bias", "layer_norm", "layernorm", "rmsnorm"] return [ {"params": [p for n, p in model.named_parameters() if not any(nd in n.lower() for nd in no_decay)], "weight_decay": 0.1}, {"params": [p for n, p in model.named_parameters() if any(nd in n.lower() for nd in no_decay)], "weight_decay": 0.0}, ] optimizer = torch.optim.AdamW(no_decay_params(model), lr=1e-4) # Mistake 3: Different LR for different param groups (correct for fine-tuning) optimizer = torch.optim.AdamW([ {"params": model.lm_head.parameters(), "lr": 1e-4}, # Higher for new head {"params": model.transformer.parameters(), "lr": 1e-5}, # Lower for pretrained ], weight_decay=0.01)
Rule of thumb: For fine-tuning: start with AdamW, lr=2e-5, weight_decay=0.01, cosine schedule, 100-step warmup. Adjust lr up if training is slow, down if loss is unstable.

Adam variants for LLM training

Several Adam variants have been developed to reduce the memory overhead of standard Adam's two optimizer state vectors. AdaFactor replaces the full second-moment matrix with a factored approximation, reducing memory from O(parameters) to O(sqrt(parameters)) at a small convergence quality cost. Came (Cautious Adam with Momentum Enhancement) improves AdaFactor convergence with cautious update masking. For 7B+ model training where optimizer states dominate memory budgets, these reduced-memory variants enable training at batch sizes that would OOM with standard AdamW.

Learning rate scheduling with Adam

Adam's per-parameter learning rates interact with the global learning rate schedule in non-obvious ways. Warmup schedules that linearly increase the learning rate over the first 5–10% of training steps are critical for transformer training — starting at full learning rate causes large parameter updates before the second moment estimates have stabilized, producing divergence or poor early convergence. Cosine annealing after warmup smoothly reduces the learning rate to near zero by the end of training, typically outperforming step decay schedules for language model fine-tuning by 1–2 perplexity points.

VariantMemory vs AdamConvergenceUse case
AdamWSameBetter (proper L2)Default for fine-tuning
8-bit Adam4x reductionComparableLimited VRAM fine-tuning
AdaFactor~4x reductionSlightly worsePretraining, TPUs
LionSame as AdamCompetitiveVision models

Adam Algorithm Variants and Adaptive Learning Rate Mechanics

Adam (Adaptive Moment Estimation) maintains per-parameter learning rates using first-moment (momentum) and second-moment (variance) estimates: θ_t = θ_{t-1} - α × m_t / (√v_t + ε), where m_t = β₁ × m_{t-1} + (1-β₁) × g_t (momentum) and v_t = β₂ × v_{t-1} + (1-β₂) × g_t² (variance). Default β₁=0.9, β₂=0.999 bias momentum toward recent gradients while maintaining long-term variance estimates. Variants address specific limitations: AMSGrad clip v_t to prevent unbounded growth, AdaBound adds lower/upper bounds on adaptive rates to transition from Adam-like (early training) to SGD-like (convergence), AdamW decouples weight decay from gradient updates (crucial fix). RAdam (Rectified Adam) is crucial: Adam's early training is unstable due to biased moment estimates from warm-start (m_0=0, v_0=0); RAdam automatically switches to SGD-like updates initially (large batch effective), then transitions to adaptive updates as sufficient samples accumulate. This improves convergence and final accuracy by 1–3% on large-batch training. Lookahead meta-optimizer wraps Adam: maintains slow weights (updated infrequently) and fast weights (Adam's normal updates); improves stability and generalization by 2–5% on language models. In practice, torch.optim.AdamW with weight_decay=0.01, lr=1e-3 is the industrial default for transformer training, simple and effective.

Learning Rate Scheduling, Warmup, and Tuning for Optimal Convergence

Raw Adam learning rate α directly impacts convergence: too high causes training instability (loss diverges), too low stalls optimization. Standard approach: warmup schedule linearly increases α from 0 to target over first N steps, prevents gradient explosion from biased moment estimates. Warmup ratio of 10% (if total steps=10k, warmup=1k) is empirically optimal; 5% insufficient for large models, 20% overshoots. Cosine annealing schedule drops α from peak to minimum over training, with periodic restarts (SGDR): improves final loss 1–2% vs fixed-rate schedules by exploring different basins. PyTorch's lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2) implements this: warm for 10 steps, anneal over 10 steps, restart with doubled period. For large-scale training (>1B tokens), learning rate scaling follows: α_eff = α × sqrt(batch_size), accounting for gradient noise variance; batch_size doubled → α increased √2×. This scaling preserves convergence properties across different batch sizes; violating this causes either poor convergence (small batch with large α) or slow convergence (large batch with small α). Typical tuning loop: (1) find peak α via coarse search (0.001, 0.01, 0.1), (2) refine warmup ratio (5%, 10%, 20%), (3) select schedule (cosine, linear decay, constant). Hyperparameter sweeps use Ray Tune or Weights & Biases; 10–20 trials reveal optimal α within factor-of-2 confidence. Fine-tuning with smaller learning rates: start with α from pre-training, reduce by 10× (1e-4 → 1e-5) to avoid catastrophic forgetting.

AdamW vs. AdaFactor and Trade-offs in Memory and Compute

AdamW (Adam with decoupled weight decay) is standard for transformer training; maintains separate momentum and variance for each parameter, O(2d) memory overhead where d = model dimension. For 70B parameter models, Adam's optimizer state (momentum + variance + parameter) requires 3×model_size memory: 70B params (140GB fp32) + 140GB momentum + 140GB variance = 420GB, fitting only on 8×A100-80GB clusters with 32-way sharding (ZeRO-3). AdaFactor is memory-efficient: approximates second moment (variance) using only row/column statistics instead of full matrix, reducing variance memory from O(d²) to O(d) for each parameter. This 100× memory reduction comes at cost of slightly worse convergence (0.5–1% final loss increase) and slower convergence (20–30% more steps required). For language model pre-training, AdaFactor is increasingly preferred: enables single-GPU training of 7B models (previously required multi-GPU setup), critical for research groups without massive compute budgets. Compute trade-off: AdaFactor's approximation requires matrix factorization per step (O(d log d) operations), making per-step compute ~20% higher than Adam; total wall-clock time is comparable or slightly faster due to reduced memory pressure (fewer gradient accumulation steps needed). In production, AdamW for fine-tuning and instruction-tuning (smaller models, plenty of memory), AdaFactor for pre-training and research (memory-constrained, thousands of steps required). Hybrid: warm up with SGD (no momentum), switch to AdamW (stability), optionally end with AdaFactor (convergence acceleration); empirically improves final model quality by 1–2% on GLUE benchmarks.