Learnixo
Back to blog
AI Systemsintermediate

Scaled Dot-Product Attention

The complete attention computation: dot products, scaling, masking, softmax, and value aggregation. Step-by-step with shapes and code.

Asma Hafeez KhanMay 16, 20264 min read
TransformersAttentionScaled Dot-ProductDeep LearningInterview
Share:𝕏

The Full Formula

Attention(Q, K, V) = softmax(QKᵀ / √dₖ) · V

This is the core operation of every transformer. All four steps matter:


Step 1: Dot-Product Scores

S = Q · Kᵀ

Shapes:
  Q: (batch, heads, seq_len, dₖ)
  K: (batch, heads, seq_len, dₖ)
  Kᵀ: (batch, heads, dₖ, seq_len)
  S: (batch, heads, seq_len, seq_len)  ← attention score matrix

S[i, j] = Q[i] · K[j]ᵀ = similarity of query at position i to key at position j

Each row of S is a query position asking: "how relevant is every key position to me?"


Step 2: Scale by √dₖ

S = S / √dₖ

Why: Dot products grow with dₖ (the embedding dimension per head). Large scores cause softmax to saturate — gradients become tiny. Dividing by √dₖ keeps values in a range where gradients flow well.

Example: with dₖ=64, divide by 8. With dₖ=128, divide by ~11.3.


Step 3: Masking (optional)

In decoder self-attention, future positions must be masked — the model should not attend to tokens it hasn't yet generated:

Mask: upper triangle of -∞ values

S_masked[i, j] = S[i, j]      if j ≤ i  (position j is in the past or present)
               = -∞            if j > i  (position j is in the future)

After softmax, -∞ → 0 (these positions contribute nothing)

Encoder self-attention: no masking needed (all positions visible to all). Cross-attention: no causal masking (query attends to all encoder positions).


Step 4: Softmax → Attention Weights

A = softmax(S, dim=-1)   ← softmax over last dimension (keys)

A[i, j] ∈ [0, 1]
Σⱼ A[i, j] = 1           ← each row sums to 1

A is now an attention weight matrix: A[i, j] is "how much attention position i pays to position j."


Step 5: Weighted Sum of Values

Output = A · V

A: (batch, heads, seq_len, seq_len)
V: (batch, heads, seq_len, dᵥ)
Output: (batch, heads, seq_len, dᵥ)

Output[i] = Σⱼ A[i,j] · V[j]

Each output position is a weighted blend of all value vectors, where high-attention positions contribute more.


Complete Code

Python
import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(
    Q: torch.Tensor,          # (batch, heads, seq_len, d_k)
    K: torch.Tensor,          # (batch, heads, seq_len, d_k)
    V: torch.Tensor,          # (batch, heads, seq_len, d_v)
    mask: torch.Tensor = None # (batch, 1, seq_len, seq_len) or None
) -> tuple[torch.Tensor, torch.Tensor]:
    d_k = Q.size(-1)

    # Step 1 + 2: Scaled dot-product scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    # scores: (batch, heads, seq_len, seq_len)

    # Step 3: Apply mask
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Step 4: Softmax
    attn_weights = F.softmax(scores, dim=-1)

    # Step 5: Weighted sum
    output = torch.matmul(attn_weights, V)
    # output: (batch, heads, seq_len, d_v)

    return output, attn_weights


# PyTorch 2.0+: FlashAttention (memory-efficient, same result)
output = F.scaled_dot_product_attention(Q, K, V, attn_mask=mask)

Complexity

Time:   O(n² · dₖ)   — the n² comes from the QKᵀ matrix multiply
Space:  O(n²)         — the attention matrix A has n² entries

For n=4096 (GPT-4 context length), the attention matrix:
  4096 × 4096 × 4 bytes (float32) ≈ 64 MB per head per layer
  With 96 heads × 96 layers: ~590 GB just for attention weights

→ FlashAttention rewrites this to not materialise the full matrix in VRAM,
  reducing memory from O(n²) to O(n)

Interview Answer

"Scaled dot-product attention computes attention in four steps: (1) dot-product Q·Kᵀ to get similarity scores, (2) divide by √dₖ to prevent softmax saturation, (3) optionally apply a causal mask (setting future positions to -∞), (4) apply softmax to get attention weights, then multiply by V to produce the output — a weighted blend of value vectors. The complexity is O(n²·dₖ) in time and O(n²) in space, which is why efficient variants like FlashAttention exist."

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.