Neural Networks

LayerNorm / BatchNorm

Normalization layers that stabilize and speed up training. LayerNorm is standard in transformers; BatchNorm is common in CNNs.

LayerNorm
Transformers
BatchNorm
CNNs/vision
RMSNorm
LLaMA/Mistral

Table of Contents

SECTION 01

Why Normalize?

Deep network activations shift distribution as weights update ("internal covariate shift"). Later layers receive inputs with changing mean/variance — they constantly have to re-adapt. Normalization layers fix activations to a standard distribution, allowing much higher learning rates and deeper networks.

Empirical impact: Adding LayerNorm to a transformer roughly doubles training speed (can use 10× higher LR). It's not optional for deep networks — it's what makes them trainable.
SECTION 02

BatchNorm Explained

import torch import torch.nn as nn # BatchNorm normalizes over the BATCH dimension # For input (batch, channels, H, W): normalize each channel over the batch bn = nn.BatchNorm1d(num_features=512) # For (batch, features) bn = nn.BatchNorm2d(num_features=64) # For (batch, C, H, W) — CNNs # What it does: # 1. Compute mean and std over the batch (dim=0) # 2. Normalize: x_hat = (x - mean) / std # 3. Learnable scale/shift: y = gamma * x_hat + beta x = torch.randn(32, 512) # batch=32, features=512 out = bn(x) # Each of the 512 features is normalized to mean≈0, std≈1 (over the batch) # Problem with BatchNorm for LLMs: # - Requires large batch size (>16) to get stable batch statistics # - Hard to use with variable-length sequences # - Different behavior at train vs test time (uses running stats at test) # → That's why transformers use LayerNorm instead
SECTION 03

LayerNorm Explained

import torch import torch.nn as nn # LayerNorm normalizes over the FEATURE dimension (each token independently) # For input (batch, seq_len, d_model): normalize each token's feature vector ln = nn.LayerNorm(normalized_shape=768) x = torch.randn(32, 512, 768) # (batch, seq_len, d_model) out = ln(x) # Each of the 32*512=16384 token vectors is normalized to mean≈0, std≈1 # What it does per token: # mean = x.mean(dim=-1, keepdim=True) # mean over 768 features # std = x.std(dim=-1, keepdim=True) # std over 768 features # x_hat = (x - mean) / (std + 1e-5) # y = gamma * x_hat + beta # learnable scale/shift (768-dim each) # Advantages over BatchNorm for LLMs: # - Works with any batch size (even 1) # - Same behavior train and test # - Works with variable-length sequences # - Stabilizes training of deep transformers print(ln.weight.shape) # (768,) — gamma print(ln.bias.shape) # (768,) — beta
SECTION 04

RMSNorm in Modern LLMs

Root Mean Square Normalization (RMSNorm) drops the mean-centering step from LayerNorm. It's simpler, ~20% faster, and used in LLaMA, Mistral, and Gemma.

import torch import torch.nn as nn class RMSNorm(nn.Module): """RMSNorm as used in LLaMA-2/3, Mistral, Qwen, Gemma.""" def __init__(self, d_model: int, eps: float = 1e-5): super().__init__() self.weight = nn.Parameter(torch.ones(d_model)) # learnable scale self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: # Compute RMS (no mean subtraction) rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return self.weight * (x / rms) # Why RMSNorm over LayerNorm? # - Drops mean subtraction: ~20% faster # - Empirically similar quality to LayerNorm # - Less numerical instability in bfloat16 # HuggingFace already uses RMSNorm for LLaMA models from transformers import LlamaModel # model.model.norm is an LlamaRMSNorm instance
SECTION 05

Pre-Norm vs Post-Norm

import torch import torch.nn as nn # Post-norm (original transformer): normalize AFTER residual add # x = LayerNorm(x + Attn(x)) # Susceptible to gradient issues in deep networks # Pre-norm (modern standard): normalize BEFORE sub-layers # x = x + Attn(LayerNorm(x)) # More stable gradients → used in GPT-3, LLaMA, all modern LLMs class TransformerBlock_PreNorm(nn.Module): def __init__(self, d_model, n_heads): super().__init__() self.norm1 = nn.RMSNorm(d_model) if hasattr(nn, 'RMSNorm') else nn.LayerNorm(d_model) self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True) self.norm2 = nn.LayerNorm(d_model) self.ffn = nn.Sequential(nn.Linear(d_model, 4*d_model), nn.GELU(), nn.Linear(4*d_model, d_model)) def forward(self, x): # Pre-norm: LayerNorm BEFORE attention and FFN x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0] x = x + self.ffn(self.norm2(x)) return x
Use pre-norm. GPT-2 switched to pre-norm. Every modern LLM uses pre-norm. Post-norm requires careful learning rate warmup; pre-norm trains stably from the start.
SECTION 06

Implementation Gotchas

import torch import torch.nn as nn # Gotcha 1: eps matters in low precision ln = nn.LayerNorm(768, eps=1e-5) # Default fine for fp32 ln_stable = nn.LayerNorm(768, eps=1e-6) # May NaN in bfloat16 — use 1e-5 # Gotcha 2: Don't apply weight decay to norm parameters def get_param_groups(model): decay, no_decay = [], [] for name, param in model.named_parameters(): if "norm" in name or "bias" in name: no_decay.append(param) # LayerNorm weight/bias: no decay else: decay.append(param) return [{"params": decay, "weight_decay": 0.1}, {"params": no_decay, "weight_decay": 0.0}] optimizer = torch.optim.AdamW(get_param_groups(model), lr=1e-4) # Gotcha 3: Fused LayerNorm for speed (torch 2.x) ln_fused = nn.LayerNorm(768, elementwise_affine=True) # Or use Flash Attention's built-in fused layer norm
SECTION 07

Normalization Comparison

Norm Type Applies over Params (γ, β) Use Case
BatchNorm Batch dimension Per feature CNNs, older architectures
LayerNorm Feature dimension Per feature Transformers, RNNs
GroupNorm Groups of features Per group Small batch sizes
RMSNorm Feature dimension (simplified) Per feature (no β) Modern LLMs (efficient)
SECTION 08

Normalization Numerical Stability

import torch
import torch.nn as nn

# LayerNorm numerical stability
x = torch.randn(32, 512)
ln = nn.LayerNorm(512, eps=1e-6)

# Internally, LayerNorm does:
# y = (x - mean) / sqrt(var + eps) * weight + bias
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
normalized = (x - mean) / torch.sqrt(var + 1e-6)

# Why eps? Prevents division by zero when var is tiny
# Typical eps: 1e-5 (default) to 1e-6 (safer for mixed precision)

Layer normalization in transformers: The original "Attention is All You Need" paper used post-norm (normalization after the residual connection), but subsequent research found pre-norm (normalization before the sublayer) more stable for training. Pre-norm applies layer norm to the input before the self-attention or feed-forward, then adds the output directly to the original input. This simple change dramatically improves training stability without performance loss.

RMSNorm, used in LLaMA and Gemma models, simplifies LayerNorm by removing the bias term and centering step. It only scales by the RMS (root mean square) of activations. This subtle change reduces computation, improves numerical stability in bfloat16 precision, and often achieves comparable or better results than LayerNorm in transformer contexts. The equivalence between RMSNorm and LayerNorm under certain conditions is an active area of research.

Initialization and learning rates interact with normalization layers in subtle ways. Models with pre-norm or RMSNorm can often use higher learning rates safely because the normalization prevents activation collapse. Debugging training instability often involves checking whether layer norm is applied in the right location, using the right epsilon value for the precision level, and whether weight initialization is compatible with the norm variant in use.

EXTRA

Advanced Normalization Techniques

Beyond standard normalization, techniques like InstanceNorm (per-instance instead of per-batch) help in domains where the batch distribution matters less than individual sample statistics. GroupNorm divides channels into groups and normalizes within each group, providing stability with small batches. LayerNorm's equivalence to GroupNorm with a single group of all channels explains why it works so well across domains.

Normalization placement in neural networks matters profoundly. Pre-norm architectures (norm before the sublayer) are more stable than post-norm (norm after) for deep networks. This seemingly small detail enabled training of 100+ layer transformers without careful initialization. Discovering these architectural insights required extensive empirical research and spawned a whole subfield of neural architecture design.

Numerical precision interacts with normalization: in bfloat16 training, standard LayerNorm can sometimes lose precision during the variance computation. Implementations like PyTorch's LayerNorm handle this internally by computing variance in float32 even when inputs are bfloat16. Understanding these implementation details prevents silent numerical degradation in mixed-precision training.

Layer norm's behavior in very deep networks (100+ layers) differs from shallow networks due to accumulating layer interactions. Pre-norm placement helps stabilize very deep networks by normalizing inputs before transformation, preventing activation explosion or vanishing. Post-norm (normalization after residual) becomes increasingly unstable with depth. This architectural discovery enabled successful training of models like BERT and GPT without special initialization tricks.

In mixed-precision training (float32 loss, bfloat16 activations), layer norm's stability is crucial. The centering and scaling ensure activations remain in reasonable ranges despite precision loss. Some implementations compute statistics in float32 internally to prevent precision-related issues. Understanding when automatic mixed precision frameworks make these precision adjustments prevents silent numerical degradation.

Distributed training introduces additional layer norm considerations. Batch norm in distributed settings requires synchronizing statistics across devices, adding communication overhead. Layer norm's independence across samples makes it naturally distributed-friendly, computing statistics per sample, not across devices. This is one reason transformers (which use layer norm) have become favored for distributed training over older CNN architectures using batch norm.

BOOST

Normalization and Quantization Interactions

Quantizing neural networks (converting float32 to int8 or lower) requires careful handling of normalization layers. Layer norm and batch norm statistics ensure activations remain in reasonable ranges. Before quantizing, practitioners analyze activation ranges (min, max, percentiles) to set appropriate quantization thresholds. Normalization layers' statistical properties directly impact quantization accuracy. Models with no normalization suffer severe accuracy loss from quantization because activations have unbounded range.

Post-training quantization (quantizing a pretrained model without retraining) relies on normalization to keep activations well-behaved. Quantization-aware training involves training with simulated quantization to prepare for deployment. Normalization's role in both post-training and quantization-aware settings is essential. The combination of normalization and quantization enables efficient on-device AI.

Knowledge distillation from large models to small quantized models requires balancing multiple objectives: matching teacher outputs, maintaining quantization feasibility, and minimizing task loss. Normalization layers help bridge float32 teacher models and quantized student models by keeping activation ranges compatible. The interplay is complex but essential for modern efficient AI systems.