Transformer Architecture Q&A · Lesson 3 of 23
Scaled Dot-Product Attention: The Math
The Full Formula
Attention(Q, K, V) = softmax(QKᵀ / √dₖ) · VThis is the core operation of every transformer. All four steps matter:
Step 1: Dot-Product Scores
S = Q · Kᵀ
Shapes:
Q: (batch, heads, seq_len, dₖ)
K: (batch, heads, seq_len, dₖ)
Kᵀ: (batch, heads, dₖ, seq_len)
S: (batch, heads, seq_len, seq_len) ← attention score matrix
S[i, j] = Q[i] · K[j]ᵀ = similarity of query at position i to key at position jEach row of S is a query position asking: "how relevant is every key position to me?"
Step 2: Scale by √dₖ
S = S / √dₖWhy: Dot products grow with dₖ (the embedding dimension per head). Large scores cause softmax to saturate — gradients become tiny. Dividing by √dₖ keeps values in a range where gradients flow well.
Example: with dₖ=64, divide by 8. With dₖ=128, divide by ~11.3.
Step 3: Masking (optional)
In decoder self-attention, future positions must be masked — the model should not attend to tokens it hasn't yet generated:
Mask: upper triangle of -∞ values
S_masked[i, j] = S[i, j] if j ≤ i (position j is in the past or present)
= -∞ if j > i (position j is in the future)
After softmax, -∞ → 0 (these positions contribute nothing)Encoder self-attention: no masking needed (all positions visible to all). Cross-attention: no causal masking (query attends to all encoder positions).
Step 4: Softmax → Attention Weights
A = softmax(S, dim=-1) ← softmax over last dimension (keys)
A[i, j] ∈ [0, 1]
Σⱼ A[i, j] = 1 ← each row sums to 1A is now an attention weight matrix: A[i, j] is "how much attention position i pays to position j."
Step 5: Weighted Sum of Values
Output = A · V
A: (batch, heads, seq_len, seq_len)
V: (batch, heads, seq_len, dᵥ)
Output: (batch, heads, seq_len, dᵥ)Output[i] = Σⱼ A[i,j] · V[j]
Each output position is a weighted blend of all value vectors, where high-attention positions contribute more.
Complete Code
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(
Q: torch.Tensor, # (batch, heads, seq_len, d_k)
K: torch.Tensor, # (batch, heads, seq_len, d_k)
V: torch.Tensor, # (batch, heads, seq_len, d_v)
mask: torch.Tensor = None # (batch, 1, seq_len, seq_len) or None
) -> tuple[torch.Tensor, torch.Tensor]:
d_k = Q.size(-1)
# Step 1 + 2: Scaled dot-product scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# scores: (batch, heads, seq_len, seq_len)
# Step 3: Apply mask
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 4: Softmax
attn_weights = F.softmax(scores, dim=-1)
# Step 5: Weighted sum
output = torch.matmul(attn_weights, V)
# output: (batch, heads, seq_len, d_v)
return output, attn_weights
# PyTorch 2.0+: FlashAttention (memory-efficient, same result)
output = F.scaled_dot_product_attention(Q, K, V, attn_mask=mask)Complexity
Time: O(n² · dₖ) — the n² comes from the QKᵀ matrix multiply
Space: O(n²) — the attention matrix A has n² entries
For n=4096 (GPT-4 context length), the attention matrix:
4096 × 4096 × 4 bytes (float32) ≈ 64 MB per head per layer
With 96 heads × 96 layers: ~590 GB just for attention weights
→ FlashAttention rewrites this to not materialise the full matrix in VRAM,
reducing memory from O(n²) to O(n)Interview Answer
"Scaled dot-product attention computes attention in four steps: (1) dot-product Q·Kᵀ to get similarity scores, (2) divide by √dₖ to prevent softmax saturation, (3) optionally apply a causal mask (setting future positions to -∞), (4) apply softmax to get attention weights, then multiply by V to produce the output — a weighted blend of value vectors. The complexity is O(n²·dₖ) in time and O(n²) in space, which is why efficient variants like FlashAttention exist."