The deep learning fundamentals that underpin every transformer, fine-tune, and training run
A neural network is a function approximator composed of layers. Each layer applies a linear transformation (matrix multiply) followed by a nonlinearity (activation function). The depth (number of layers) and width (neurons per layer) determine expressiveness.
Neurons: Each neuron computes z = Wx + b, then applies activation a = σ(z). Activation functions: ReLU (most common), tanh, sigmoid introduce nonlinearity so the network can learn complex patterns. Layers: Stack of neurons. Input layer, hidden layers, output layer. Weights & biases: Parameters learned during training. Initialized small (random) initially.
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 64)
self.dropout = nn.Dropout(0.2)
self.fc2 = nn.Linear(64, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.dropout(x)
return self.fc2(x)
model = SimpleNet()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
loss_fn = nn.MSELoss()
x = torch.randn(32, 10)
y = torch.randn(32, 1)
for epoch in range(10):
pred = model(x)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 5 == 0:
print(f'Epoch {epoch}: loss={loss.item():.4f}')
Backpropagation computes gradients of the loss with respect to all parameters using the chain rule. The gradient tells you how to adjust each weight to reduce loss. It's the engine of all modern deep learning.
Forward pass: Input → layer 1 → layer 2 → ... → output. Compute loss. Backward pass: Compute ∂loss/∂W for every parameter by applying chain rule backwards through the network. Update: W := W - η·∂loss/∂W where η is learning rate.
The chain rule lets you decompose the gradient into local derivatives at each layer. Each layer only needs to know its own gradient input to compute its output gradient. This is computationally efficient even for networks with millions of parameters.
loss.backward() and PyTorch computes all gradients for you. You rarely write backprop by hand.
import torch
import torch.nn as nn
# Build a simple 2-layer network and trace gradients
torch.manual_seed(0)
model = nn.Sequential(
nn.Linear(4, 8), nn.ReLU(),
nn.Linear(8, 1)
)
x = torch.randn(3, 4) # batch of 3, 4 features
y = torch.tensor([[1.0], [0.0], [1.0]])
# Forward pass — PyTorch builds a computation graph
pred = model(x)
loss = nn.functional.binary_cross_entropy_with_logits(pred, y)
print(f"Loss: {loss.item():.4f}")
# Backward pass — reverse-mode autodiff through the graph
loss.backward()
# Inspect gradients
for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name}: grad_norm={param.grad.norm():.4f}, "
f"param_norm={param.data.norm():.4f}")
# Gradient update (what the optimizer does)
with torch.no_grad():
for param in model.parameters():
param -= 0.01 * param.grad # SGD step (lr=0.01)
param.grad.zero_()
# Verify loss decreased
pred2 = model(x)
loss2 = nn.functional.binary_cross_entropy_with_logits(pred2, y)
print(f"Loss after one step: {loss2.item():.4f}")
An optimizer takes gradients and updates parameters. Different optimizers use different strategies. SGD is the baseline; Adam adapts learning rate per parameter and usually converges faster.
| Optimizer | Learning rate | Convergence | Memory | Best for |
|---|---|---|---|---|
| SGD | Fixed | Slow but stable | Low | Classical baseline |
| SGD + momentum | Fixed | Faster than SGD | Low | Vision models |
| Adam | Adaptive per param | Fast | 2x (momentum + variance) | LLMs, default choice |
| AdamW | Adaptive per param | Fast + decoupled decay | 2x | Modern standard |
Adam maintains two moments: momentum (first moment) and variance (second moment) of gradients. It adapts the effective learning rate for each parameter based on the variance of its gradient history. Sparse parameters (rare updates) get higher learning rate; frequent parameters get lower rate automatically.
import torch
import torch.nn as nn
def train(opt_class, lr=1e-3, weight_decay=0.01, steps=300):
torch.manual_seed(42)
model = nn.Sequential(
nn.Linear(16, 64), nn.ReLU(),
nn.Linear(64, 64), nn.ReLU(),
nn.Linear(64, 1)
)
# AdamW decouples weight decay from gradient update
# Adam applies weight decay as L2 penalty (incorrect for adaptive methods)
opt = opt_class(model.parameters(), lr=lr, weight_decay=weight_decay)
X = torch.randn(512, 16); y = torch.randn(512, 1)
X_val = torch.randn(128, 16); y_val = torch.randn(128, 1)
for step in range(steps):
loss = nn.functional.mse_loss(model(X), y)
opt.zero_grad(); loss.backward(); opt.step()
val_loss = nn.functional.mse_loss(model(X_val), y_val).item()
return val_loss
adam_loss = train(torch.optim.Adam)
adamw_loss = train(torch.optim.AdamW)
print(f"Adam val loss: {adam_loss:.4f}")
print(f"AdamW val loss: {adamw_loss:.4f}")
# AdamW typically generalises better due to correct weight decay handling
# Difference is larger with deeper models and longer training
Regularization prevents overfitting (memorizing training data instead of learning patterns). Two primary techniques: dropout (random neuron zeroing) and weight decay (L2 penalty on weights).
During training, randomly zero out neurons with probability p (typically 0.1–0.5). At test time, scale activations by (1-p). This prevents co-adaptation: neurons can't rely on specific other neurons. Effect: ensemble-like behavior with single network.
Add penalty term λ||W||² to loss. Encourages weights to stay small. Smaller weights = less overfitting, smoother learned function. Learning rate interaction matters: too much decay prevents learning.
Modern training uses tricks to fit larger models and train faster. Mixed precision: compute in lower-precision (float16) to save memory and speed; keep loss in float32 for stability. Gradient checkpointing: recompute activations during backward pass instead of storing them, trade compute for memory.
Forward pass in float16, backward in float16 (with gradient scaling to prevent underflow), weight updates in float32. Reduces memory ~2x, 2-3x speedup on modern GPUs. PyTorch's torch.autocast handles this automatically.
Store only some activations, recompute others during backward. Reduces memory from O(L·B) to O(√(L·B)) (L=layers, B=batch). Slowdown: ~20% but frees memory for larger batch or deeper models.
Theory covers how training works; practice is knowing which knobs to turn first when things go wrong. In rough priority order: (1) verify your data pipeline before tuning hyperparameters; (2) monitor gradient norms — exploding or vanishing gradients diagnose most training failures early; (3) use a learning rate finder; (4) enable mixed-precision by default to halve memory use; (5) checkpoint frequently so crashes don't waste GPU hours.
Common failure modes: loss stays flat (learning rate too low or silent data bug), loss explodes (learning rate too high or missing gradient clipping), validation loss diverges early (overfitting — regularize or reduce model size), NaN loss after stable start (numerical instability — enable anomaly detection and check for zero-length inputs).
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
model = MyModel().cuda()
optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
scaler = torch.cuda.amp.GradScaler() # mixed precision scaling
for epoch in range(num_epochs):
model.train()
for step, batch in enumerate(dataloader):
optimizer.zero_grad()
# Mixed precision forward pass
with torch.autocast(device_type="cuda", dtype=torch.float16):
outputs = model(**batch)
loss = outputs.loss
# Backward with gradient scaling
scaler.scale(loss).backward()
# Gradient clipping — prevents exploding gradients
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=1.0
)
# Diagnose: warn if gradients are unusually large
if grad_norm > 5.0:
print(f"E{epoch} S{step}: large grad norm {grad_norm:.2f}")
scaler.step(optimizer)
scaler.update()
scheduler.step()
print(f"Epoch {epoch}: loss={loss.item():.4f}")
Dive deeper into specialized ML topics: