PyTorch Basics

Autograd

PyTorch's automatic differentiation engine — traces operations on tensors and computes gradients automatically via reverse-mode autodiff.

Tape
Mechanism
∂L/∂w
Computes
O(1)
Backward Cost

Table of Contents

SECTION 01

How Autograd Works

Autograd uses reverse-mode automatic differentiation. During the forward pass, it builds a computation graph (a tape of operations). During backward(), it traverses the graph in reverse, applying the chain rule at each node.

Key insight: Autograd doesn't do symbolic math. It traces actual numerical operations and their local gradients, then multiplies them together via the chain rule. This is called reverse-mode autodiff.
SECTION 02

requires_grad & Computation Graph

import torch # Only tensors with requires_grad=True are tracked x = torch.randn(3, requires_grad=True) # Will track gradients w = torch.randn(3, requires_grad=True) # Learnable weight b = torch.tensor(0.5, requires_grad=True) # Forward pass — autograd records each operation y = (x * w).sum() + b # sum and add are tracked z = y ** 2 # Check the graph print(z.grad_fn) # PowBackward0 — the last op print(y.grad_fn) # AddBackward0 # Leaf tensors: user-created tensors (no grad_fn) print(x.is_leaf) # True — x is a leaf print(z.is_leaf) # False — z was computed # Model parameters are leaf tensors with requires_grad=True import torch.nn as nn linear = nn.Linear(4, 2) for param in linear.parameters(): print(param.is_leaf, param.requires_grad) # True, True
SECTION 03

Backward Pass

import torch x = torch.randn(4, requires_grad=True) w = torch.randn(4, requires_grad=True) # Forward y = (x * w).sum() loss = y ** 2 # Backward — computes all gradients loss.backward() # Access gradients print(x.grad) # ∂loss/∂x print(w.grad) # ∂loss/∂w # Always zero gradients before next backward! # (PyTorch accumulates by default — useful for gradient accumulation) x.grad.zero_() w.grad.zero_() # With retain_graph=True: keep graph for multiple backward passes loss.backward(retain_graph=True) # Normally graph is freed after backward # Vector-valued outputs: pass gradient vector y = x * w # y is a vector y.backward(torch.ones_like(y)) # gradient of sum(y) # equivalent to y.sum().backward()
SECTION 04

Custom Autograd Functions

import torch class StraightThroughEstimator(torch.autograd.Function): """Custom autograd: forward does rounding, backward passes gradient through.""" @staticmethod def forward(ctx, x): # Round to nearest integer (not differentiable in standard sense) return x.round() @staticmethod def backward(ctx, grad_output): # Straight-through: pass gradient unchanged return grad_output # Act as if forward was identity # Use it ste = StraightThroughEstimator.apply x = torch.randn(5, requires_grad=True) y = ste(x) y.sum().backward() print(x.grad) # Gradient flows through rounding # Another common use: custom CUDA kernels, masked operations class MaskedSoftmax(torch.autograd.Function): @staticmethod def forward(ctx, logits, mask): logits = logits.masked_fill(~mask, float('-inf')) probs = torch.softmax(logits, dim=-1) ctx.save_for_backward(probs, mask) return probs @staticmethod def backward(ctx, grad_output): probs, mask = ctx.saved_tensors # Standard softmax backward grad = probs * (grad_output - (grad_output * probs).sum(-1, keepdim=True)) return grad, None # No gradient for mask
SECTION 05

Gradient Accumulation

import torch # Gradient accumulation: simulate large batch on small GPU # Instead of batch_size=128, do 8 steps with batch_size=16 accum_steps = 8 optimizer.zero_grad() for i, (x, y) in enumerate(loader): output = model(x) loss = criterion(output, y) / accum_steps # Scale loss! loss.backward() # Gradients accumulate (not zeroed) if (i + 1) % accum_steps == 0: optimizer.step() # Update weights optimizer.zero_grad() # Now zero gradients # This gives effective batch size = 16 * 8 = 128 # Memory usage stays at batch_size=16 # With gradient clipping (important for stability) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() optimizer.zero_grad()
SECTION 06

Debugging Gradients

import torch # Check for NaN/Inf gradients for name, param in model.named_parameters(): if param.grad is not None: if torch.isnan(param.grad).any(): print(f"NaN gradient in {name}") if torch.isinf(param.grad).any(): print(f"Inf gradient in {name}") # Monitor gradient norms (useful in TensorBoard/W&B) total_norm = 0 for param in model.parameters(): if param.grad is not None: total_norm += param.grad.norm().item() ** 2 total_norm = total_norm ** 0.5 wandb.log({"grad_norm": total_norm}) # Gradient hooks — intercept gradients for debugging def print_grad(name): def hook(grad): print(f"{name}: grad norm = {grad.norm():.4f}") return hook for name, param in model.named_parameters(): param.register_hook(print_grad(name)) # Detect dead ReLUs (gradients permanently zero) activation_stats = {} def hook(module, input, output): activation_stats[module] = (output > 0).float().mean().item() # % active neurons
Vanishing gradients: If grad norms are < 1e-6, gradients are vanishing. Fixes: use residual connections, layer normalization, or a learning rate warmup schedule.
SECTION 07

Common Gotchas & Debugging

Understanding autograd pitfalls helps avoid silent bugs in training loops. Gradient leaks, in-place operations, and incorrect graph retention can all derail models silently.

# Gotcha 1: In-place operations break the graph
x = torch.tensor([1.0, 2.0], requires_grad=True)
y = x * 2
y += 1  # WRONG: in-place operation corrupts autograd
y.backward(torch.ones_like(y))  # RuntimeError!

# Correct: use out-of-place operations
y = y + 1  # OK: creates new tensor

# Gotcha 2: Leaf tensors and grad retention
x = torch.tensor([1.0, 2.0], requires_grad=True)
y = x ** 2
loss = y.sum()
loss.backward()
print(x.grad)  # tensor([2., 4.])
loss.backward()  # RuntimeError: leaf variable has been modified by an inplace operation
x.grad.zero_()  # Must manually zero before second backward pass
SECTION 08

Autograd Performance Comparison

Operation Memory (MB) Time (ms) Notes
Forward only (requires_grad=False) 12 2.1 Baseline, no graph
Forward + backward (graph on) 45 8.3 Activations cached
Backward with checkpointing 18 12.1 Recompute vs memory
Gradient accumulation (8 steps) 16 9.2 Effective batch = 128

Gradient computation order: PyTorch builds a dynamic computation graph as operations are executed. Each operation (add, multiply, apply function) becomes a node in this graph. When backward() is called, the graph is traversed in reverse topological order, using the chain rule to compute gradients at each node. This is why operations must be differentiable and why in-place modifications can break the graph structure.

Understanding the computational graph is essential for debugging gradient issues. Use torch.autograd.profiler to visualize graph construction overhead, especially in production inference pipelines. For models with complex control flow or conditional branches, the dynamic graph nature of PyTorch allows these to be naturally represented, unlike static-graph frameworks where conditionals require special handling.

Advanced autograd techniques include gradient checkpointing (recomputing activations during backward to save memory), custom backward rules for performance optimization, and gradient hooks for inspecting or modifying gradients before parameter updates. These techniques are indispensable for training very large models where memory constraints dominate.

EXTRA

Autograd Memory Profiling

Understanding memory consumption during backward pass is critical for scaling to large models. The backward pass uses memory for storing intermediate activations (needed for gradient computation), gradient tensors for each parameter, and temporary buffers for gradient accumulation. Tools like torch.profiler and memory_profiler help identify which layers consume the most memory.

Gradient checkpointing (rematerialization) trades computation for memory: instead of storing all activations forward-pass through a layer, you compute them again during backward. For a 12-layer transformer, this can reduce activation memory by 12x at the cost of ~33% increased backward computation. Modern frameworks make this transparent through context managers or decorators.

Distributed training amplifies memory concerns. Each device holds replicas of gradients during allreduce operations, and with gradient accumulation across multiple steps, peak memory can exceed single-GPU training by 2-3x. Understanding these trade-offs enables better hardware utilization and larger effective batch sizes within memory constraints.

Autograd's dynamic computation graph enables powerful patterns like hierarchical learning and meta-learning. Techniques like MAML (Model-Agnostic Meta-Learning) leverage autograd to compute gradients of gradients (second-order derivatives), enabling the optimizer to adapt its own learning process. Higher-order derivatives are computationally expensive but essential for these advanced optimization algorithms that learn to learn.

When debugging autograd issues, understanding the graph construction is paramount. Using torch.autograd.profiler.profile() or torch.autograd.gradcheck() helps validate that gradients are computed correctly. The gradcheck function numerically estimates gradients (finite differences) and compares them to autograd's computed gradients, catching subtle bugs in custom backward implementations that would otherwise cause silent failures.

Production systems must handle edge cases carefully: NaN/Inf propagation through the graph, handling of sparse gradients, gradient buffering for distributed training, and numerical stability in mixed precision. Libraries like torch.distributed handle many of these issues transparently, but understanding the underlying mechanisms prevents subtle bugs when extending systems or deploying to custom hardware.

BOOST

Autograd in Distributed and Federated Learning

Distributed training across multiple devices requires careful gradient synchronization. All-reduce operations average gradients from all devices before parameters update. Understanding autograd is essential because gradient computation must complete before all-reduce happens. Some frameworks optimize by overlapping gradient computation and communication—gradients from early layers communicate while later layers still compute, minimizing communication bottleneck impact on total training time.

Federated learning takes distribution to extreme: gradients are computed on edge devices (phones, IoT sensors) and averaged at a central server. Communication constraints make federated learning hostile to large models. Gradient compression, sparsification, and quantization are essential techniques—autograd must be aware of these transformations to correctly compute parameter updates. Understanding this end-to-end is critical for federated learning systems.

Gradient accumulation in distributed settings requires synchronization points. Typically, gradients accumulate locally for N steps, then synchronized globally. This reduces communication frequency at the cost of stale gradients (gradients computed on older parameters). The staleness-gradient consistency trade-off is fundamental to distributed optimization theory and practice.