Flash Attention: IO-Aware Attention Algorithm
How Flash Attention reformulates self-attention to minimize GPU memory I/O, enabling 2-4x speedups and linear memory scaling for long sequences.
The Memory Bottleneck in Standard Attention
Standard scaled dot-product attention computes:
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) VThe bottleneck isn't the FLOPs β it's memory I/O. GPU memory has a hierarchy:
| Memory tier | Size | Bandwidth | |---|---|---| | SRAM (L1/shared memory) | ~20MB on A100 | ~19 TB/s | | HBM (GPU memory) | 40β80GB on A100 | ~2 TB/s | | CPU RAM | Hundreds of GB | ~50 GB/s |
Standard attention materializes the full NΓN attention matrix in HBM:
- Write Q, K, V to HBM
- Read Q and K from HBM β compute QK^T β write result to HBM
- Read QK^T from HBM β apply softmax β write to HBM
- Read softmax and V from HBM β compute final output β write to HBM
Each step requires reading from and writing to the slow HBM. For N=8192, the attention matrix is 8192Γ8192Γ2 bytes = 128MB β read and written multiple times.
The Flash Attention Algorithm
Flash Attention (Dao et al., 2022) fuses all attention operations into a single kernel that tiles computation to fit in SRAM. The key insight: instead of materializing the NΓN attention matrix, compute the output in small blocks that fit in the GPU's fast SRAM:
Tile Q into blocks of size B_r
Tile K, V into blocks of size B_c
For each Q block:
For each K/V block:
Load Q block, K block, V block into SRAM
Compute partial attention scores (this block only)
Update running max and sum for stable online softmax
Accumulate output
Write final output from SRAM to HBMThis is mathematically equivalent to standard attention but requires only O(N) HBM memory instead of O(NΒ²).
Online Softmax: The Key Mathematical Insight
Tiling attention requires computing softmax over blocks without seeing all values at once. Flash Attention uses the log-sum-exp trick for numerically stable online softmax:
import numpy as np
def online_softmax_update(
running_max: float,
running_sum: float,
running_output: np.ndarray,
new_scores: np.ndarray, # New block of attention scores
new_values: np.ndarray, # Corresponding V block
) -> tuple[float, float, np.ndarray]:
"""Update running softmax statistics with a new block."""
# New block's max
block_max = new_scores.max()
# Update global max
new_global_max = max(running_max, block_max)
# Rescale running statistics to new global max
rescale_factor = np.exp(running_max - new_global_max)
new_running_sum = rescale_factor * running_sum + np.exp(new_scores - new_global_max).sum()
new_running_output = (
rescale_factor * running_sum * running_output
+ np.exp(new_scores - new_global_max) @ new_values
) / new_running_sum
return new_global_max, new_running_sum, new_running_outputThis allows softmax to be computed incrementally across blocks without ever materializing the full NΓN matrix.
Flash Attention in Practice
PyTorch 2.0+ includes Flash Attention via scaled_dot_product_attention:
import torch
import torch.nn.functional as F
import time
def benchmark_attention(seq_len: int, num_heads: int = 32, head_dim: int = 128):
batch = 1
device = "cuda"
dtype = torch.float16
q = torch.randn(batch, num_heads, seq_len, head_dim, device=device, dtype=dtype)
k = torch.randn(batch, num_heads, seq_len, head_dim, device=device, dtype=dtype)
v = torch.randn(batch, num_heads, seq_len, head_dim, device=device, dtype=dtype)
# Standard attention (forces full materialization)
start = time.time()
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
out_standard = F.scaled_dot_product_attention(q, k, v, is_causal=True)
torch.cuda.synchronize()
standard_time = time.time() - start
# Flash attention
start = time.time()
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
out_flash = F.scaled_dot_product_attention(q, k, v, is_causal=True)
torch.cuda.synchronize()
flash_time = time.time() - start
print(f"seq_len={seq_len}: standard={standard_time*1000:.1f}ms, flash={flash_time*1000:.1f}ms, speedup={standard_time/flash_time:.1f}x")
print(f"Outputs match: {torch.allclose(out_standard, out_flash, atol=1e-2)}")
# Flash Attention provides larger speedups for longer sequences
for n in [512, 2048, 8192]:
benchmark_attention(n)Enabling Flash Attention in Transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Method 1: torch.compile (automatically uses Flash Attention)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3-8B",
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="flash_attention_2", # Explicit Flash Attention 2
)
# Method 2: Install flash-attn package for Flash Attention 2
# pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func
def flash_attention_forward(q, k, v, causal=True, dropout_p=0.0):
"""Direct Flash Attention 2 call."""
# q, k, v: (batch, seq_len, heads, head_dim) β note different shape convention
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)Flash Attention 2 and 3
Flash Attention 2 (2023): Better parallelism across sequence length dimension. Up to 2Γ faster than Flash Attention 1. Supports GQA (grouped-query attention). Better GPU utilization on A100/H100.
Flash Attention 3 (2024): Designed specifically for H100 GPUs. Uses tensor core hardware for FP8 computation. Hides memory I/O latency by overlapping compute and data movement. Up to 2.6Γ faster than Flash Attention 2 on H100.
Memory Complexity Comparison
def memory_comparison(seq_len: int, d_model: int = 4096, dtype_bytes: int = 2):
"""Compare peak memory for standard vs Flash Attention."""
# Standard attention: stores NΓN attention matrix per head per layer
num_heads = 32
head_dim = d_model // num_heads
# Standard: (batch, heads, seq, seq) attention matrix
standard_attn_bytes = num_heads * seq_len * seq_len * dtype_bytes
# Plus Q, K, V, and output
standard_total_gb = (standard_attn_bytes + 4 * num_heads * seq_len * head_dim * dtype_bytes) / 1e9
# Flash Attention: tile size in SRAM (not dependent on N)
sram_tile_kb = 64 # Approximate tile size in KB
flash_total_gb = (4 * num_heads * seq_len * head_dim * dtype_bytes) / 1e9 # Just Q,K,V,O
return {
"seq_len": seq_len,
"standard_gb": standard_total_gb,
"flash_gb": flash_total_gb,
"memory_saved_gb": standard_total_gb - flash_total_gb,
}
for n in [4096, 32768, 131072]:
m = memory_comparison(n)
print(f"N={n:>7}: standard={m['standard_gb']:.1f}GB, flash={m['flash_gb']:.2f}GB, saved={m['memory_saved_gb']:.1f}GB")When Flash Attention Helps Most
High benefit:
- Long sequences (8k+ tokens) β quadratic scaling makes standard attention impractical
- Training β backward pass through standard attention is even more memory-intensive
- Small batch sizes β memory savings are most impactful when GPU is memory-constrained
Lower benefit:
- Very short sequences (512 tokens) β the NΓN matrix is small anyway
- Inference with batch size of 1 β simpler optimizations (like PyTorch compile) may suffice
- CPUs β Flash Attention is CUDA-specific; the memory hierarchy is different on CPU
Flash Attention is now standard in virtually all production LLM training and serving infrastructure. If you're deploying on CUDA hardware, there's almost no reason not to use it.
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.