Learnixo
Back to blog
AI Systemsintermediate

Multi-Head Attention

Why multiple heads let the model learn different relationship types; splitting Q/K/V into h heads; concat and project; head dimension = d_model/h; Python implementation.

Asma Hafeez KhanMay 15, 20266 min read
TransformersMulti-Head AttentionAttention HeadsArchitecturePyTorch
Share:𝕏

Why Single-Head Attention Is Not Enough

Single-head attention produces one set of attention weights β€” one pattern per token pair. In practice, a sentence requires many simultaneous patterns:

  • Syntactic agreement: subject and verb
  • Coreference: pronoun and antecedent
  • Semantic proximity: words with related meaning
  • Local structure: adjacent words

Multi-head attention runs h attention operations in parallel, each with its own learned projections. Each head can specialise in a different type of relationship.

The Formula

Given h heads and d_model-dimensional input:

head_i = Attention(Q @ W_Q_i, K @ W_K_i, V @ W_V_i)

MultiHead(Q, K, V) = Concat(head_1, ..., head_h) @ W_O

Where:

  • W_Q_i, W_K_i are (d_model, d_k) β€” each with d_k = d_model / h
  • W_V_i is (d_model, d_v) β€” each with d_v = d_model / h
  • W_O is (h Γ— d_v, d_model) β€” projects concatenated heads back to d_model

The key constraint: d_k Γ— h = d_model, so the total parameter count stays comparable to a single large attention with d_k = d_model.

Head Dimension: d_model / h

| d_model | h (heads) | d_k per head | |---------|-----------|--------------| | 512 | 8 | 64 | | 768 | 12 | 64 | | 1024 | 16 | 64 | | 4096 | 32 | 128 |

Most models land on d_k = 64 or d_k = 128. Notice: GPT-3 with d_model = 12288 and h = 96 gives d_k = 128.

Python Implementation (NumPy, from Scratch)

Python
import numpy as np
import math


def softmax(x, axis=-1):
    e = np.exp(x - x.max(axis=axis, keepdims=True))
    return e / e.sum(axis=axis, keepdims=True)


def scaled_dot_product(Q, K, V, mask=None):
    d_k = Q.shape[-1]
    scores = Q @ K.swapaxes(-1, -2) / math.sqrt(d_k)
    if mask is not None:
        scores = np.where(mask, -1e9, scores)
    return softmax(scores) @ V


class MultiHeadAttentionNumPy:
    """Multi-head attention using only NumPy β€” educational implementation."""

    def __init__(self, d_model: int, h: int, rng=None):
        assert d_model % h == 0, "d_model must be divisible by h"
        rng = rng or np.random.default_rng(0)
        self.h = h
        self.d_k = d_model // h

        scale = 0.02
        # Stack all head projections into one matrix for efficiency
        self.W_Q = rng.standard_normal((d_model, d_model)) * scale
        self.W_K = rng.standard_normal((d_model, d_model)) * scale
        self.W_V = rng.standard_normal((d_model, d_model)) * scale
        self.W_O = rng.standard_normal((d_model, d_model)) * scale

    def forward(self, x, mask=None):
        """
        x: (batch, seq_len, d_model)
        Returns: (batch, seq_len, d_model)
        """
        B, T, D = x.shape
        h, d_k = self.h, self.d_k

        # Project: (B, T, D)
        Q = x @ self.W_Q   # (B, T, D)
        K = x @ self.W_K
        V = x @ self.W_V

        # Split into heads: (B, T, h, d_k) β†’ (B, h, T, d_k)
        def split_heads(t):
            return t.reshape(B, T, h, d_k).transpose(0, 2, 1, 3)

        Q = split_heads(Q)   # (B, h, T, d_k)
        K = split_heads(K)
        V = split_heads(V)

        # Attend (per head)
        out = scaled_dot_product(Q, K, V, mask)   # (B, h, T, d_k)

        # Concatenate heads: (B, T, h*d_k) = (B, T, D)
        out = out.transpose(0, 2, 1, 3).reshape(B, T, D)

        # Final projection
        return out @ self.W_O


# Test
rng = np.random.default_rng(42)
mha = MultiHeadAttentionNumPy(d_model=128, h=8)
x = rng.standard_normal((2, 10, 128))
out = mha.forward(x)
print("Output shape:", out.shape)   # (2, 10, 128)

PyTorch Implementation

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class MultiHeadAttention(nn.Module):
    """
    Multi-head scaled dot-product attention.
    Efficient implementation that fuses all heads into a single matmul.
    """

    def __init__(self, d_model: int, h: int, dropout: float = 0.0):
        super().__init__()
        assert d_model % h == 0
        self.h = h
        self.d_k = d_model // h

        # Fused projection: compute all heads at once
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def _split_heads(self, t: torch.Tensor) -> torch.Tensor:
        """(B, T, D) β†’ (B, h, T, d_k)"""
        B, T, D = t.shape
        return t.view(B, T, self.h, self.d_k).transpose(1, 2)

    def _merge_heads(self, t: torch.Tensor) -> torch.Tensor:
        """(B, h, T, d_k) β†’ (B, T, D)"""
        B, h, T, d_k = t.shape
        return t.transpose(1, 2).contiguous().view(B, T, h * d_k)

    def forward(
        self,
        query: torch.Tensor,   # (B, T_q, D)
        key: torch.Tensor,     # (B, T_k, D)
        value: torch.Tensor,   # (B, T_k, D)
        mask: torch.Tensor = None,
    ):
        Q = self._split_heads(self.W_Q(query))   # (B, h, T_q, d_k)
        K = self._split_heads(self.W_K(key))
        V = self._split_heads(self.W_V(value))

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            # mask: (B, 1, T_q, T_k) or (B, h, T_q, T_k)
            scores = scores.masked_fill(mask, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        out = torch.matmul(attn_weights, V)        # (B, h, T_q, d_k)
        out = self._merge_heads(out)               # (B, T_q, D)
        return self.W_O(out), attn_weights


# Sanity check
mha = MultiHeadAttention(d_model=256, h=8)
q = torch.randn(4, 15, 256)
k = torch.randn(4, 20, 256)
v = torch.randn(4, 20, 256)
out, w = mha(q, k, v)
print("out shape:", out.shape)   # (4, 15, 256)
print("attn shape:", w.shape)    # (4, 8, 15, 20)

What Each Head Learns

Interpretability research (e.g., Voita et al. 2019) found that different attention heads in trained BERT models specialise in:

  • Positional heads: attend to adjacent tokens (local pattern)
  • Syntactic heads: subject-verb agreement, noun-adjective agreement
  • Rare token heads: attend to rare or unusual tokens
  • Delimiters: attend to [CLS], [SEP], or period tokens

You can inspect heads by extracting the weight tensor from a trained model:

Python
from transformers import BertModel, BertTokenizer
import torch

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased", output_attentions=True)

text = "The cat sat on the mat"
inputs = tokenizer(text, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

# attentions: tuple of (1, 12, seq_len, seq_len) per layer
attentions = outputs.attentions
print(f"Layers: {len(attentions)}")
print(f"Layer 0 shape: {attentions[0].shape}")   # (1, 12, 8, 8)

# Head 4 in layer 5 β€” inspect what it attends to
head_weights = attentions[5][0, 4].numpy()
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
print("Tokens:", tokens)
print("Head weights:\n", head_weights.round(3))

Pruning Attention Heads

Heads can be pruned without major quality loss. Michel et al. 2019 found that in many tasks, more than half the heads in BERT could be removed with less than 1% accuracy drop. This is because:

  1. Not all heads learn distinct patterns.
  2. The output projection W_O can compensate if one head collapses to near-uniform weights.

Pruning is done by zeroing out entire head blocks in W_Q, W_K, W_V, and masking the corresponding slice of W_O.

Parameter Count

For a transformer with d_model = 768, h = 12, n_layers = 12:

  • Per-layer MHA params: 4 Γ— d_modelΒ² = 4 Γ— 768Β² = 2,359,296
  • Total MHA params (12 layers): roughly 28M out of BERT-base's 110M

The remainder goes to feed-forward layers (roughly 56M) and embeddings.

Key Takeaways

  • h parallel attention heads each use d_k = d_model / h dimensions.
  • Splitting and merging is a reshape + transpose β€” no extra parameters.
  • The output projection W_O mixes information across all heads.
  • Different heads empirically specialise in syntactic, positional, and semantic patterns.
  • Heads can be pruned; not all are equally important for every task.

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.