Multi-Head Attention
Why multiple heads let the model learn different relationship types; splitting Q/K/V into h heads; concat and project; head dimension = d_model/h; Python implementation.
Why Single-Head Attention Is Not Enough
Single-head attention produces one set of attention weights β one pattern per token pair. In practice, a sentence requires many simultaneous patterns:
- Syntactic agreement: subject and verb
- Coreference: pronoun and antecedent
- Semantic proximity: words with related meaning
- Local structure: adjacent words
Multi-head attention runs h attention operations in parallel, each with its own learned projections. Each head can specialise in a different type of relationship.
The Formula
Given h heads and d_model-dimensional input:
head_i = Attention(Q @ W_Q_i, K @ W_K_i, V @ W_V_i)
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) @ W_OWhere:
W_Q_i,W_K_iare(d_model, d_k)β each withd_k = d_model / hW_V_iis(d_model, d_v)β each withd_v = d_model / hW_Ois(h Γ d_v, d_model)β projects concatenated heads back tod_model
The key constraint: d_k Γ h = d_model, so the total parameter count stays comparable to a single large attention with d_k = d_model.
Head Dimension: d_model / h
| d_model | h (heads) | d_k per head | |---------|-----------|--------------| | 512 | 8 | 64 | | 768 | 12 | 64 | | 1024 | 16 | 64 | | 4096 | 32 | 128 |
Most models land on d_k = 64 or d_k = 128. Notice: GPT-3 with d_model = 12288 and h = 96 gives d_k = 128.
Python Implementation (NumPy, from Scratch)
import numpy as np
import math
def softmax(x, axis=-1):
e = np.exp(x - x.max(axis=axis, keepdims=True))
return e / e.sum(axis=axis, keepdims=True)
def scaled_dot_product(Q, K, V, mask=None):
d_k = Q.shape[-1]
scores = Q @ K.swapaxes(-1, -2) / math.sqrt(d_k)
if mask is not None:
scores = np.where(mask, -1e9, scores)
return softmax(scores) @ V
class MultiHeadAttentionNumPy:
"""Multi-head attention using only NumPy β educational implementation."""
def __init__(self, d_model: int, h: int, rng=None):
assert d_model % h == 0, "d_model must be divisible by h"
rng = rng or np.random.default_rng(0)
self.h = h
self.d_k = d_model // h
scale = 0.02
# Stack all head projections into one matrix for efficiency
self.W_Q = rng.standard_normal((d_model, d_model)) * scale
self.W_K = rng.standard_normal((d_model, d_model)) * scale
self.W_V = rng.standard_normal((d_model, d_model)) * scale
self.W_O = rng.standard_normal((d_model, d_model)) * scale
def forward(self, x, mask=None):
"""
x: (batch, seq_len, d_model)
Returns: (batch, seq_len, d_model)
"""
B, T, D = x.shape
h, d_k = self.h, self.d_k
# Project: (B, T, D)
Q = x @ self.W_Q # (B, T, D)
K = x @ self.W_K
V = x @ self.W_V
# Split into heads: (B, T, h, d_k) β (B, h, T, d_k)
def split_heads(t):
return t.reshape(B, T, h, d_k).transpose(0, 2, 1, 3)
Q = split_heads(Q) # (B, h, T, d_k)
K = split_heads(K)
V = split_heads(V)
# Attend (per head)
out = scaled_dot_product(Q, K, V, mask) # (B, h, T, d_k)
# Concatenate heads: (B, T, h*d_k) = (B, T, D)
out = out.transpose(0, 2, 1, 3).reshape(B, T, D)
# Final projection
return out @ self.W_O
# Test
rng = np.random.default_rng(42)
mha = MultiHeadAttentionNumPy(d_model=128, h=8)
x = rng.standard_normal((2, 10, 128))
out = mha.forward(x)
print("Output shape:", out.shape) # (2, 10, 128)PyTorch Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
"""
Multi-head scaled dot-product attention.
Efficient implementation that fuses all heads into a single matmul.
"""
def __init__(self, d_model: int, h: int, dropout: float = 0.0):
super().__init__()
assert d_model % h == 0
self.h = h
self.d_k = d_model // h
# Fused projection: compute all heads at once
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def _split_heads(self, t: torch.Tensor) -> torch.Tensor:
"""(B, T, D) β (B, h, T, d_k)"""
B, T, D = t.shape
return t.view(B, T, self.h, self.d_k).transpose(1, 2)
def _merge_heads(self, t: torch.Tensor) -> torch.Tensor:
"""(B, h, T, d_k) β (B, T, D)"""
B, h, T, d_k = t.shape
return t.transpose(1, 2).contiguous().view(B, T, h * d_k)
def forward(
self,
query: torch.Tensor, # (B, T_q, D)
key: torch.Tensor, # (B, T_k, D)
value: torch.Tensor, # (B, T_k, D)
mask: torch.Tensor = None,
):
Q = self._split_heads(self.W_Q(query)) # (B, h, T_q, d_k)
K = self._split_heads(self.W_K(key))
V = self._split_heads(self.W_V(value))
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
# mask: (B, 1, T_q, T_k) or (B, h, T_q, T_k)
scores = scores.masked_fill(mask, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
out = torch.matmul(attn_weights, V) # (B, h, T_q, d_k)
out = self._merge_heads(out) # (B, T_q, D)
return self.W_O(out), attn_weights
# Sanity check
mha = MultiHeadAttention(d_model=256, h=8)
q = torch.randn(4, 15, 256)
k = torch.randn(4, 20, 256)
v = torch.randn(4, 20, 256)
out, w = mha(q, k, v)
print("out shape:", out.shape) # (4, 15, 256)
print("attn shape:", w.shape) # (4, 8, 15, 20)What Each Head Learns
Interpretability research (e.g., Voita et al. 2019) found that different attention heads in trained BERT models specialise in:
- Positional heads: attend to adjacent tokens (local pattern)
- Syntactic heads: subject-verb agreement, noun-adjective agreement
- Rare token heads: attend to rare or unusual tokens
- Delimiters: attend to [CLS], [SEP], or period tokens
You can inspect heads by extracting the weight tensor from a trained model:
from transformers import BertModel, BertTokenizer
import torch
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased", output_attentions=True)
text = "The cat sat on the mat"
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# attentions: tuple of (1, 12, seq_len, seq_len) per layer
attentions = outputs.attentions
print(f"Layers: {len(attentions)}")
print(f"Layer 0 shape: {attentions[0].shape}") # (1, 12, 8, 8)
# Head 4 in layer 5 β inspect what it attends to
head_weights = attentions[5][0, 4].numpy()
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
print("Tokens:", tokens)
print("Head weights:\n", head_weights.round(3))Pruning Attention Heads
Heads can be pruned without major quality loss. Michel et al. 2019 found that in many tasks, more than half the heads in BERT could be removed with less than 1% accuracy drop. This is because:
- Not all heads learn distinct patterns.
- The output projection
W_Ocan compensate if one head collapses to near-uniform weights.
Pruning is done by zeroing out entire head blocks in W_Q, W_K, W_V, and masking the corresponding slice of W_O.
Parameter Count
For a transformer with d_model = 768, h = 12, n_layers = 12:
- Per-layer MHA params:
4 Γ d_modelΒ² = 4 Γ 768Β² = 2,359,296 - Total MHA params (12 layers): roughly 28M out of BERT-base's 110M
The remainder goes to feed-forward layers (roughly 56M) and embeddings.
Key Takeaways
hparallel attention heads each used_k = d_model / hdimensions.- Splitting and merging is a reshape + transpose β no extra parameters.
- The output projection
W_Omixes information across all heads. - Different heads empirically specialise in syntactic, positional, and semantic patterns.
- Heads can be pruned; not all are equally important for every task.
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.