Training with bfloat16/float16 precision reduces GPU memory by ~50% and speeds up training 1.5β2Γ while maintaining model quality via fp32 master weights.
Full-precision (float32) uses 4 bytes per value. Half-precision (float16/bfloat16) uses 2 bytes. Halving the memory footprint of activations and gradients lets you: use 2Γ larger batch sizes, train larger models, or run on smaller GPUs.
| float32 | float16 | bfloat16 | |
|---|---|---|---|
| Bits | 32 | 16 | 16 |
| Exponent bits | 8 | 5 | 8 |
| Mantissa bits | 23 | 10 | 7 |
| Max value | 3.4e38 | 65,504 | 3.4e38 |
| Overflow risk | None | High (gradients often > 65504) | None |
| Precision | High | Medium | Low |
| Best for | Optimizer states | Inference on older GPUs | Training on A100/H100 |
float16 can't represent gradients smaller than ~6e-5. Small gradients underflow to zero β parameters don't update. Loss scaling fixes this by multiplying the loss before backward, then dividing the gradients after.
Mixed precision is well-supported on modern GPUs: NVIDIA (CUDA, cuDNN), AMD (ROCm), and Intel (Intel Extension for PyTorch). TPUs support bfloat16 natively and are excellent for large-scale training. Most PyTorch and JAX training code can enable mixed precision with a few lines: `torch.cuda.amp.autocast()` or `jax.experimental.allow_float64(False)`.
Framework support: PyTorch (via AMP), TensorFlow (mixed_precision.Policy), JAX (default if you avoid float64), and HuggingFace Transformers (via Trainer API). Pre-trained models are often published in float32, so you're responsible for converting and validating before production.
Loss scaling prevents gradient underflow in FP16 but adds tuning complexity; start with init_scale=65536 and let PyTorch's dynamic loss scaling adjust. Always validate numerical stability: log gradients, check for NaN/Inf, and compare a few steps of mixed-precision training to full precision. Some operations (softmax, layer norm) are sensitive to precision; running them in float32 ("keep_fp32_activation") is a safe practice.
| Common Issue | Cause | Fix |
|---|---|---|
| NaN gradients | Loss overflow in FP16 | Increase loss scale or use dynamic scaling |
| Slower convergence | Gradient variance too high | Use FP32 for layer norm, softmax |
| Training instability | Learning rate too high relative to precision | Reduce LR or use gradient clipping |
| Model divergence mid-training | Accumulated rounding errors | Checkpoint in float32, resume, continue in FP16 |
Advanced mixed-precision patterns: Some practitioners use triple-precision: FP32 for forward pass (stability), FP16 gradients (memory), and FP32 weight updates (precision). Others use layer-wise precision: low-rank layers in FP16, high-rank layers in FP32. AutoFP16 frameworks automatically select precision per operation based on numerical sensitivity analysis. These advanced techniques squeeze the last few percentage points of performance but add significant complexity and debugging burden.
Production deployment often uses 8-bit quantization post-training: train in mixed FP16/FP32, then quantize the final model to int8 for inference. This gives you the best of both worlds: training stability and deployment efficiency. Tools like ONNX Runtime, TVM, and TensorRT handle mixed-precision inference, abstracting away low-level complexity. For edge devices, 4-bit and 2-bit quantization are gaining traction, though quality degrades if the model wasn't trained with quantization in mind.
Mixed precision in different frameworks: PyTorch's automatic mixed precision (AMP) is the most mature and easiest to adopt. TensorFlow's mixed_precision API is similar but requires explicit policy setup. JAX requires more manual control but is very flexible. HuggingFace Transformers abstracts the complexity awayβjust set `fp16=True` or `bf16=True` in the Trainer config. For production code, using framework defaults is recommended; only customize if you hit specific issues.
Hardware considerations: NVIDIA's Tensor Cores (A100, H100) excel at mixed precision; AMD's MI300 also supports it well. Older hardware (V100, T4) supports FP16 but may have lower throughput. Intel's Gaudi supports bfloat16 natively. If you're buying GPUs for mixed-precision training, A100 or H100 are the safe choices. For inference on edge (mobile, embedded), consider quantization (4-bit, 8-bit) alongside mixed precision for maximum efficiency.
Cost impact: training with mixed precision is about 30β50% faster than float32 and uses 2Γ less memory. For a 10-day training run, this translates to 3β5 days saved and proportional cost savings. For inference, benefits are smaller but still meaningful (20β30% speedup on token generation). For large-scale deployments, mixed precision is financially justifiable.
Mixed precision in distributed training: When training across multiple GPUs (data parallel, distributed data parallel), mixed precision behavior is consistent: gradients are accumulated in float32, weights updated in float32, but forward/backward passes use FP16. Distributed setups add complexity: ensure all GPUs use the same loss scaling strategy, otherwise training diverges. Most frameworks handle this automatically, but it's worth double-checking in logs.
Gradient accumulation: if your batch size is limited by GPU memory, use gradient accumulation (compute gradients on small batches, accumulate them, update weights once every N batches). Combined with mixed precision, this allows effective batch size of 512+ on modest GPUs. Ensure loss scaling is applied correctly with accumulated gradients; frameworks like PyTorch handle this transparently.
Checkpoint and recovery: when saving checkpoints for resumption, save the optimizer state in float32 (not the accumulated FP16 state) to ensure correct recovery. Tools like FSDP (fully sharded data parallel) handle this automatically. For long training runs, periodic checkpoints in mixed precision are essential for fault tolerance.
Mixed precision across different hardware: TPU (Google Cloud) handles mixed precision exceptionally well and is often the target hardware for large-scale training. NVIDIA's approach (Tensor Cores, cuDNN) is well-optimized. AMD's ROCm implementation is improving but still behind NVIDIA. Intel's Gaudi supports it but is less common. For teams deploying internationally, ensure your mixed precision strategy works across target hardware. If training on TPU but inferencing on NVIDIA GPUs, ensure numerical compatibility (same precision formats, same libraries). Mixed precision isn't a free lunch: it requires careful implementation, validation, and monitoring. But for large models and long training runs, the gains (30β50% faster, 2Γ lower memory) justify the effort.