Neural Networks

Backpropagation

The algorithm that computes gradients for every parameter in a neural network via the chain rule — the foundation of all gradient-based learning.

Chain Rule
Core Math
O(n)
Per Layer
1986
Rumelhart et al.

Table of Contents

SECTION 01

The Learning Problem

A neural network is a function f(x; W) parameterized by weights W. Training means finding W that minimizes a loss L. Gradient descent requires knowing ∂L/∂W for every parameter. For a network with 7 billion parameters, computing this efficiently is the problem backpropagation solves.

Backprop insight: The chain rule is applied once per operation in the computation graph. For a network with L layers, one backward pass costs O(L) — same asymptotic cost as the forward pass.

Why not finite differences? Numerical gradient: (f(w+ε) - f(w-ε)) / 2ε per parameter. For 7B parameters, that's 14B forward passes per update — completely infeasible. Backprop computes all gradients in one pass.

SECTION 02

Forward & Backward Pass

import torch import torch.nn as nn # Simple 2-layer network model = nn.Sequential( nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 10) ) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) x = torch.randn(32, 128) # batch of 32 inputs y = torch.randint(0, 10, (32,)) # labels # ── FORWARD PASS ── # Records all ops in computation graph output = model(x) # (32, 10) loss = criterion(output, y) # scalar # ── BACKWARD PASS ── optimizer.zero_grad() # Clear accumulated gradients loss.backward() # Compute ∂L/∂W for EVERY parameter # ── PARAMETER UPDATE ── optimizer.step() # W ← W - lr * ∂L/∂W # Check gradients for name, param in model.named_parameters(): print(f"{name}: grad norm = {param.grad.norm():.4f}")
SECTION 03

Vanishing & Exploding Gradients

Deep networks multiply many Jacobians together during backprop. If each is < 1, gradients shrink to zero (vanishing). If each is > 1, gradients explode to infinity.

import torch import torch.nn as nn # Demonstration: gradient vanishing in deep sigmoid network depth = 20 model = nn.Sequential(*[nn.Sequential(nn.Linear(64, 64), nn.Sigmoid()) for _ in range(depth)]) x = torch.randn(1, 64) y = torch.randn(1, 64) loss = ((model(x) - y)**2).mean() loss.backward() # Gradients near the input print(model[0][0].weight.grad.norm()) # Near zero — vanished! # Solutions: # 1. ReLU/GELU instead of sigmoid (doesn't saturate) # 2. Residual connections: y = x + F(x) — gradient always has +1 from skip # 3. LayerNorm — normalizes activations, keeps gradient magnitudes stable # 4. Careful weight initialization (Xavier/He) # 5. Gradient clipping for exploding gradients
Transformer solution: Transformers use residual connections + LayerNorm specifically to solve vanishing/exploding gradients. Without them, 100-layer transformers wouldn't train.
SECTION 04

Gradient Checkpointing

Backprop normally stores all intermediate activations from the forward pass. For a 100-layer network, that's 100× the memory of the forward pass. Gradient checkpointing trades memory for compute.

import torch from torch.utils.checkpoint import checkpoint # Without checkpointing: all activations stored output = model(input) # Stores all layer outputs in memory # With checkpointing: only checkpoint some layers # Re-computes activations during backward (costs extra compute) from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") model.gradient_checkpointing_enable() # Now activations are NOT stored — only checkpoints at segment boundaries # Memory usage: O(√n) instead of O(n) for n layers # Speed cost: ~25-30% slower backward pass # Manual use with torch.utils.checkpoint.checkpoint def forward_fn(module, x): return module(x) output = checkpoint(forward_fn, layer, x) # Activations recomputed on backward
SECTION 05

Gradient Clipping

import torch # Clip gradient norm — prevents exploding gradients # Standard practice in LLM training (max_norm=1.0 or 0.5) def training_step(model, optimizer, x, y): output = model(x) loss = criterion(output, y) loss.backward() # Clip gradient norm BEFORE optimizer.step() total_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=1.0 # Clip to L2 norm ≤ 1.0 ) # Log gradient norm for monitoring if total_norm > 1.0: print(f"Gradient clipped: {total_norm:.2f} → 1.0") optimizer.step() optimizer.zero_grad() return loss.item(), total_norm # In HuggingFace Trainer — gradient clipping is built in: from transformers import TrainingArguments args = TrainingArguments( max_grad_norm=1.0, # Clip to 1.0 automatically ... )
SECTION 06

Practical Training Loop

import torch import torch.nn as nn def train_epoch(model, loader, optimizer, scaler, device): model.train() total_loss = 0 for batch_idx, batch in enumerate(loader): input_ids = batch["input_ids"].to(device) labels = batch["labels"].to(device) optimizer.zero_grad() # Mixed precision forward with torch.autocast(device_type="cuda", dtype=torch.bfloat16): outputs = model(input_ids=input_ids, labels=labels) loss = outputs.loss # Mixed precision backward scaler.scale(loss).backward() scaler.unscale_(optimizer) # Gradient clipping torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) scaler.step(optimizer) scaler.update() total_loss += loss.item() return total_loss / len(loader)
Debug checklist: Loss not decreasing? Check: (1) optimizer.zero_grad() called, (2) loss.backward() called, (3) optimizer.step() called, (4) model is in train() mode, (5) parameters have requires_grad=True.

Backpropagation through transformer layers

Backpropagation through transformer attention is more memory-intensive than through simple feed-forward networks because storing the intermediate activations needed for the backward pass requires keeping the attention score matrices (sequence_length × sequence_length) for every layer in GPU memory during the forward pass. For long sequences, the memory cost of stored activations can exceed the model parameter memory — a 1024-token sequence through 32 attention layers requires storing 32 attention matrices of 1024×1024 floats. Gradient checkpointing mitigates this by recomputing activations during the backward pass instead of storing them, trading compute for memory.

Optimizer state and training memory

The Adam optimizer stores two additional states per parameter — a running mean of gradients (first moment) and a running mean of squared gradients (second moment) — tripling the memory required for parameters and optimizer state compared to the model parameters alone. A 7B parameter model in float16 requires 14GB for parameters, but full Adam optimizer training requires an additional 28GB for optimizer states (stored in float32 for numerical stability), making Adam-based fine-tuning of 7B models require 42GB+ VRAM. AdamW-8bit from bitsandbytes reduces optimizer memory by 4x through 8-bit quantization of optimizer states while maintaining comparable convergence quality.

Training componentMemory (7B model)Reduction technique
Model parameters (fp16)~14GBQuantize to 4-bit (QLoRA)
Adam optimizer states (fp32)~56GB8-bit Adam, Adafactor
Gradients (fp16)~14GBGradient accumulation
ActivationsVariableGradient checkpointing

Computational graphs and memory efficiency

Backpropagation operates on a computational graph constructed during forward pass: each tensor stores references to its parents and the operation that produced it. The graph is a directed acyclic graph (DAG) where nodes are tensors and edges represent operations (matmul, addition, activation). During backpropagation, PyTorch traverses this DAG in reverse topological order, computing gradients. The memory footprint depends on what's retained: storing activations for every layer enables fast backward computation but consumes O(depth) memory. Techniques like gradient checkpointing trade memory for computation: periodically discard activations during forward pass, then recompute them during backward (at cost of extra FLOPS). For models with depth > 100 layers, this reduces memory by 50%+ at ~30% compute overhead, crucial for fitting large models on limited hardware. Production implementations carefully balance: gradient checkpointing on expensive layers (transformer self-attention), retaining activations on cheap layers (linear projections), and profiling memory/compute trade-offs empirically for each architecture.

Numerical stability and precision management

Backpropagation through many layers accumulates numerical errors: gradients can explode (unstable training) or vanish (gradients near zero after many layers prevent weight updates). Modern techniques address this: layer normalization stabilizes activations before they enter the next layer, preventing activation ranges from diverging; residual connections allow gradients to skip layers, addressing vanishing gradients; and mixed-precision training uses float16 for forward/backward but float32 for gradient accumulation, reducing memory while maintaining numerical precision. Understanding where numerical issues occur requires gradient monitoring: tracking gradient magnitude and variance across layers, and comparing to activation magnitude. Libraries like TensorFlow and PyTorch provide debugging tools: `tf.debugging.check_numerics()` halts on NaN/Inf, and PyTorch's `anomaly_detection` mode pinpoints the operation producing non-finite gradients. Production training pipelines often implement monitoring: logging gradient statistics to tensorboard, alerting when gradient norms diverge beyond expected ranges, and enabling easy rollback when training becomes numerically unstable.

Custom backward functions and gradient override

PyTorch allows defining custom backward functions via `torch.autograd.Function`, enabling scenarios where standard gradient computation is insufficient. Common uses: straight-through estimators (STE) for discrete operations like quantization or sampling, where gradients don't truly exist but estimates enable learning; reweighting gradients for imbalanced classification (higher weight on rare classes); and approximating expensive-to-compute gradients via cheaper surrogates. Custom backward demands rigor: gradients must satisfy numerical correctness (checked via finite differences: does gradient*epsilon approximately equal loss change?), and must be tested independently of the broader model. For research and specialized applications, this flexibility is powerful: novel loss functions, modified attention mechanisms, and experimental gradient estimators all become feasible. However, mistakes in backward implementation can silently produce wrong gradients that don't trigger errors but degrade convergence subtly, so rigorous testing of custom backward functions is non-negotiable.