Multi-Head Attention
Why multi-head attention uses parallel heads, how heads are split and concatenated, what different heads learn, and the full architecture with code.
Why Multiple Heads?
Single-head attention computes one attention pattern — one way of weighting the input. But a token might need to attend to different positions for different reasons simultaneously:
"The bank by the river overflowed"
↑
Parsing "bank":
Head 1: syntactic — attends to "The" (determiner) and "overflowed" (verb agreement)
Head 2: semantic — attends to "river" (disambiguates sense)
Head 3: positional — attends to neighboring words
Head 4: coreference — attends to pronouns or related nounsMultiple heads allow the model to jointly attend to information from different representation subspaces at different positions simultaneously.
How Multi-Head Attention Works
MultiHead(Q, K, V) = Concat(head₁, ..., headₕ) · W_O
where headᵢ = Attention(Q·Wᵢᴬ, K·Wᵢᴷ, V·Wᵢᵛ)Each head has its own Q, K, V projection matrices (Wᵢᴬ, Wᵢᴷ, Wᵢᵛ), and computes scaled dot-product attention independently. The outputs are concatenated and projected through W_O.
Dimensions (original Transformer, d_model=512, h=8 heads):
dₖ = dᵥ = d_model / h = 512 / 8 = 64 per head
Each Wᵢᴬ, Wᵢᴷ: (512, 64)
Each Wᵢᵛ: (512, 64)
W_O: (512, 512)
Parameters per head: 3 × (512 × 64) = 98,304
Total for all heads + W_O: 8 × 98,304 + 512² = 1,048,576The total parameter count is the same as one large attention head — multiple heads don't add parameters, they redistribute them into parallel subspaces.
Code
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int):
super().__init__()
assert d_model % num_heads == 0
self.num_heads = num_heads
self.d_k = d_model // num_heads
# One linear layer for all heads combined (more efficient than separate)
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
# x: (batch, seq_len, d_model)
batch, seq_len, d_model = x.shape
x = x.view(batch, seq_len, self.num_heads, self.d_k)
return x.transpose(1, 2)
# output: (batch, num_heads, seq_len, d_k)
def forward(
self,
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
mask: torch.Tensor = None
) -> torch.Tensor:
batch = Q.size(0)
# Project and split into heads
Q = self.split_heads(self.W_q(Q)) # (batch, heads, seq, d_k)
K = self.split_heads(self.W_k(K))
V = self.split_heads(self.W_v(V))
# Scaled dot-product attention (all heads in parallel via batched matmul)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V) # (batch, heads, seq, d_k)
# Concatenate heads
out = out.transpose(1, 2).contiguous()
out = out.view(batch, -1, self.num_heads * self.d_k) # (batch, seq, d_model)
return self.W_o(out)What Different Heads Learn
Empirical findings from attention analysis (Voita et al., Clark et al.):
- Syntactic heads: attend to grammatical relations (subject-verb, noun-adjective)
- Positional heads: attend to fixed offsets (next token, previous token)
- Rare token heads: disproportionately attend to rare or out-of-vocabulary tokens
- Coreference heads: track pronoun references across long distances
Not all heads are equally important. Studies find that pruning many heads has minimal impact — different heads specialize, with some contributing much more than others.
Interview Answer
"Multi-head attention runs h parallel attention operations in separate subspaces (each dₖ = d_model/h dimensional), then concatenates and projects the results. This allows the model to simultaneously attend to different aspects of the input — one head might focus on syntactic relationships, another on semantic similarity, another on positional patterns. The total parameter count is equivalent to one large attention layer, but the parallel heads create representational diversity that a single head cannot achieve."
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.