Learnixo

Transformer Architecture Q&A · Lesson 5 of 23

Multi-Head Attention: Why Multiple Heads?

Why Multiple Heads?

Single-head attention computes one attention pattern — one way of weighting the input. But a token might need to attend to different positions for different reasons simultaneously:

"The bank by the river overflowed"
          ↑
Parsing "bank":
  Head 1: syntactic — attends to "The" (determiner) and "overflowed" (verb agreement)
  Head 2: semantic  — attends to "river" (disambiguates sense)
  Head 3: positional — attends to neighboring words
  Head 4: coreference — attends to pronouns or related nouns

Multiple heads allow the model to jointly attend to information from different representation subspaces at different positions simultaneously.


How Multi-Head Attention Works

MultiHead(Q, K, V) = Concat(head₁, ..., headₕ) · W_O

where headᵢ = Attention(Q·Wᵢᴬ, K·Wᵢᴷ, V·Wᵢᵛ)

Each head has its own Q, K, V projection matrices (Wᵢᴬ, Wᵢᴷ, Wᵢᵛ), and computes scaled dot-product attention independently. The outputs are concatenated and projected through W_O.

Dimensions (original Transformer, d_model=512, h=8 heads):
  dₖ = dᵥ = d_model / h = 512 / 8 = 64 per head

  Each Wᵢᴬ, Wᵢᴷ: (512, 64)
  Each Wᵢᵛ:       (512, 64)
  W_O:             (512, 512)

  Parameters per head: 3 × (512 × 64) = 98,304
  Total for all heads + W_O: 8 × 98,304 + 512² = 1,048,576

The total parameter count is the same as one large attention head — multiple heads don't add parameters, they redistribute them into parallel subspaces.


Code

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # One linear layer for all heads combined (more efficient than separate)
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch, seq_len, d_model)
        batch, seq_len, d_model = x.shape
        x = x.view(batch, seq_len, self.num_heads, self.d_k)
        return x.transpose(1, 2)
        # output: (batch, num_heads, seq_len, d_k)

    def forward(
        self,
        Q: torch.Tensor,
        K: torch.Tensor,
        V: torch.Tensor,
        mask: torch.Tensor = None
    ) -> torch.Tensor:
        batch = Q.size(0)

        # Project and split into heads
        Q = self.split_heads(self.W_q(Q))  # (batch, heads, seq, d_k)
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        # Scaled dot-product attention (all heads in parallel via batched matmul)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, V)  # (batch, heads, seq, d_k)

        # Concatenate heads
        out = out.transpose(1, 2).contiguous()
        out = out.view(batch, -1, self.num_heads * self.d_k)  # (batch, seq, d_model)

        return self.W_o(out)

What Different Heads Learn

Empirical findings from attention analysis (Voita et al., Clark et al.):

  • Syntactic heads: attend to grammatical relations (subject-verb, noun-adjective)
  • Positional heads: attend to fixed offsets (next token, previous token)
  • Rare token heads: disproportionately attend to rare or out-of-vocabulary tokens
  • Coreference heads: track pronoun references across long distances

Not all heads are equally important. Studies find that pruning many heads has minimal impact — different heads specialize, with some contributing much more than others.


Interview Answer

"Multi-head attention runs h parallel attention operations in separate subspaces (each dₖ = d_model/h dimensional), then concatenates and projects the results. This allows the model to simultaneously attend to different aspects of the input — one head might focus on syntactic relationships, another on semantic similarity, another on positional patterns. The total parameter count is equivalent to one large attention layer, but the parallel heads create representational diversity that a single head cannot achieve."