Attention in LLMs: Deep Dive
Multi-query, grouped-query, and multi-head attention variants in modern LLMs ā how they differ, their KV cache implications, and the FlashAttention implementation.
Attention Variant Taxonomy
Multi-Head Attention (MHA) ā original:
h Q heads, h K heads, h V heads
KV cache: 2 Ć h Ć seq_len Ć d_head per layer
Multi-Query Attention (MQA) ā extreme reduction:
h Q heads, 1 K head, 1 V head (shared by all Q heads)
KV cache: 2 Ć 1 Ć seq_len Ć d_head per layer
Used by: Falcon, PaLM, Mistral (originally)
Grouped-Query Attention (GQA) ā balanced:
h Q heads, g K heads, g V heads (g < h)
Each K/V head is shared by h/g Q heads
KV cache: 2 Ć g Ć seq_len Ć d_head per layer
Used by: LLaMA 2 70B (g=8, h=64), Mistral 7B (g=8, h=32)GQA in Detail
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model: int, num_q_heads: int, num_kv_heads: int):
super().__init__()
assert num_q_heads % num_kv_heads == 0
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.num_groups = num_q_heads // num_kv_heads
self.d_head = d_model // num_q_heads
self.W_q = nn.Linear(d_model, num_q_heads * self.d_head, bias=False)
self.W_k = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
self.W_v = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, mask=None):
B, T, _ = x.shape
Q = self.W_q(x).view(B, T, self.num_q_heads, self.d_head).transpose(1, 2)
K = self.W_k(x).view(B, T, self.num_kv_heads, self.d_head).transpose(1, 2)
V = self.W_v(x).view(B, T, self.num_kv_heads, self.d_head).transpose(1, 2)
# Expand K/V to match Q heads
K = K.repeat_interleave(self.num_groups, dim=1) # (B, num_q_heads, T, d_head)
V = V.repeat_interleave(self.num_groups, dim=1)
scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_head)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
out = (F.softmax(scores, dim=-1) @ V) # (B, num_q_heads, T, d_head)
out = out.transpose(1, 2).reshape(B, T, -1) # (B, T, d_model)
return self.W_o(out)KV Cache Memory Impact
For LLaMA 2 70B serving with max_seq_len=4096, batch_size=1:
MHA (hypothetical 70B with 64 heads):
KV cache = 2 Ć 80 layers Ć 64 heads Ć 4096 Ć 128 dim Ć 2 bytes (fp16)
= 2 Ć 80 Ć 64 Ć 4096 Ć 128 Ć 2 = 10.7 GB
GQA (actual 70B with 8 KV heads):
KV cache = 2 Ć 80 Ć 8 Ć 4096 Ć 128 Ć 2 = 1.3 GB
ā 8Ć reduction from MHA
For batch_size=32:
MHA: 342 GB ā infeasible on a single A100 (80GB)
GQA: 42 GB ā fits in single A100 with model weightsFlashAttention
FlashAttention rewrites the attention kernel to avoid materialising the O(n²) attention matrix in HBM:
Problem: for seq_len=4096, the attention matrix is 4096Ć4096 = 16M entries
at fp16 = 32MB per head Ć 32 heads Ć 80 layers = 82GB just for intermediates
FlashAttention solution:
Process in tiles that fit in SRAM (fast cache)
Compute softmax incrementally using the online softmax trick
Never write the full score matrix to HBM
Result:
Memory: O(n) instead of O(n²) for intermediates
Speed: 2-4Ć faster wall-clock on A100 for typical seq lengths
Exact: same output as standard attention (not an approximation)Sliding Window Attention
Mistral 7B uses sliding window attention to handle long contexts efficiently:
Standard attention: position i attends to all positions 0..i
Memory O(n²), slow for long sequences
Sliding window: position i attends to positions max(0, i-w)..i
Window size w (e.g., 4096)
Memory O(nĀ·w), fast for any n
For n=32K, w=4096:
Standard: 32K Ć 32K = 1B entries per head
Sliding: 32K Ć 4096 = 131M entries ā 8Ć less
Information beyond the window propagates through multiple layers:
Layer 1: sees w positions
Layer 2: sees wĆ2 positions (via layer 1's representations)
Layer 24: effectively sees up to wĆ24 positionsInterview Answer
"Modern LLMs use three attention variants: MHA (one K/V head per Q head), MQA (one K/V head shared by all Q heads), and GQA (one K/V head per group of Q heads). GQA is the dominant choice ā LLaMA 2 70B uses 8 KV heads for 64 Q heads, reducing the KV cache 8Ć. FlashAttention is a hardware-aware kernel that avoids materialising the O(n²) score matrix in HBM, giving O(n) memory for intermediates and 2-4Ć speedup ā same output, not an approximation. For very long contexts, sliding window attention limits each position to a local window, with information propagating across layers."
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.