Learnixo

Transformer Architecture Q&A · Lesson 23 of 23

Interview: What Do Different Attention Heads Learn?

Q: Why does multi-head attention use h heads instead of one large head?

One large head with d_model dimensions computes a single attention pattern — one way of weighting the sequence. Multiple smaller heads (each with d_model/h dimensions) compute h different attention patterns in parallel, allowing the model to simultaneously represent multiple types of relationships:

  • Syntactic relationships (head → verb agreement)
  • Semantic proximity (related concepts)
  • Positional patterns (attending to next/previous token)
  • Long-range coreference

The total parameter count is equivalent — distributing d_model into h heads doesn't add parameters, it creates representational diversity within the same budget.


Q: If we double the number of heads (keeping d_model fixed), what happens?

Each head's dimension dₖ = d_model/h halves. The total parameter count (Q, K, V, O matrices) stays the same. Each head now represents a lower-dimensional subspace, which can hurt expressivity per head.

In practice, the optimal h/d_model ratio is empirically tuned. BERT-base uses h=12, d_model=768, dₖ=64. Modern LLMs use larger dₖ (often 64-128) and adjust d_model rather than cramming many tiny heads.


Q: What is grouped-query attention (GQA) and why is it used?

In standard multi-head attention, each query head has its own K and V heads — h heads total for both Q and K/V. In GQA, multiple query heads share a single K/V head:

Standard MHA:   h Q heads, h K heads, h V heads
GQA (g groups): h Q heads, g K heads, g V heads  (g < h)
  Each group of (h/g) Q heads shares one K/V pair

MQA (g=1):      h Q heads, 1 K head,  1 V head
  Extreme case: all Q heads share one K/V

GQA reduces the KV cache by a factor of h/g, which is the primary memory bottleneck during inference with long contexts and large batch sizes. LLaMA 2 70B uses GQA; Mistral 7B uses GQA with 8 KV heads for 32 Q heads (4× reduction).


Q: Can you prune attention heads after training?

Yes. Studies (Michel et al., 2019; Voita et al., 2019) found:

  • Some heads can be pruned with near-zero performance loss
  • Different heads have different importance — syntactic and coreference heads tend to be more important
  • Up to 70% of heads in some models can be removed with under 1% accuracy drop

Head importance is measured by gradient-based sensitivity or by masking heads and measuring loss change. Pruning reduces both KV cache size and attention computation.


Q: How does increasing d_model vs increasing depth affect model capacity?

Increasing d_model (width):
  More parameters per layer
  Each token has richer representation
  More attention head capacity, larger FFN
  Better in low-data regimes (each layer is more expressive)

Increasing depth (number of layers):
  More sequential computation
  More levels of abstraction
  Better at compositional reasoning
  Requires more data to utilise effectively

Scaling laws (Hoffmann et al., Chinchilla):
  Optimal model size and token count scale together
  Neither purely wide nor purely deep models are optimal
  In practice: both scale together (GPT-3: 96 layers, d_model=12288)

Q: What is Flash Attention?

FlashAttention (Dao et al., 2022) is an I/O-aware attention algorithm that avoids materialising the full n×n attention matrix in HBM (GPU high-bandwidth memory):

Standard attention:
  1. Write scores S = Q·Kᵀ/√dₖ to HBM  (O(n²) memory write)
  2. Write softmax(S) to HBM             (O(n²) memory write)
  3. Read S, softmax(S) to compute output (O(n²) memory read)
  Total HBM reads/writes: O(n²)

FlashAttention:
  Tile Q, K, V into blocks in SRAM
  Compute softmax incrementally without storing full S
  Never write the n×n matrix to HBM
  Total HBM reads/writes: O(n) — 5-20× memory reduction
  Same output, same math — only implementation changes

FlashAttention enables longer context windows and larger batch sizes without changing the model architecture.


Q: Describe the difference between sparse and dense attention.

Dense attention: every position attends to every other position
  O(n²) complexity. Used by default in all standard transformers.

Sparse attention: each position attends to a SUBSET of positions
  Examples:
    Local window: attend to ±w neighbours only — O(n·w)
    Strided: attend to every k-th token — O(n²/k)
    Global + local: a few "global" tokens attend everywhere (Longformer)
    Learned patterns: BigBird, Reformer — dynamically select keys

Trade-off: cheaper computation vs potential loss of long-range signals.
  For most tasks, local context dominates — sparse works fine.
  For tasks requiring very long-range reasoning, dense is safer.

Interview Answer Template

"Multi-head attention uses h heads with dₖ = d_model/h dimensions each — same parameter count as one large head, but each head learns a different attention pattern. At large scale, grouped-query attention (GQA) reduces the KV cache by having multiple Q heads share one K/V head. Heads can be pruned post-training (studies show 50-70% removable with minimal loss). FlashAttention keeps complexity the same but reduces HBM memory from O(n²) to O(n) by tiling the computation in SRAM. Sparse attention (local windows, global tokens) trades expressivity for linear complexity."