Foundations · ML Core

Neural Networks Fundamentals

Perceptrons, backpropagation, activation functions, and batch normalisation — the building blocks of deep learning

6 activation fns
6 sections
PyTorch first
Contents
  1. Architecture basics
  2. Activation functions
  3. Backpropagation
  4. Normalisation layers
  5. Weight initialisation
  6. Training stability
  7. Tools & frameworks
  8. References
01 — Structure

Architecture Basics

A neural network is stacked matrix multiplications plus nonlinearities. Each layer applies y = Wx + b, where W is a learnable weight matrix, x is input, b is bias. Then a nonlinear activation (ReLU, GELU) is applied. The composition of many layers creates expressive function approximators.

PyTorch nn.Linear Example

import torch from torch import nn layer = nn.Linear(784, 128) x = torch.randn(32, 784) y = layer(x) print(f"Weight shape: {layer.weight.shape}") print(f"Bias shape: {layer.bias.shape}")

The weight matrix has shape (out_features, in_features), transposed for efficient matrix multiplication. Forward pass: y = x @ W.T + b. Stacking many such layers with nonlinearities creates deep networks capable of learning complex mappings.

💡 Matrix view: Think of neural networks as compositions of linear transformations and nonlinearities. This algebraic view makes optimization, parallelization, and hardware acceleration clear.
02 — Nonlinearities

Activation Functions

FunctionFormulaRangeDying ReLULLM use
Sigmoid1 / (1 + e^-x)[0, 1]NoRare
Tanh(e^x - e^-x) / (e^x + e^-x)[-1, 1]NoRare
ReLUmax(0, x)[0, inf)YesLegacy
GELUx * Phi(x)[-0.2, x]NoStandard
SiLUx * sigmoid(x)[-0.18, x]NoModern
Swishx * sigmoid(beta*x)[-0.18, x]NoCommon

Why Activation Matters

ReLU is fast but suffers from dying units: neurons that output 0 permanently stop learning. GELU (smooth approximation to ReLU) avoids this; standard in transformers. SiLU/Swish are modern alternatives with slightly better empirical performance on large models.

LLMs use GELU or SiLU in feedforward layers. Both allow gradients to flow through inactive regions and have negative outputs for small x, preventing output collapse.

⚠️ Dying ReLU problem: If many neurons output 0, they don't update weights. Leaky ReLU (allows small negative slope) mitigates this. For LLMs, use GELU or SiLU; don't use vanilla ReLU.
03 — Learning

Backpropagation

Backpropagation computes gradients via the chain rule. Loss L depends on output y, which depends on hidden layer h, which depends on input x and weights W. Chain rule: dL/dW = dL/dy * dy/dh * dh/dW. PyTorch's autograd automates this; you call backward() and gradients are computed end-to-end.

Autograd Concept

import torch x = torch.tensor([2.0], requires_grad=True) y = x ** 2 + 3 * x loss = y.sum() loss.backward() print(f"dx: {x.grad}") # At x=2: dy/dx = 2*x + 3 = 7

Computation Graph

PyTorch builds a dynamic computation graph during forward pass. Each operation records how to compute its gradient. backward() traverses the graph, computing gradients at each node. This is efficient: gradients are computed only for nodes needed to reach the loss.

Key insight: Backprop is automatic differentiation. No hand-coded gradients needed. This enables rapid experimentation and complex architectures.

04 — Stability

Normalisation Layers

TypeNormalizesWhen usedLLMs
BatchNormPer feature, across batchCNNs, old RNNsNo
LayerNormPer sample, across featuresTransformersYes (standard)
RMSNormPer sample, L2 norm (no mean)Modern transformersYes (modern)
GroupNormPer group of featuresVariable batch sizeRare

Why LayerNorm for LLMs

LLMs use LayerNorm (or RMSNorm) because: (1) sequence length varies, so batch statistics are unstable; (2) normalizing per sample (not batch) prevents inter-example coupling; (3) applied before attention/FFN, stabilizes training. Applied post-activation in LLaMA-style architecture.

PyTorch Example

import torch from torch import nn ln = nn.LayerNorm(512) x = torch.randn(32, 16, 512) y = ln(x) print(f"After LayerNorm, mean: {y.mean():.4f}, std: {y.std():.4f}")
💡 LayerNorm vs RMSNorm: LayerNorm subtracts mean and divides by variance. RMSNorm only divides by RMS (norm), skipping mean subtraction. RMSNorm is slightly faster, nearly equivalent in practice. Modern LLMs prefer RMSNorm for efficiency.
05 — Initialization

Weight Initialisation

Weights initialized poorly → gradients vanish or explode → training fails. Standard approaches: Xavier/Glorot (uniform/normal with scaling for layer size), He/Kaiming (larger variance for ReLU), scaled init (custom scaling based on model architecture).

Why It Matters

At initialization, each layer output has some variance. If variance shrinks layer-by-layer, gradients vanish. If variance grows, gradients explode. Proper initialization keeps variance ~1 across layers, allowing stable gradient flow.

PyTorch Defaults

import torch from torch import nn # Default initialization (Kaiming for ReLU) layer = nn.Linear(512, 512) print(f"Weight std: {layer.weight.std().item():.4f}") # Manual Xavier nn.init.xavier_uniform_(layer.weight) print(f"After Xavier: {layer.weight.std().item():.4f}")

LLMs use careful initialization: weights drawn from a narrow distribution, then scaled by 1/sqrt(N) where N is layer width. This keeps activation variance stable during forward and backward pass. See GPT/LLaMA codebases for exact formulas.

06 — Training Dynamics

Training Stability

Gradient Clipping

If gradient norm exceeds threshold, scale down. Prevents exploding gradients from large loss spike. Standard: clip gradients to norm 1.0 during training. Essential for transformers; less critical for modern optimizers (AdamW with weight decay).

Learning Rate Warmup

Start with tiny learning rate, gradually increase over first 10K steps, then decay. Warmup stabilizes early training when loss landscape is rough. Common: linear warmup to peak LR, then cosine decay.

Loss Spikes & Fixes

Sudden loss spike: gradient norm exploded. Reduce learning rate or increase gradient clip threshold. Loss plateaus: stuck in local minimum or learning rate too small. Increase LR or change optimizer. Slow convergence: batch size too small or learning rate too low.

⚠️ Common failure mode: Large-scale models can diverge during training without gradient clipping + warmup. Always use both. If loss spikes even with clipping, reduce learning rate 10x and retrain from checkpoint.
07 — Ecosystem

Tools & Frameworks

PyTorch

Dynamic computation graphs, autograd, GPU support. Industry standard for research and production LLMs.

JAX

Functional array library with autograd. Good for custom kernels and XLA compilation. Used at DeepMind.

Flax

JAX-based neural network library. Explicit parameter handling. Growing in LLM space.

TensorFlow

Static/eager graph modes. Less popular for LLMs; strong for production serving.

Keras

High-level API. Good for quick prototyping; less control than raw TensorFlow/PyTorch.

TorchViz

Visualize PyTorch computation graphs. Debugging tool for understanding forward/backward pass.

Weights & Biases

Experiment tracking, hyperparameter logging, model versioning. Standard for ML teams.

Captum

Model interpretability for PyTorch. Understand which inputs/neurons matter for predictions.

08 — Further Reading

References

Academic Papers
Documentation & Guides
Practitioner Writing