Learnixo

LLMs Deep Dive · Lesson 6 of 24

Attention Mechanism: Math and Intuition

Attention Variant Taxonomy

Multi-Head Attention (MHA) — original:
  h Q heads, h K heads, h V heads
  KV cache: 2 × h × seq_len × d_head per layer

Multi-Query Attention (MQA) — extreme reduction:
  h Q heads, 1 K head, 1 V head (shared by all Q heads)
  KV cache: 2 × 1 × seq_len × d_head per layer
  Used by: Falcon, PaLM, Mistral (originally)

Grouped-Query Attention (GQA) — balanced:
  h Q heads, g K heads, g V heads  (g < h)
  Each K/V head is shared by h/g Q heads
  KV cache: 2 × g × seq_len × d_head per layer
  Used by: LLaMA 2 70B (g=8, h=64), Mistral 7B (g=8, h=32)

GQA in Detail

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model: int, num_q_heads: int, num_kv_heads: int):
        super().__init__()
        assert num_q_heads % num_kv_heads == 0
        self.num_q_heads  = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.num_groups   = num_q_heads // num_kv_heads
        self.d_head       = d_model // num_q_heads

        self.W_q = nn.Linear(d_model, num_q_heads  * self.d_head, bias=False)
        self.W_k = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
        self.W_v = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, mask=None):
        B, T, _ = x.shape
        Q = self.W_q(x).view(B, T, self.num_q_heads, self.d_head).transpose(1, 2)
        K = self.W_k(x).view(B, T, self.num_kv_heads, self.d_head).transpose(1, 2)
        V = self.W_v(x).view(B, T, self.num_kv_heads, self.d_head).transpose(1, 2)

        # Expand K/V to match Q heads
        K = K.repeat_interleave(self.num_groups, dim=1)  # (B, num_q_heads, T, d_head)
        V = V.repeat_interleave(self.num_groups, dim=1)

        scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_head)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        out = (F.softmax(scores, dim=-1) @ V)            # (B, num_q_heads, T, d_head)
        out = out.transpose(1, 2).reshape(B, T, -1)      # (B, T, d_model)
        return self.W_o(out)

KV Cache Memory Impact

For LLaMA 2 70B serving with max_seq_len=4096, batch_size=1:

MHA (hypothetical 70B with 64 heads):
  KV cache = 2 × 80 layers × 64 heads × 4096 × 128 dim × 2 bytes (fp16)
           = 2 × 80 × 64 × 4096 × 128 × 2 = 10.7 GB

GQA (actual 70B with 8 KV heads):
  KV cache = 2 × 80 × 8 × 4096 × 128 × 2 = 1.3 GB
           → 8× reduction from MHA

For batch_size=32:
  MHA: 342 GB — infeasible on a single A100 (80GB)
  GQA:  42 GB — fits in single A100 with model weights

FlashAttention

FlashAttention rewrites the attention kernel to avoid materialising the O(n²) attention matrix in HBM:

Problem: for seq_len=4096, the attention matrix is 4096×4096 = 16M entries
         at fp16 = 32MB per head × 32 heads × 80 layers = 82GB just for intermediates

FlashAttention solution:
  Process in tiles that fit in SRAM (fast cache)
  Compute softmax incrementally using the online softmax trick
  Never write the full score matrix to HBM

Result:
  Memory: O(n) instead of O(n²) for intermediates
  Speed: 2-4× faster wall-clock on A100 for typical seq lengths
  Exact: same output as standard attention (not an approximation)

Sliding Window Attention

Mistral 7B uses sliding window attention to handle long contexts efficiently:

Standard attention: position i attends to all positions 0..i
  Memory O(n²), slow for long sequences

Sliding window: position i attends to positions max(0, i-w)..i
  Window size w (e.g., 4096)
  Memory O(n·w), fast for any n

For n=32K, w=4096:
  Standard: 32K × 32K = 1B entries per head
  Sliding:  32K × 4096 = 131M entries — 8× less

Information beyond the window propagates through multiple layers:
  Layer 1: sees w positions
  Layer 2: sees w×2 positions (via layer 1's representations)
  Layer 24: effectively sees up to w×24 positions

Interview Answer

"Modern LLMs use three attention variants: MHA (one K/V head per Q head), MQA (one K/V head shared by all Q heads), and GQA (one K/V head per group of Q heads). GQA is the dominant choice — LLaMA 2 70B uses 8 KV heads for 64 Q heads, reducing the KV cache 8×. FlashAttention is a hardware-aware kernel that avoids materialising the O(n²) score matrix in HBM, giving O(n) memory for intermediates and 2-4× speedup — same output, not an approximation. For very long contexts, sliding window attention limits each position to a local window, with information propagating across layers."