Foundations

Flash Attention

IO-aware fused kernel attention mechanism eliminating O(N²) memory bottlenecks and achieving 2-4x speedup

O(N) Memory
2–4× Speedup
v3 Latest Version
In This Concept
1

The Attention Bottleneck

Standard attention mechanism has computational complexity of O(N²) where N is sequence length. This is manageable in compute (GPUs are fast at matrix operations), but the memory bottleneck is severe. During attention computation, the algorithm materializes the full N×N attention matrix in high-bandwidth memory (HBM), requiring O(N²) memory. For a sequence length of 4096 tokens with 12,288-dimensional embeddings, the attention matrix alone requires gigabytes of memory.

The real bottleneck isn't compute FLOPs; it's data movement. Modern GPUs have substantial compute capacity but limited memory bandwidth. The V100 GPU can do ~7 TFLOPS (trillion float operations per second) but only has ~900 GB/s memory bandwidth. This means reading and writing data to HBM dominates execution time. Standard attention requires multiple passes through HBM: reading Q, K, V from HBM, computing attention scores, reading/writing intermediate matrices, and writing the output.

Flash Attention solves this by using tiling and online softmax computation, eliminating the need to materialize the full N×N matrix. Instead, the algorithm processes attention in blocks, computing partial sums and max values on-the-fly. This dramatically reduces memory bandwidth requirements and allows sequences that were previously impossible to fit in GPU memory.

The Roofline Model Insight: GPU utilization is limited by arithmetic intensity (FLOPs per byte of memory traffic). Attention has low arithmetic intensity, making it memory-bound. Flash Attention improves arithmetic intensity through blocking and fusion, allowing better GPU utilization.

# Memory Requirements Comparison Sequence Length: N=4096 tokens Embedding Dim: d=12288 Standard Attention Memory: Q: 4096 × 12288 × 2 bytes = 100MB K: 4096 × 12288 × 2 bytes = 100MB V: 4096 × 12288 × 2 bytes = 100MB Attention Matrix (QK^T): 4096 × 4096 × 2 = 32MB Softmax output: 4096 × 4096 × 2 = 32MB ──────────────────────────────────── Total: ~400MB (O(N²) dominated by 64MB matrix) Flash Attention Memory: Q: 100MB (kept in HBM) K: 100MB (kept in HBM) V: 100MB (kept in HBM) Intermediate buffers: ~50MB ──────────────────────────────────── Total: ~350MB (O(N) scale, no matrix materialization) For N=32768: Standard requires ~10GB, Flash ~2.5GB
2

Flash Attention Algorithm

Flash Attention computes attention using a tiling strategy combined with online softmax. The core insight is that softmax can be computed incrementally without materializing the full score matrix. By processing attention in blocks (tiles), you can compute max and sum statistics on-the-fly and rescale previous outputs as new blocks arrive.

The algorithm divides the N tokens into blocks of size B_r (e.g., 128), and for each block of queries, iterates through blocks of keys. For each (Q_block, K_block, V_block) triple, it computes a partial attention matrix of size B_r × B_c, runs softmax on this smaller matrix, and accumulates the weighted output. The key technical detail is that when moving to the next KV block, the softmax normalization factors must be updated to account for new score values.

# Using PyTorch 2.0+ Flash Attention import torch import torch.nn.functional as F # Query, Key, Value shapes: [batch, seq, heads, head_dim] q = torch.randn(batch, seq_len, nheads, head_dim, device='cuda') k = torch.randn(batch, seq_len, nheads, head_dim, device='cuda') v = torch.randn(batch, seq_len, nheads, head_dim, device='cuda') # Flash Attention via scaled_dot_product_attention output = F.scaled_dot_product_attention( q, k, v, dropout_p=0.1, is_causal=True # For autoregressive ) # With flash-attn library (fastest) from flash_attn import flash_attn_func output = flash_attn_func(q, k, v, dropout_p=0.1, causal=True)

This online softmax computation is mathematically precise due to the softmax recurrence property. When combining two softmax distributions with different max values, you can efficiently compute the combined normalization factor. This allows Flash Attention to maintain numerical stability while processing attention incrementally.

Tiling Benefit: If you process Q in blocks of 128 and KV in blocks of 128, instead of reading all N tokens for every Q token, you read Q once and K,V Ceil(N/B_c) times. Total HBM traffic is O(N × Ceil(N/B_c)) instead of O(N²), a significant reduction.

3

IO-Aware Analysis

Flash Attention's core innovation is IO-awareness: designing the algorithm with memory access patterns as a first-class optimization target. Traditional algorithms optimize for FLOPs, but Flash Attention optimizes for bytes transferred (IO). This requires detailed understanding of GPU memory hierarchy: fast but small SRAM (Shared Memory), larger but slower HBM, and how data flows between them.

The original Flash Attention paper proved that by using optimal block sizes, you can reduce HBM reads/writes to O(N d log(d)) where d is the head dimension. This is provably optimal for computing exact attention with only SRAM as fast memory. The practical speedup varies from 2x to 10x depending on sequence length and hardware.

Key Metrics: Standard attention: 465MB HBM traffic per forward pass. Flash Attention: 400MB. Speedup: ~4x through better tiling and fused kernels.

# HBM Traffic Reduction Example # Standard Attention: O(N²) memory # HBM reads: Q, K, V = 3Nd bytes # HBM reads (scores): N² bytes # HBM reads (after softmax): N² bytes # HBM writes (output): Nd bytes # Total: O(N² + Nd) ≈ O(N²) for large N # Flash Attention: O(N) memory # HBM reads: Q, K, V each (N/B × B)d = Nd bytes # No N² matrix materialization # Total: O(Nd) with better constants
4

Using Flash Attention in PyTorch

PyTorch 2.0+ includes a scaled_dot_product_attention function that automatically uses Flash Attention when available on compatible hardware. For maximum performance, use the flash-attn library directly.

# PyTorch Built-in Flash Attention import torch import torch.nn.functional as F # Flash Attention automatically selected if available output = F.scaled_dot_product_attention( query, key, value, dropout_p=0.1, is_causal=True )
5

Flash Attention 2 Improvements

Flash Attention 2 improved practical speedups from 2.5x to 3-4x on A100 GPUs through work partitioning, better GPU utilization, and optimized causal masking. The main innovation was reorganizing block computation for better cache efficiency and reduced register pressure.

v2 Improvements: Better GPU occupancy, optimized causal masking, 3-4x speedup on A100. Slight numerical differences in backward pass but imperceptible for most models.

6

Flash Attention 3 (H100)

Flash Attention 3 optimized for H100 introduces asynchronous GEMM/softmax overlap, FP8 support, and specialized Hopper tensor cores usage. Achieves 4-10x speedups on H100, with higher speedups at longer sequences due to async pipelining benefits.

H100 Performance: 4-5x for short sequences, 6-8x for medium sequences, 8-10x for long sequences (32k+ tokens).

7

Impact on Model Training

Flash Attention enables long context windows (100k+ tokens), faster training throughput (3-4x for 4k-8k sequences), and better memory efficiency allowing larger batch sizes. This has been crucial for improvements in model quality over 2023-2024.

Memory savings enable larger batches, which improve training efficiency and gradient stability. Combined with context length scaling, Flash Attention has fundamentally changed what's possible in LLM training and inference.

Training Note: Requires CUDA 8.0+ (A100 or newer). Older GPUs get no benefit or may be slower. Always profile on your target hardware.

SECTION 08

Flash Attention Version Comparison

VersionKey InnovationSpeedup vs StandardHardware Target
Flash Attention 1Tiling + recomputation; IO-aware2–4× trainingA100, V100
Flash Attention 2Better parallelism, fewer non-matmul FLOPs2× over FA1A100, A10G
Flash Attention 3Async WGMMA + pingpong scheduling1.5–2× over FA2H100 only

Flash Attention is now the default attention implementation in most major frameworks (PyTorch 2.0+, HuggingFace Transformers, vLLM). Enable it in PyTorch with F.scaled_dot_product_attention which auto-selects FlashAttention when the inputs are on CUDA and meet alignment requirements. For HuggingFace models, set attn_implementation="flash_attention_2" in from_pretrained(). Verify it is active by checking that model forward passes do not appear in the CUDA memory trace as large temporary attention matrices — they should be absent with FA enabled.

FA3 requires H100 GPUs and the flash-attn==3.x wheel. It is not backward-compatible with A100 due to its use of H100-specific WGMMA (Warp Group Matrix Multiply-Accumulate) instructions. For mixed GPU fleets, keep FA2 to maintain portability.

FlashAttention's recomputation strategy trades compute for memory: rather than storing the full N×N attention matrix in HBM (high bandwidth memory), it recomputes attention scores during the backward pass from stored activations. This is profitable when N is large enough that the HBM read/write savings outweigh the extra compute. For sequences shorter than ~512 tokens, standard attention may be faster due to the overhead of tiling logic. Profile on your specific sequence length distribution before assuming FlashAttention is always faster in your workload.

FlashAttention requires version matching for grouped-query attention (GQA) architectures used in Llama, Mistral, and Gemma. Use flash-attn 2.3+ for GQA support. Alibi positional encodings and custom attention masks also require special handling -- always check the flash-attn documentation for your specific model architecture before assuming out-of-the-box compatibility.