Learnixo
Back to blog
AI Systemsintermediate

Flash Attention: IO-Aware Attention Algorithm

How Flash Attention reformulates self-attention to minimize GPU memory I/O, enabling 2-4x speedups and linear memory scaling for long sequences.

Asma Hafeez KhanMay 16, 20265 min read
TransformersFlash AttentionGPUOptimization
Share:𝕏

The Memory Bottleneck in Standard Attention

Standard scaled dot-product attention computes:

Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V

The bottleneck isn't the FLOPs β€” it's memory I/O. GPU memory has a hierarchy:

| Memory tier | Size | Bandwidth | |---|---|---| | SRAM (L1/shared memory) | ~20MB on A100 | ~19 TB/s | | HBM (GPU memory) | 40–80GB on A100 | ~2 TB/s | | CPU RAM | Hundreds of GB | ~50 GB/s |

Standard attention materializes the full NΓ—N attention matrix in HBM:

  1. Write Q, K, V to HBM
  2. Read Q and K from HBM β†’ compute QK^T β†’ write result to HBM
  3. Read QK^T from HBM β†’ apply softmax β†’ write to HBM
  4. Read softmax and V from HBM β†’ compute final output β†’ write to HBM

Each step requires reading from and writing to the slow HBM. For N=8192, the attention matrix is 8192Γ—8192Γ—2 bytes = 128MB β€” read and written multiple times.


The Flash Attention Algorithm

Flash Attention (Dao et al., 2022) fuses all attention operations into a single kernel that tiles computation to fit in SRAM. The key insight: instead of materializing the NΓ—N attention matrix, compute the output in small blocks that fit in the GPU's fast SRAM:

Tile Q into blocks of size B_r
Tile K, V into blocks of size B_c

For each Q block:
    For each K/V block:
        Load Q block, K block, V block into SRAM
        Compute partial attention scores (this block only)
        Update running max and sum for stable online softmax
        Accumulate output

Write final output from SRAM to HBM

This is mathematically equivalent to standard attention but requires only O(N) HBM memory instead of O(NΒ²).


Online Softmax: The Key Mathematical Insight

Tiling attention requires computing softmax over blocks without seeing all values at once. Flash Attention uses the log-sum-exp trick for numerically stable online softmax:

Python
import numpy as np

def online_softmax_update(
    running_max: float,
    running_sum: float,
    running_output: np.ndarray,
    new_scores: np.ndarray,  # New block of attention scores
    new_values: np.ndarray,  # Corresponding V block
) -> tuple[float, float, np.ndarray]:
    """Update running softmax statistics with a new block."""
    # New block's max
    block_max = new_scores.max()

    # Update global max
    new_global_max = max(running_max, block_max)

    # Rescale running statistics to new global max
    rescale_factor = np.exp(running_max - new_global_max)
    new_running_sum = rescale_factor * running_sum + np.exp(new_scores - new_global_max).sum()
    new_running_output = (
        rescale_factor * running_sum * running_output
        + np.exp(new_scores - new_global_max) @ new_values
    ) / new_running_sum

    return new_global_max, new_running_sum, new_running_output

This allows softmax to be computed incrementally across blocks without ever materializing the full NΓ—N matrix.


Flash Attention in Practice

PyTorch 2.0+ includes Flash Attention via scaled_dot_product_attention:

Python
import torch
import torch.nn.functional as F
import time

def benchmark_attention(seq_len: int, num_heads: int = 32, head_dim: int = 128):
    batch = 1
    device = "cuda"
    dtype = torch.float16

    q = torch.randn(batch, num_heads, seq_len, head_dim, device=device, dtype=dtype)
    k = torch.randn(batch, num_heads, seq_len, head_dim, device=device, dtype=dtype)
    v = torch.randn(batch, num_heads, seq_len, head_dim, device=device, dtype=dtype)

    # Standard attention (forces full materialization)
    start = time.time()
    with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
        out_standard = F.scaled_dot_product_attention(q, k, v, is_causal=True)
    torch.cuda.synchronize()
    standard_time = time.time() - start

    # Flash attention
    start = time.time()
    with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
        out_flash = F.scaled_dot_product_attention(q, k, v, is_causal=True)
    torch.cuda.synchronize()
    flash_time = time.time() - start

    print(f"seq_len={seq_len}: standard={standard_time*1000:.1f}ms, flash={flash_time*1000:.1f}ms, speedup={standard_time/flash_time:.1f}x")
    print(f"Outputs match: {torch.allclose(out_standard, out_flash, atol=1e-2)}")

# Flash Attention provides larger speedups for longer sequences
for n in [512, 2048, 8192]:
    benchmark_attention(n)

Enabling Flash Attention in Transformers

Python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Method 1: torch.compile (automatically uses Flash Attention)
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3-8B",
    torch_dtype=torch.float16,
    device_map="auto",
    attn_implementation="flash_attention_2",  # Explicit Flash Attention 2
)

# Method 2: Install flash-attn package for Flash Attention 2
# pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func

def flash_attention_forward(q, k, v, causal=True, dropout_p=0.0):
    """Direct Flash Attention 2 call."""
    # q, k, v: (batch, seq_len, heads, head_dim) β€” note different shape convention
    return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)

Flash Attention 2 and 3

Flash Attention 2 (2023): Better parallelism across sequence length dimension. Up to 2Γ— faster than Flash Attention 1. Supports GQA (grouped-query attention). Better GPU utilization on A100/H100.

Flash Attention 3 (2024): Designed specifically for H100 GPUs. Uses tensor core hardware for FP8 computation. Hides memory I/O latency by overlapping compute and data movement. Up to 2.6Γ— faster than Flash Attention 2 on H100.


Memory Complexity Comparison

Python
def memory_comparison(seq_len: int, d_model: int = 4096, dtype_bytes: int = 2):
    """Compare peak memory for standard vs Flash Attention."""
    # Standard attention: stores NΓ—N attention matrix per head per layer
    num_heads = 32
    head_dim = d_model // num_heads

    # Standard: (batch, heads, seq, seq) attention matrix
    standard_attn_bytes = num_heads * seq_len * seq_len * dtype_bytes
    # Plus Q, K, V, and output
    standard_total_gb = (standard_attn_bytes + 4 * num_heads * seq_len * head_dim * dtype_bytes) / 1e9

    # Flash Attention: tile size in SRAM (not dependent on N)
    sram_tile_kb = 64  # Approximate tile size in KB
    flash_total_gb = (4 * num_heads * seq_len * head_dim * dtype_bytes) / 1e9  # Just Q,K,V,O

    return {
        "seq_len": seq_len,
        "standard_gb": standard_total_gb,
        "flash_gb": flash_total_gb,
        "memory_saved_gb": standard_total_gb - flash_total_gb,
    }

for n in [4096, 32768, 131072]:
    m = memory_comparison(n)
    print(f"N={n:>7}: standard={m['standard_gb']:.1f}GB, flash={m['flash_gb']:.2f}GB, saved={m['memory_saved_gb']:.1f}GB")

When Flash Attention Helps Most

High benefit:

  • Long sequences (8k+ tokens) β€” quadratic scaling makes standard attention impractical
  • Training β€” backward pass through standard attention is even more memory-intensive
  • Small batch sizes β€” memory savings are most impactful when GPU is memory-constrained

Lower benefit:

  • Very short sequences (512 tokens) β€” the NΓ—N matrix is small anyway
  • Inference with batch size of 1 β€” simpler optimizations (like PyTorch compile) may suffice
  • CPUs β€” Flash Attention is CUDA-specific; the memory hierarchy is different on CPU

Flash Attention is now standard in virtually all production LLM training and serving infrastructure. If you're deploying on CUDA hardware, there's almost no reason not to use it.

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.