Transformer Architecture Q&A · Lesson 23 of 23
Interview: What Do Different Attention Heads Learn?
Q: Why does multi-head attention use h heads instead of one large head?
One large head with d_model dimensions computes a single attention pattern — one way of weighting the sequence. Multiple smaller heads (each with d_model/h dimensions) compute h different attention patterns in parallel, allowing the model to simultaneously represent multiple types of relationships:
- Syntactic relationships (head → verb agreement)
- Semantic proximity (related concepts)
- Positional patterns (attending to next/previous token)
- Long-range coreference
The total parameter count is equivalent — distributing d_model into h heads doesn't add parameters, it creates representational diversity within the same budget.
Q: If we double the number of heads (keeping d_model fixed), what happens?
Each head's dimension dₖ = d_model/h halves. The total parameter count (Q, K, V, O matrices) stays the same. Each head now represents a lower-dimensional subspace, which can hurt expressivity per head.
In practice, the optimal h/d_model ratio is empirically tuned. BERT-base uses h=12, d_model=768, dₖ=64. Modern LLMs use larger dₖ (often 64-128) and adjust d_model rather than cramming many tiny heads.
Q: What is grouped-query attention (GQA) and why is it used?
In standard multi-head attention, each query head has its own K and V heads — h heads total for both Q and K/V. In GQA, multiple query heads share a single K/V head:
Standard MHA: h Q heads, h K heads, h V heads
GQA (g groups): h Q heads, g K heads, g V heads (g < h)
Each group of (h/g) Q heads shares one K/V pair
MQA (g=1): h Q heads, 1 K head, 1 V head
Extreme case: all Q heads share one K/VGQA reduces the KV cache by a factor of h/g, which is the primary memory bottleneck during inference with long contexts and large batch sizes. LLaMA 2 70B uses GQA; Mistral 7B uses GQA with 8 KV heads for 32 Q heads (4× reduction).
Q: Can you prune attention heads after training?
Yes. Studies (Michel et al., 2019; Voita et al., 2019) found:
- Some heads can be pruned with near-zero performance loss
- Different heads have different importance — syntactic and coreference heads tend to be more important
- Up to 70% of heads in some models can be removed with under 1% accuracy drop
Head importance is measured by gradient-based sensitivity or by masking heads and measuring loss change. Pruning reduces both KV cache size and attention computation.
Q: How does increasing d_model vs increasing depth affect model capacity?
Increasing d_model (width):
More parameters per layer
Each token has richer representation
More attention head capacity, larger FFN
Better in low-data regimes (each layer is more expressive)
Increasing depth (number of layers):
More sequential computation
More levels of abstraction
Better at compositional reasoning
Requires more data to utilise effectively
Scaling laws (Hoffmann et al., Chinchilla):
Optimal model size and token count scale together
Neither purely wide nor purely deep models are optimal
In practice: both scale together (GPT-3: 96 layers, d_model=12288)Q: What is Flash Attention?
FlashAttention (Dao et al., 2022) is an I/O-aware attention algorithm that avoids materialising the full n×n attention matrix in HBM (GPU high-bandwidth memory):
Standard attention:
1. Write scores S = Q·Kᵀ/√dₖ to HBM (O(n²) memory write)
2. Write softmax(S) to HBM (O(n²) memory write)
3. Read S, softmax(S) to compute output (O(n²) memory read)
Total HBM reads/writes: O(n²)
FlashAttention:
Tile Q, K, V into blocks in SRAM
Compute softmax incrementally without storing full S
Never write the n×n matrix to HBM
Total HBM reads/writes: O(n) — 5-20× memory reduction
Same output, same math — only implementation changesFlashAttention enables longer context windows and larger batch sizes without changing the model architecture.
Q: Describe the difference between sparse and dense attention.
Dense attention: every position attends to every other position
O(n²) complexity. Used by default in all standard transformers.
Sparse attention: each position attends to a SUBSET of positions
Examples:
Local window: attend to ±w neighbours only — O(n·w)
Strided: attend to every k-th token — O(n²/k)
Global + local: a few "global" tokens attend everywhere (Longformer)
Learned patterns: BigBird, Reformer — dynamically select keys
Trade-off: cheaper computation vs potential loss of long-range signals.
For most tasks, local context dominates — sparse works fine.
For tasks requiring very long-range reasoning, dense is safer.Interview Answer Template
"Multi-head attention uses h heads with dₖ = d_model/h dimensions each — same parameter count as one large head, but each head learns a different attention pattern. At large scale, grouped-query attention (GQA) reduces the KV cache by having multiple Q heads share one K/V head. Heads can be pruned post-training (studies show 50-70% removable with minimal loss). FlashAttention keeps complexity the same but reduces HBM memory from O(n²) to O(n) by tiling the computation in SRAM. Sparse attention (local windows, global tokens) trades expressivity for linear complexity."