The Attention Mechanism Explained
How attention computes Q, K, V; the dot-product attention formula; why it captures long-range dependencies; and a Python from-scratch implementation.
What Is Attention?
Before transformers, sequence models like LSTMs processed tokens one by one, compressing history into a fixed-size hidden state. By the time you reached token 50, information from token 1 was heavily diluted. Attention was introduced to solve this: every token can directly attend to every other token in a single step, regardless of distance.
The core intuition: given a query, look up the most relevant keys in a database, retrieve the corresponding values, and return a weighted sum of those values.
The Three Matrices: Q, K, V
Given an input sequence X of shape (seq_len, d_model), we project it into three separate spaces:
Q = X @ W_Q # Queries ā shape: (seq_len, d_k)
K = X @ W_K # Keys ā shape: (seq_len, d_k)
V = X @ W_V # Values ā shape: (seq_len, d_v)W_Q,W_Kare(d_model, d_k)learned weight matricesW_Vis(d_model, d_v)ā usuallyd_v = d_k = d_model / hwherehis the number of heads
Think of it like a search engine:
- Query ā what you are looking for
- Key ā the index of each document
- Value ā the actual content you retrieve
Scaled Dot-Product Attention
The formula for scaled dot-product attention is:
Attention(Q, K, V) = softmax( Q Ć K^T / sqrt(d_k) ) Ć VStep by step:
- Dot product:
Q Ć K^Tproduces a(seq_len, seq_len)matrix of raw scores. Each entry[i, j]measures how much tokeniattends to tokenj. - Scale: divide by
sqrt(d_k). Without this, dot products grow large in magnitude asd_kincreases, pushing softmax into saturation regions with near-zero gradients. - Softmax: convert scores into a probability distribution over keys. Each row sums to 1.
- Weighted sum: multiply by
V. Each output is a blend of all value vectors, weighted by attention scores.
Why Scaling Matters
If d_k = 64, sqrt(d_k) = 8. Dot products of random unit vectors have variance d_k, so their standard deviation is sqrt(d_k). Dividing brings variance back to 1, keeping softmax in a well-behaved gradient regime.
Why Attention Captures Long-Range Dependencies
In an RNN, the path from token 1 to token 50 requires 49 sequential steps ā gradient signal must travel backward through all of them. Attention computes all pairwise interactions in a single matrix multiply: the path length between any two tokens is always 1. This makes long-range dependencies easy to learn.
Masking
In decoder (autoregressive) attention, we add a causal mask so token i cannot attend to token j > i:
mask[i, j] = -inf if j > i else 0We add this mask before softmax, so masked positions receive zero weight after exponentiation.
Python Implementation from Scratch
import numpy as np
import math
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Compute scaled dot-product attention.
Args:
Q: (batch, seq_q, d_k)
K: (batch, seq_k, d_k)
V: (batch, seq_k, d_v)
mask: optional (batch, seq_q, seq_k) boolean mask
True means "mask this position" (set to -inf)
Returns:
output: (batch, seq_q, d_v)
attention_weights: (batch, seq_q, seq_k)
"""
d_k = Q.shape[-1]
# Step 1: dot product Q @ K^T ā (batch, seq_q, seq_k)
scores = np.matmul(Q, K.transpose(0, 2, 1))
# Step 2: scale
scores = scores / math.sqrt(d_k)
# Step 3: apply mask (causal or padding)
if mask is not None:
scores = np.where(mask, -1e9, scores)
# Step 4: softmax over the key dimension
scores_max = scores.max(axis=-1, keepdims=True)
exp_scores = np.exp(scores - scores_max)
attention_weights = exp_scores / exp_scores.sum(axis=-1, keepdims=True)
# Step 5: weighted sum of values
output = np.matmul(attention_weights, V)
return output, attention_weights
# āā Demo āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
np.random.seed(42)
batch, seq_len, d_k, d_v = 1, 5, 8, 8
Q = np.random.randn(batch, seq_len, d_k)
K = np.random.randn(batch, seq_len, d_k)
V = np.random.randn(batch, seq_len, d_v)
output, weights = scaled_dot_product_attention(Q, K, V)
print("Output shape:", output.shape) # (1, 5, 8)
print("Attention weights:\n", weights.round(3))Full Linear-Projection Attention Layer in PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SingleHeadAttention(nn.Module):
"""Single-head scaled dot-product attention with learned projections."""
def __init__(self, d_model: int, d_k: int = None):
super().__init__()
d_k = d_k or d_model
self.d_k = d_k
self.W_Q = nn.Linear(d_model, d_k, bias=False)
self.W_K = nn.Linear(d_model, d_k, bias=False)
self.W_V = nn.Linear(d_model, d_k, bias=False)
self.W_O = nn.Linear(d_k, d_model, bias=False)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
"""
x: (batch, seq_len, d_model)
mask: (batch, seq_len, seq_len) or None
"""
Q = self.W_Q(x) # (B, T, d_k)
K = self.W_K(x) # (B, T, d_k)
V = self.W_V(x) # (B, T, d_k)
# Scaled dot-product
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask, float('-inf'))
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V) # (B, T, d_k)
return self.W_O(out), attn # project back to d_model
# Quick test
model = SingleHeadAttention(d_model=64)
x = torch.randn(2, 10, 64)
out, attn = model(x)
print("Output:", out.shape) # torch.Size([2, 10, 64])
print("Attn: ", attn.shape) # torch.Size([2, 10, 10])Visualising Attention Weights
Attention weight matrices are interpretable. After training, you can extract the (seq_q, seq_k) matrix and visualise which positions each token attends to:
import matplotlib.pyplot as plt
import seaborn as sns
def plot_attention(weights, tokens):
"""
weights: (seq_len, seq_len) numpy array
tokens: list of token strings
"""
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(weights, xticklabels=tokens, yticklabels=tokens,
cmap='Blues', ax=ax, vmin=0, vmax=1)
ax.set_xlabel("Key positions")
ax.set_ylabel("Query positions")
ax.set_title("Attention Weight Heatmap")
plt.tight_layout()
plt.savefig("attention_heatmap.png", dpi=150)
plt.show()
tokens = ["The", "cat", "sat", "on", "mat"]
# Grab weights from our model (first batch item)
w = attn[0].detach().numpy()
plot_attention(w, tokens)Computational Complexity
| Operation | Complexity |
|-----------|-----------|
| Q à K^T | O(n² à d_k) |
| Softmax | O(n²) |
| Attn à V | O(n² à d_v) |
The O(n²) factor in sequence length n is the bottleneck that motivated sparse attention, linear attention, and Flash Attention (covered in later lessons).
Key Takeaways
- Attention maps each position to a weighted blend of all other positions via Q, K, V projections.
- Scaling by
sqrt(d_k)keeps gradients stable. - Causal masking enforces left-to-right generation in decoders.
- The all-pairs interaction makes long-range dependencies easy to learn but costs
O(n²)memory and compute. - Learned weight matrices
W_Q,W_K,W_Vallow the model to discover which comparisons are useful.
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.