Training Tech

Mixed Precision Training

Training with bfloat16/float16 precision reduces GPU memory by ~50% and speeds up training 1.5–2Γ— while maintaining model quality via fp32 master weights.

2Γ—
Speed Boost
50%
Memory Savings
bf16
Recommended

Table of Contents

SECTION 01

Why Mixed Precision?

Full-precision (float32) uses 4 bytes per value. Half-precision (float16/bfloat16) uses 2 bytes. Halving the memory footprint of activations and gradients lets you: use 2Γ— larger batch sizes, train larger models, or run on smaller GPUs.

SECTION 02

float16 vs bfloat16

float32float16bfloat16
Bits321616
Exponent bits858
Mantissa bits23107
Max value3.4e3865,5043.4e38
Overflow riskNoneHigh (gradients often > 65504)None
PrecisionHighMediumLow
Best forOptimizer statesInference on older GPUsTraining on A100/H100
Use bfloat16 for training. bfloat16 has the same range as float32 so gradients don't overflow. float16 requires loss scaling to prevent overflow. On A100/H100/TPU, always use bfloat16.
SECTION 03

PyTorch AMP

import torch from torch.cuda.amp import autocast, GradScaler device = "cuda" model = model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) # GradScaler only needed for float16 (not bfloat16) scaler = GradScaler() # For float16 for batch in dataloader: inputs, labels = batch["input_ids"].to(device), batch["labels"].to(device) optimizer.zero_grad() # Autocast: automatically choose fp16/bf16 for compute-heavy ops with autocast(device_type="cuda", dtype=torch.bfloat16): # or float16 outputs = model(inputs) loss = criterion(outputs, labels) # For bfloat16: standard backward (no scaling needed) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() # For float16 (with scaling): # scaler.scale(loss).backward() # scaler.unscale_(optimizer) # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # scaler.step(optimizer) # scaler.update()
SECTION 04

Loss Scaling

float16 can't represent gradients smaller than ~6e-5. Small gradients underflow to zero β€” parameters don't update. Loss scaling fixes this by multiplying the loss before backward, then dividing the gradients after.

from torch.cuda.amp import GradScaler scaler = GradScaler( init_scale=2**16, # Start with scale factor 65536 growth_factor=2.0, # Double scale if no overflow for growth_interval steps backoff_factor=0.5, # Halve scale if overflow detected growth_interval=2000, # Steps between scale increases enabled=True # Set False to disable (e.g., when using bfloat16) ) # Training loop with float16 loss = compute_loss(model, batch) scaler.scale(loss).backward() # Backward with scaled loss scaler.unscale_(optimizer) # Unscale gradients before clipping torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer) # Skips step if inf/nan detected scaler.update() # Adjust scale for next iteration # Monitor scale to detect issues print(f"Loss scale: {scaler.get_scale():.0f}") # Should stay high (>= 1024). If it keeps dropping β†’ gradient instability
SECTION 05

HuggingFace Integration

from transformers import TrainingArguments, Trainer # bf16 training β€” one flag args = TrainingArguments( output_dir="./output", bf16=True, # Use bfloat16 (recommended for A100/H100) # fp16=True, # Use float16 instead (older GPUs) per_device_train_batch_size=8, gradient_accumulation_steps=4, num_train_epochs=3, ) trainer = Trainer(model=model, args=args, ...) trainer.train() # Also useful: bf16_full_eval β€” run eval in bf16 too args = TrainingArguments( bf16=True, bf16_full_eval=True, # Speeds up evaluation ) # Accelerate for more control from accelerate import Accelerator accelerator = Accelerator(mixed_precision="bf16") model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) # Now everything is handled automatically
SECTION 06

Common Issues

import torch # Issue 1: NaN loss with float16 # Cause: gradient overflow (gradients > 65504) # Fix: switch to bfloat16, or increase scaler's init_scale # Issue 2: NaN loss with bfloat16 # Cause: numerical instability in layer norm with bfloat16 # Fix: use float32 for layer norm computation import torch.nn as nn class StableLayerNorm(nn.Module): def __init__(self, d_model): super().__init__() self.ln = nn.LayerNorm(d_model) def forward(self, x): # Upcast to float32 for LayerNorm, then back return self.ln(x.float()).to(x.dtype) # Issue 3: Loss not improving with mixed precision # Cause: accidentally running optimizer in fp16 # Fix: verify optimizer states are fp32 for group in optimizer.param_groups: for p in group["params"]: assert p.dtype == torch.bfloat16 or p.dtype == torch.float32 # Issue 4: Out of memory despite using bf16 # Cause: activation memory still dominated by large batch/seq # Fix: gradient checkpointing + smaller batch + grad accumulation
SECTION 07

Hardware support and framework integration

Mixed precision is well-supported on modern GPUs: NVIDIA (CUDA, cuDNN), AMD (ROCm), and Intel (Intel Extension for PyTorch). TPUs support bfloat16 natively and are excellent for large-scale training. Most PyTorch and JAX training code can enable mixed precision with a few lines: `torch.cuda.amp.autocast()` or `jax.experimental.allow_float64(False)`.

Framework support: PyTorch (via AMP), TensorFlow (mixed_precision.Policy), JAX (default if you avoid float64), and HuggingFace Transformers (via Trainer API). Pre-trained models are often published in float32, so you're responsible for converting and validating before production.

Common pitfalls and best practices

Loss scaling prevents gradient underflow in FP16 but adds tuning complexity; start with init_scale=65536 and let PyTorch's dynamic loss scaling adjust. Always validate numerical stability: log gradients, check for NaN/Inf, and compare a few steps of mixed-precision training to full precision. Some operations (softmax, layer norm) are sensitive to precision; running them in float32 ("keep_fp32_activation") is a safe practice.

Common IssueCauseFix
NaN gradientsLoss overflow in FP16Increase loss scale or use dynamic scaling
Slower convergenceGradient variance too highUse FP32 for layer norm, softmax
Training instabilityLearning rate too high relative to precisionReduce LR or use gradient clipping
Model divergence mid-trainingAccumulated rounding errorsCheckpoint in float32, resume, continue in FP16

Advanced mixed-precision patterns: Some practitioners use triple-precision: FP32 for forward pass (stability), FP16 gradients (memory), and FP32 weight updates (precision). Others use layer-wise precision: low-rank layers in FP16, high-rank layers in FP32. AutoFP16 frameworks automatically select precision per operation based on numerical sensitivity analysis. These advanced techniques squeeze the last few percentage points of performance but add significant complexity and debugging burden.

Production deployment often uses 8-bit quantization post-training: train in mixed FP16/FP32, then quantize the final model to int8 for inference. This gives you the best of both worlds: training stability and deployment efficiency. Tools like ONNX Runtime, TVM, and TensorRT handle mixed-precision inference, abstracting away low-level complexity. For edge devices, 4-bit and 2-bit quantization are gaining traction, though quality degrades if the model wasn't trained with quantization in mind.

Mixed precision in different frameworks: PyTorch's automatic mixed precision (AMP) is the most mature and easiest to adopt. TensorFlow's mixed_precision API is similar but requires explicit policy setup. JAX requires more manual control but is very flexible. HuggingFace Transformers abstracts the complexity awayβ€”just set `fp16=True` or `bf16=True` in the Trainer config. For production code, using framework defaults is recommended; only customize if you hit specific issues.

Hardware considerations: NVIDIA's Tensor Cores (A100, H100) excel at mixed precision; AMD's MI300 also supports it well. Older hardware (V100, T4) supports FP16 but may have lower throughput. Intel's Gaudi supports bfloat16 natively. If you're buying GPUs for mixed-precision training, A100 or H100 are the safe choices. For inference on edge (mobile, embedded), consider quantization (4-bit, 8-bit) alongside mixed precision for maximum efficiency.

Cost impact: training with mixed precision is about 30–50% faster than float32 and uses 2Γ— less memory. For a 10-day training run, this translates to 3–5 days saved and proportional cost savings. For inference, benefits are smaller but still meaningful (20–30% speedup on token generation). For large-scale deployments, mixed precision is financially justifiable.

Mixed precision in distributed training: When training across multiple GPUs (data parallel, distributed data parallel), mixed precision behavior is consistent: gradients are accumulated in float32, weights updated in float32, but forward/backward passes use FP16. Distributed setups add complexity: ensure all GPUs use the same loss scaling strategy, otherwise training diverges. Most frameworks handle this automatically, but it's worth double-checking in logs.

Gradient accumulation: if your batch size is limited by GPU memory, use gradient accumulation (compute gradients on small batches, accumulate them, update weights once every N batches). Combined with mixed precision, this allows effective batch size of 512+ on modest GPUs. Ensure loss scaling is applied correctly with accumulated gradients; frameworks like PyTorch handle this transparently.

Checkpoint and recovery: when saving checkpoints for resumption, save the optimizer state in float32 (not the accumulated FP16 state) to ensure correct recovery. Tools like FSDP (fully sharded data parallel) handle this automatically. For long training runs, periodic checkpoints in mixed precision are essential for fault tolerance.

Mixed precision across different hardware: TPU (Google Cloud) handles mixed precision exceptionally well and is often the target hardware for large-scale training. NVIDIA's approach (Tensor Cores, cuDNN) is well-optimized. AMD's ROCm implementation is improving but still behind NVIDIA. Intel's Gaudi supports it but is less common. For teams deploying internationally, ensure your mixed precision strategy works across target hardware. If training on TPU but inferencing on NVIDIA GPUs, ensure numerical compatibility (same precision formats, same libraries). Mixed precision isn't a free lunch: it requires careful implementation, validation, and monitoring. But for large models and long training runs, the gains (30–50% faster, 2Γ— lower memory) justify the effort.