Learnixo
Back to blog
AI Systemsintermediate

Interview Q&A: Attention Heads and Scaling

Common interview questions on multi-head attention design choices, head pruning, grouped-query attention, and how scaling affects head count and model capacity.

Asma Hafeez KhanMay 16, 20265 min read
TransformersMulti-Head AttentionGQAScalingInterview
Share:𝕏

Q: Why does multi-head attention use h heads instead of one large head?

One large head with d_model dimensions computes a single attention pattern — one way of weighting the sequence. Multiple smaller heads (each with d_model/h dimensions) compute h different attention patterns in parallel, allowing the model to simultaneously represent multiple types of relationships:

  • Syntactic relationships (head → verb agreement)
  • Semantic proximity (related concepts)
  • Positional patterns (attending to next/previous token)
  • Long-range coreference

The total parameter count is equivalent — distributing d_model into h heads doesn't add parameters, it creates representational diversity within the same budget.


Q: If we double the number of heads (keeping d_model fixed), what happens?

Each head's dimension dₖ = d_model/h halves. The total parameter count (Q, K, V, O matrices) stays the same. Each head now represents a lower-dimensional subspace, which can hurt expressivity per head.

In practice, the optimal h/d_model ratio is empirically tuned. BERT-base uses h=12, d_model=768, dₖ=64. Modern LLMs use larger dₖ (often 64-128) and adjust d_model rather than cramming many tiny heads.


Q: What is grouped-query attention (GQA) and why is it used?

In standard multi-head attention, each query head has its own K and V heads — h heads total for both Q and K/V. In GQA, multiple query heads share a single K/V head:

Standard MHA:   h Q heads, h K heads, h V heads
GQA (g groups): h Q heads, g K heads, g V heads  (g < h)
  Each group of (h/g) Q heads shares one K/V pair

MQA (g=1):      h Q heads, 1 K head,  1 V head
  Extreme case: all Q heads share one K/V

GQA reduces the KV cache by a factor of h/g, which is the primary memory bottleneck during inference with long contexts and large batch sizes. LLaMA 2 70B uses GQA; Mistral 7B uses GQA with 8 KV heads for 32 Q heads (4× reduction).


Q: Can you prune attention heads after training?

Yes. Studies (Michel et al., 2019; Voita et al., 2019) found:

  • Some heads can be pruned with near-zero performance loss
  • Different heads have different importance — syntactic and coreference heads tend to be more important
  • Up to 70% of heads in some models can be removed with under 1% accuracy drop

Head importance is measured by gradient-based sensitivity or by masking heads and measuring loss change. Pruning reduces both KV cache size and attention computation.


Q: How does increasing d_model vs increasing depth affect model capacity?

Increasing d_model (width):
  More parameters per layer
  Each token has richer representation
  More attention head capacity, larger FFN
  Better in low-data regimes (each layer is more expressive)

Increasing depth (number of layers):
  More sequential computation
  More levels of abstraction
  Better at compositional reasoning
  Requires more data to utilise effectively

Scaling laws (Hoffmann et al., Chinchilla):
  Optimal model size and token count scale together
  Neither purely wide nor purely deep models are optimal
  In practice: both scale together (GPT-3: 96 layers, d_model=12288)

Q: What is Flash Attention?

FlashAttention (Dao et al., 2022) is an I/O-aware attention algorithm that avoids materialising the full n×n attention matrix in HBM (GPU high-bandwidth memory):

Standard attention:
  1. Write scores S = Q·Kᵀ/√dₖ to HBM  (O(n²) memory write)
  2. Write softmax(S) to HBM             (O(n²) memory write)
  3. Read S, softmax(S) to compute output (O(n²) memory read)
  Total HBM reads/writes: O(n²)

FlashAttention:
  Tile Q, K, V into blocks in SRAM
  Compute softmax incrementally without storing full S
  Never write the n×n matrix to HBM
  Total HBM reads/writes: O(n) — 5-20× memory reduction
  Same output, same math — only implementation changes

FlashAttention enables longer context windows and larger batch sizes without changing the model architecture.


Q: Describe the difference between sparse and dense attention.

Dense attention: every position attends to every other position
  O(n²) complexity. Used by default in all standard transformers.

Sparse attention: each position attends to a SUBSET of positions
  Examples:
    Local window: attend to ±w neighbours only — O(n·w)
    Strided: attend to every k-th token — O(n²/k)
    Global + local: a few "global" tokens attend everywhere (Longformer)
    Learned patterns: BigBird, Reformer — dynamically select keys

Trade-off: cheaper computation vs potential loss of long-range signals.
  For most tasks, local context dominates — sparse works fine.
  For tasks requiring very long-range reasoning, dense is safer.

Interview Answer Template

"Multi-head attention uses h heads with dₖ = d_model/h dimensions each — same parameter count as one large head, but each head learns a different attention pattern. At large scale, grouped-query attention (GQA) reduces the KV cache by having multiple Q heads share one K/V head. Heads can be pruned post-training (studies show 50-70% removable with minimal loss). FlashAttention keeps complexity the same but reduces HBM memory from O(n²) to O(n) by tiling the computation in SRAM. Sparse attention (local windows, global tokens) trades expressivity for linear complexity."

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.