Xavier, Kaiming, and scaled initialization strategies that prevent vanishing/exploding gradients and allow deep networks to train from scratch.
Bad initialization causes vanishing or exploding activations on the very first forward pass — before any training occurs. If activations collapse to zero or blow up to infinity in layer 1, gradients through the whole network are broken.
| Architecture | Activation | Weight Init |
|---|---|---|
| Transformer (GPT-style) | GELU/SwiGLU | Normal(0, 0.02), residual layers ÷ √(2L) |
| CNN (ResNet-style) | ReLU | Kaiming Normal, mode=fan_out |
| LSTM/RNN | Tanh | Orthogonal init for recurrent weights |
| MLP with tanh | Tanh/Sigmoid | Xavier Normal |
| Embedding layers | N/A | Normal(0, 0.02) or Uniform(-0.1, 0.1) |
| Output head (new task) | — | Normal(0, 0.02), bias=0 |
Transformers don't strictly follow He or Xavier rules—instead, they often use small uniform initialization for most weights and scaled initialization for attention. GPT uses U(−0.02, 0.02); BERT uses truncated normal with std=0.02. Attention output weights are often scaled by 1/sqrt(num_heads) to prevent score inflation. Layer norms always start at identity (gamma=1, beta=0).
Output embeddings (in language models) are sometimes tied to input embeddings, which means initialization of one affects the other. Care is needed to avoid double-initialization bugs. Residual connections reduce the importance of intermediate initialization because gradients skip layers, but the first layer and the final output layer still matter.
If your model trains poorly from scratch, suspicious initialization is often the culprit. Check: (1) activation distribution at layer 0 (should be ~N(0,1) pre-activation), (2) gradient magnitudes across layers (should be roughly equal, not decaying or exploding), (3) loss trajectory in the first 10 steps (should show steady decrease, not blow up or stall). A simple audit: initialize the model, run one batch, and print layer-wise activation/gradient statistics.
| Red Flag | Likely Cause | Fix |
|---|---|---|
| Activation scale explodes (>1000) in early layers | Weights too large | Scale init by 1/sqrt(fan_in) |
| Gradients vanish (magnitude <1e-5) at deep layers | Poor residual flow or bad layer norm init | Ensure residuals are present, check gamma/beta |
| Loss stays flat for 1000 steps, then crashes | Initialization fine but learning rate too high or unstable architecture | Reduce LR, check residual scales |
| Model underfits but single-layer baseline overfits | Deeper model's init may favor solution space that's too narrow | Increase init scale slightly or use better parameterization |
Initialization in practice: PyTorch's `nn.Linear` defaults to uniform initialization; TensorFlow's Dense layer uses Glorot uniform. For most modern networks (ResNets, Transformers, Vision Transformers), these defaults work well. But custom layers—attention heads, gating mechanisms, adaptive pools—often need tuning. When in doubt, start with Xavier/Glorot; if training is unstable, try He initialization. For layers with skip connections, slightly larger initialization (scaled up by sqrt(2)) can help because gradients skip the layer.
The initialization lottery is real: two randomly initialized networks with the same architecture, data, and hyperparameters may converge to different solutions with different generalization. Good initialization reduces variance but doesn't eliminate it. For research, reporting multiple runs from different random seeds is standard; for production, ensemble diverse initializations to improve robustness. Some practitioners use "warm start" initialization: copy weights from a similar pre-trained model rather than random init, which almost always converges faster.
Initialization and generalization: Recent research suggests initialization affects not just convergence speed but also generalization. Models initialized with larger weights sometimes converge faster but generalize worse (higher test error). Initialization also interacts with regularization: with strong regularization (high L2 penalty), careful initialization matters less. This suggests the best initialization strategy depends on your specific problem: easy tasks need less tuning, hard tasks need more.
Modern initialization in practice: Most practitioners use framework defaults and adjust only if training is visibly unstable. A simple sanity check: plot the activation distribution (mean and std) per layer after one forward pass. Ideal: activations have mean ~0 and std ~1 pre-activation, std ~sqrt(2) post-activation (accounting for ReLU dead neurons). If you see explosions (std > 5) or attenuation (std < 0.1), re-initialize.
For very deep networks (50+ layers), initialization is critical. Techniques like careful residual scaling, layer-wise adaptive initialization, and skip-connection initialization become necessary. Vision Transformers and large language models use careful initialization because their depth and width require it. Smaller networks (< 10 layers) are forgiving and work with any reasonable initialization.
Initialization for transfer learning: When fine-tuning pre-trained models, you start with learned (non-random) weights. New layers (added on top of pre-trained backbone) need initialization. Best practice: use small random initialization for new layers so they don't interfere with pre-trained weights. If fine-tuning is slow, try slightly larger initialization or lower learning rates for new layers relative to pre-trained weights.
Adapter modules (like LoRA) are initialized to near-zero to ensure backward compatibility: at the start of fine-tuning, adapter weights contribute almost nothing, and the model behaves like the original. As training progresses, adapter weights grow. This initialization strategy (called "zero initialization" or "identity initialization") is crucial for stable fine-tuning.
Initialization for pruning: if you plan to prune weights after training, some research suggests initializing differently (larger variance to create more variation) helps pruning find better masks. However, this is active research and not yet standard practice. For now, stick with framework defaults unless you're specifically optimizing for pruning efficiency.