Live Coding Interview Prep · Lesson 9 of 16
Implement Scaled Dot-Product Attention in NumPy
The Dot Product in Attention
Dot product is the core operation in transformer attention. The self-attention mechanism scores how much each token should attend to every other token — using dot products between query and key vectors.
Understanding and implementing this from scratch is a common AI engineer interview task.
Basic Dot Product
import numpy as np
def dot_product(a: list[float], b: list[float]) -> float:
"""Standard dot product: sum of element-wise products."""
if len(a) != len(b):
raise ValueError("Vectors must have the same dimension")
return sum(x * y for x, y in zip(a, b))
# Test
q = [1.0, 0.5, -1.0] # Query vector
k = [0.8, 0.3, 0.2] # Key vector
score = dot_product(q, k)
print(f"Attention score: {score}") # 1.0*0.8 + 0.5*0.3 + (-1.0)*0.2 = 0.75Scaled Dot Product Attention
The full attention operation from "Attention Is All You Need":
Attention(Q, K, V) = softmax(Q @ K.T / sqrt(d_k)) @ VWhy the scaling factor sqrt(d_k)? Without it, dot products grow large for high-dimensional vectors, pushing softmax into regions where gradients vanish.
import numpy as np
def softmax(x: np.ndarray) -> np.ndarray:
"""Numerically stable softmax."""
x_shifted = x - np.max(x) # Subtract max for stability
exp_x = np.exp(x_shifted)
return exp_x / exp_x.sum()
def scaled_dot_product_attention(
Q: np.ndarray, # (seq_len, d_k)
K: np.ndarray, # (seq_len, d_k)
V: np.ndarray, # (seq_len, d_v)
mask: np.ndarray | None = None, # Optional causal mask
) -> tuple[np.ndarray, np.ndarray]:
"""
Returns:
output: (seq_len, d_v) — attended values
attention_weights: (seq_len, seq_len) — for visualization
"""
d_k = Q.shape[-1]
# Step 1: Compute raw attention scores
scores = Q @ K.T # (seq_len, seq_len)
# Step 2: Scale to prevent gradient vanishing
scores = scores / np.sqrt(d_k)
# Step 3: Apply mask (for causal/autoregressive models)
if mask is not None:
scores = np.where(mask, scores, -1e9)
# Step 4: Softmax to get attention weights
attention_weights = np.apply_along_axis(softmax, axis=1, arr=scores)
# Step 5: Weighted sum of values
output = attention_weights @ V
return output, attention_weights
# Example: 4 tokens, 8-dimensional keys, 16-dimensional values
seq_len = 4
d_k = 8
d_v = 16
np.random.seed(42)
Q = np.random.randn(seq_len, d_k)
K = np.random.randn(seq_len, d_k)
V = np.random.randn(seq_len, d_v)
output, weights = scaled_dot_product_attention(Q, K, V)
print(f"Output shape: {output.shape}") # (4, 16)
print(f"Weights shape: {weights.shape}") # (4, 4)
print(f"Weights sum (each row): {weights.sum(axis=1)}") # All 1.0Causal Masking (Autoregressive)
For language models, a token can only attend to previous tokens (not future ones). The causal mask prevents attending forward:
def causal_mask(seq_len: int) -> np.ndarray:
"""
Lower-triangular mask.
mask[i, j] = True means position i CAN attend to position j.
"""
return np.tril(np.ones((seq_len, seq_len), dtype=bool))
seq_len = 5
mask = causal_mask(seq_len)
print(mask)
# [[True False False False False]
# [True True False False False]
# [True True True False False]
# [True True True True False]
# [True True True True True]]
# Token 0 can only see itself
# Token 4 can see all previous tokens
Q = np.random.randn(seq_len, d_k)
K = np.random.randn(seq_len, d_k)
V = np.random.randn(seq_len, d_v)
output, weights = scaled_dot_product_attention(Q, K, V, mask=mask)
print(f"Masked weights (upper triangle should be ~0):\n{weights.round(3)}")Multi-Head Attention Structure
Single-head attention captures one type of relationship. Multi-head attention runs h heads in parallel, each learning different relationship types:
def multi_head_attention(
Q: np.ndarray, # (seq_len, d_model)
K: np.ndarray, # (seq_len, d_model)
V: np.ndarray, # (seq_len, d_model)
W_q: np.ndarray, # (d_model, d_model)
W_k: np.ndarray, # (d_model, d_model)
W_v: np.ndarray, # (d_model, d_model)
W_o: np.ndarray, # (d_model, d_model)
n_heads: int,
) -> np.ndarray:
d_model = Q.shape[-1]
d_k = d_model // n_heads
# Project to Q, K, V
Q_proj = Q @ W_q # (seq_len, d_model)
K_proj = K @ W_k
V_proj = V @ W_v
# Split into heads: (seq_len, n_heads, d_k)
def split_heads(x):
seq_len = x.shape[0]
x = x.reshape(seq_len, n_heads, d_k)
return x.transpose(1, 0, 2) # (n_heads, seq_len, d_k)
Q_heads = split_heads(Q_proj) # (n_heads, seq_len, d_k)
K_heads = split_heads(K_proj)
V_heads = split_heads(V_proj)
# Attention per head
head_outputs = []
for h in range(n_heads):
head_out, _ = scaled_dot_product_attention(Q_heads[h], K_heads[h], V_heads[h])
head_outputs.append(head_out)
# Concatenate heads: (seq_len, d_model)
concat = np.concatenate(head_outputs, axis=-1)
# Final projection
return concat @ W_oInterview Follow-Ups
Q: Why does attention use Q, K, V instead of just comparing tokens directly?
The projection matrices W_q, W_k, W_v are learned — they transform tokens into the "question space" (what am I looking for?), "key space" (what do I offer?), and "value space" (what information do I contribute?). This allows the model to learn what aspects of tokens are relevant for matching.
Q: What is the computational complexity of self-attention?
O(n² × d) where n is sequence length and d is dimension. The quadratic scaling on n is why long contexts are expensive — doubling the sequence length quadruples the attention computation.
Q: How does Flash Attention improve on standard attention?
Flash Attention reorders operations to reduce memory bandwidth usage. Standard attention materializes the full (n × n) attention matrix in HBM (GPU memory). Flash Attention keeps intermediate results in SRAM (fast, small cache), never writing the full attention matrix — reducing memory from O(n²) to O(n) at the cost of more FLOPs.