Optimization

Weight Initialization

Xavier, Kaiming, and scaled initialization strategies that prevent vanishing/exploding gradients and allow deep networks to train from scratch.

Xavier
Sigmoid/tanh
Kaiming
ReLU
σ=0.02
Transformers

Table of Contents

SECTION 01

Why Init Matters

Bad initialization causes vanishing or exploding activations on the very first forward pass — before any training occurs. If activations collapse to zero or blow up to infinity in layer 1, gradients through the whole network are broken.

import torch import torch.nn as nn # Bad: uniform random init — activations explode bad_model = nn.Sequential(*[nn.Linear(512, 512) for _ in range(10)]) nn.init.uniform_(bad_model[0].weight, -1, 1) # Large random init x = torch.randn(1, 512) for layer in bad_model: x = layer(x) print(f"Activation std: {x.std().item():.4f}") # Output: 1.2, 7.8, 45.1, 312.5 ... → explodes!
Problem: Each linear layer multiplies input by W. If W has std > 1/√n, activations grow exponentially with depth. If std < 1/√n, they shrink to zero.
SECTION 02

Xavier / Glorot Init

import torch import torch.nn as nn # Xavier initialization: designed for sigmoid/tanh activations # W ~ Uniform(-√(6/(fan_in + fan_out)), √(6/(fan_in + fan_out))) # OR N(0, 2/(fan_in + fan_out)) # Goal: keep activation variance ~1.0 across layers linear = nn.Linear(512, 512) nn.init.xavier_uniform_(linear.weight) # Uniform version nn.init.xavier_normal_(linear.weight) # Normal version (usually similar) nn.init.zeros_(linear.bias) # Bias: always init to 0 # Verify activations stay stable model = nn.Sequential(*[ nn.Sequential(nn.Linear(512, 512, bias=False), nn.Tanh()) for _ in range(20) ]) for m in model.modules(): if isinstance(m, nn.Linear): nn.init.xavier_normal_(m.weight) x = torch.randn(1, 512) for layer in model: x = layer(x) print(f"Activation std: {x.std().item():.4f}") # Output: ~0.9, ~0.85, ~0.8 ... stays bounded ✓
SECTION 03

Kaiming / He Init

import torch import torch.nn as nn # Kaiming initialization: designed for ReLU activations # ReLU zeroes half the neurons — need to compensate for variance loss # W ~ N(0, 2/fan_in) ← the "2" compensates for ReLU killing half # OR W ~ N(0, 2/fan_out) depending on propagation direction linear = nn.Linear(512, 512) # mode="fan_in" (default): forward pass variance preserved nn.init.kaiming_normal_(linear.weight, mode="fan_in", nonlinearity="relu") # mode="fan_out": backward pass variance preserved nn.init.kaiming_normal_(linear.weight, mode="fan_out", nonlinearity="relu") # PyTorch's default for Linear is actually Kaiming Uniform # (so you often don't need to manually init Linear layers with ReLU) # Verify model = nn.Sequential(*[ nn.Sequential(nn.Linear(512, 512, bias=False), nn.ReLU()) for _ in range(20) ]) for m in model.modules(): if isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, nonlinearity="relu") x = torch.randn(1, 512) for layer in model: x = layer(x) print(f"Activation std: {x.std().item():.4f}") # ~0.9 throughout ✓
SECTION 04

Transformer Init

import torch import torch.nn as nn import math # GPT-2 style initialization: # - Linear layers: N(0, 0.02) standard # - Residual projections: N(0, 0.02 / sqrt(2 * n_layers)) to prevent # residual stream growing with depth class GPTLinear(nn.Linear): """Linear layer with GPT-2-style initialization.""" def __init__(self, in_features, out_features, n_layers=12, is_residual=False): super().__init__(in_features, out_features, bias=True) std = 0.02 / math.sqrt(2 * n_layers) if is_residual else 0.02 nn.init.normal_(self.weight, mean=0.0, std=std) nn.init.zeros_(self.bias) # Embedding init embedding = nn.Embedding(50257, 768) # vocab_size, d_model nn.init.normal_(embedding.weight, mean=0, std=0.02) # In practice — HuggingFace handles all of this for you: from transformers import GPT2Config, GPT2LMHeadModel config = GPT2Config(n_embd=768, n_layer=12, n_head=12) model = GPT2LMHeadModel(config) # Weights already initialized correctly
SECTION 05

Practical Recipes

ArchitectureActivationWeight Init
Transformer (GPT-style)GELU/SwiGLUNormal(0, 0.02), residual layers ÷ √(2L)
CNN (ResNet-style)ReLUKaiming Normal, mode=fan_out
LSTM/RNNTanhOrthogonal init for recurrent weights
MLP with tanhTanh/SigmoidXavier Normal
Embedding layersN/ANormal(0, 0.02) or Uniform(-0.1, 0.1)
Output head (new task)Normal(0, 0.02), bias=0
SECTION 06

Debugging Init

import torch def check_init(model): """Print activation stats after init to catch vanishing/exploding.""" x = torch.randn(8, 512) stds = [] for name, module in model.named_modules(): if hasattr(module, 'forward') and len(list(module.children())) == 0: x = module(x) stds.append((name, x.std().item())) print(f"{name}: std={x.std():.4f}, mean={x.mean():.4f}") # Red flags: # std > 100 or std < 0.01 → init is wrong # NaN on first forward pass → init too large + overflow # Fix pattern: detect and fix for name, param in model.named_parameters(): std = param.data.std().item() if std > 1.0 or std < 1e-4: print(f"Warning: {name} has std={std:.4f}") nn.init.normal_(param.data, 0, 0.02) # Re-init
Quickstart: Using HuggingFace models for fine-tuning? Don't touch initialization — it's already correct. Only worry about init when building a custom architecture from scratch.
SECTION 07

Initialization in modern architectures

Transformers don't strictly follow He or Xavier rules—instead, they often use small uniform initialization for most weights and scaled initialization for attention. GPT uses U(−0.02, 0.02); BERT uses truncated normal with std=0.02. Attention output weights are often scaled by 1/sqrt(num_heads) to prevent score inflation. Layer norms always start at identity (gamma=1, beta=0).

Output embeddings (in language models) are sometimes tied to input embeddings, which means initialization of one affects the other. Care is needed to avoid double-initialization bugs. Residual connections reduce the importance of intermediate initialization because gradients skip layers, but the first layer and the final output layer still matter.

Debugging initialization issues

If your model trains poorly from scratch, suspicious initialization is often the culprit. Check: (1) activation distribution at layer 0 (should be ~N(0,1) pre-activation), (2) gradient magnitudes across layers (should be roughly equal, not decaying or exploding), (3) loss trajectory in the first 10 steps (should show steady decrease, not blow up or stall). A simple audit: initialize the model, run one batch, and print layer-wise activation/gradient statistics.

Red FlagLikely CauseFix
Activation scale explodes (>1000) in early layersWeights too largeScale init by 1/sqrt(fan_in)
Gradients vanish (magnitude <1e-5) at deep layersPoor residual flow or bad layer norm initEnsure residuals are present, check gamma/beta
Loss stays flat for 1000 steps, then crashesInitialization fine but learning rate too high or unstable architectureReduce LR, check residual scales
Model underfits but single-layer baseline overfitsDeeper model's init may favor solution space that's too narrowIncrease init scale slightly or use better parameterization

Initialization in practice: PyTorch's `nn.Linear` defaults to uniform initialization; TensorFlow's Dense layer uses Glorot uniform. For most modern networks (ResNets, Transformers, Vision Transformers), these defaults work well. But custom layers—attention heads, gating mechanisms, adaptive pools—often need tuning. When in doubt, start with Xavier/Glorot; if training is unstable, try He initialization. For layers with skip connections, slightly larger initialization (scaled up by sqrt(2)) can help because gradients skip the layer.

The initialization lottery is real: two randomly initialized networks with the same architecture, data, and hyperparameters may converge to different solutions with different generalization. Good initialization reduces variance but doesn't eliminate it. For research, reporting multiple runs from different random seeds is standard; for production, ensemble diverse initializations to improve robustness. Some practitioners use "warm start" initialization: copy weights from a similar pre-trained model rather than random init, which almost always converges faster.

Initialization and generalization: Recent research suggests initialization affects not just convergence speed but also generalization. Models initialized with larger weights sometimes converge faster but generalize worse (higher test error). Initialization also interacts with regularization: with strong regularization (high L2 penalty), careful initialization matters less. This suggests the best initialization strategy depends on your specific problem: easy tasks need less tuning, hard tasks need more.

Modern initialization in practice: Most practitioners use framework defaults and adjust only if training is visibly unstable. A simple sanity check: plot the activation distribution (mean and std) per layer after one forward pass. Ideal: activations have mean ~0 and std ~1 pre-activation, std ~sqrt(2) post-activation (accounting for ReLU dead neurons). If you see explosions (std > 5) or attenuation (std < 0.1), re-initialize.

For very deep networks (50+ layers), initialization is critical. Techniques like careful residual scaling, layer-wise adaptive initialization, and skip-connection initialization become necessary. Vision Transformers and large language models use careful initialization because their depth and width require it. Smaller networks (< 10 layers) are forgiving and work with any reasonable initialization.

Initialization for transfer learning: When fine-tuning pre-trained models, you start with learned (non-random) weights. New layers (added on top of pre-trained backbone) need initialization. Best practice: use small random initialization for new layers so they don't interfere with pre-trained weights. If fine-tuning is slow, try slightly larger initialization or lower learning rates for new layers relative to pre-trained weights.

Adapter modules (like LoRA) are initialized to near-zero to ensure backward compatibility: at the start of fine-tuning, adapter weights contribute almost nothing, and the model behaves like the original. As training progresses, adapter weights grow. This initialization strategy (called "zero initialization" or "identity initialization") is crucial for stable fine-tuning.

Initialization for pruning: if you plan to prune weights after training, some research suggests initializing differently (larger variance to create more variation) helps pruning find better masks. However, this is active research and not yet standard practice. For now, stick with framework defaults unless you're specifically optimizing for pruning efficiency.