Learnixo
Back to blog
AI Systemsintermediate

The Attention Mechanism Explained

How attention computes Q, K, V; the dot-product attention formula; why it captures long-range dependencies; and a Python from-scratch implementation.

Asma Hafeez KhanMay 15, 20266 min read
TransformersAttentionQ K VDot-Product AttentionDeep Learning
Share:š•

What Is Attention?

Before transformers, sequence models like LSTMs processed tokens one by one, compressing history into a fixed-size hidden state. By the time you reached token 50, information from token 1 was heavily diluted. Attention was introduced to solve this: every token can directly attend to every other token in a single step, regardless of distance.

The core intuition: given a query, look up the most relevant keys in a database, retrieve the corresponding values, and return a weighted sum of those values.

The Three Matrices: Q, K, V

Given an input sequence X of shape (seq_len, d_model), we project it into three separate spaces:

Q = X @ W_Q    # Queries  — shape: (seq_len, d_k)
K = X @ W_K    # Keys     — shape: (seq_len, d_k)
V = X @ W_V    # Values   — shape: (seq_len, d_v)
  • W_Q, W_K are (d_model, d_k) learned weight matrices
  • W_V is (d_model, d_v) — usually d_v = d_k = d_model / h where h is the number of heads

Think of it like a search engine:

  • Query — what you are looking for
  • Key — the index of each document
  • Value — the actual content you retrieve

Scaled Dot-Product Attention

The formula for scaled dot-product attention is:

Attention(Q, K, V) = softmax( Q Ɨ K^T / sqrt(d_k) ) Ɨ V

Step by step:

  1. Dot product: Q Ɨ K^T produces a (seq_len, seq_len) matrix of raw scores. Each entry [i, j] measures how much token i attends to token j.
  2. Scale: divide by sqrt(d_k). Without this, dot products grow large in magnitude as d_k increases, pushing softmax into saturation regions with near-zero gradients.
  3. Softmax: convert scores into a probability distribution over keys. Each row sums to 1.
  4. Weighted sum: multiply by V. Each output is a blend of all value vectors, weighted by attention scores.

Why Scaling Matters

If d_k = 64, sqrt(d_k) = 8. Dot products of random unit vectors have variance d_k, so their standard deviation is sqrt(d_k). Dividing brings variance back to 1, keeping softmax in a well-behaved gradient regime.

Why Attention Captures Long-Range Dependencies

In an RNN, the path from token 1 to token 50 requires 49 sequential steps — gradient signal must travel backward through all of them. Attention computes all pairwise interactions in a single matrix multiply: the path length between any two tokens is always 1. This makes long-range dependencies easy to learn.

Masking

In decoder (autoregressive) attention, we add a causal mask so token i cannot attend to token j > i:

mask[i, j] = -inf  if j > i else 0

We add this mask before softmax, so masked positions receive zero weight after exponentiation.

Python Implementation from Scratch

Python
import numpy as np
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention.

    Args:
        Q: (batch, seq_q, d_k)
        K: (batch, seq_k, d_k)
        V: (batch, seq_k, d_v)
        mask: optional (batch, seq_q, seq_k) boolean mask
              True means "mask this position" (set to -inf)

    Returns:
        output: (batch, seq_q, d_v)
        attention_weights: (batch, seq_q, seq_k)
    """
    d_k = Q.shape[-1]

    # Step 1: dot product Q @ K^T  → (batch, seq_q, seq_k)
    scores = np.matmul(Q, K.transpose(0, 2, 1))

    # Step 2: scale
    scores = scores / math.sqrt(d_k)

    # Step 3: apply mask (causal or padding)
    if mask is not None:
        scores = np.where(mask, -1e9, scores)

    # Step 4: softmax over the key dimension
    scores_max = scores.max(axis=-1, keepdims=True)
    exp_scores = np.exp(scores - scores_max)
    attention_weights = exp_scores / exp_scores.sum(axis=-1, keepdims=True)

    # Step 5: weighted sum of values
    output = np.matmul(attention_weights, V)

    return output, attention_weights


# ── Demo ─────────────────────────────────────────────────────────────────────
np.random.seed(42)

batch, seq_len, d_k, d_v = 1, 5, 8, 8

Q = np.random.randn(batch, seq_len, d_k)
K = np.random.randn(batch, seq_len, d_k)
V = np.random.randn(batch, seq_len, d_v)

output, weights = scaled_dot_product_attention(Q, K, V)
print("Output shape:", output.shape)          # (1, 5, 8)
print("Attention weights:\n", weights.round(3))

Full Linear-Projection Attention Layer in PyTorch

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class SingleHeadAttention(nn.Module):
    """Single-head scaled dot-product attention with learned projections."""

    def __init__(self, d_model: int, d_k: int = None):
        super().__init__()
        d_k = d_k or d_model
        self.d_k = d_k
        self.W_Q = nn.Linear(d_model, d_k, bias=False)
        self.W_K = nn.Linear(d_model, d_k, bias=False)
        self.W_V = nn.Linear(d_model, d_k, bias=False)
        self.W_O = nn.Linear(d_k, d_model, bias=False)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        """
        x: (batch, seq_len, d_model)
        mask: (batch, seq_len, seq_len) or None
        """
        Q = self.W_Q(x)   # (B, T, d_k)
        K = self.W_K(x)   # (B, T, d_k)
        V = self.W_V(x)   # (B, T, d_k)

        # Scaled dot-product
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, V)         # (B, T, d_k)
        return self.W_O(out), attn           # project back to d_model


# Quick test
model = SingleHeadAttention(d_model=64)
x = torch.randn(2, 10, 64)
out, attn = model(x)
print("Output:", out.shape)   # torch.Size([2, 10, 64])
print("Attn:  ", attn.shape)  # torch.Size([2, 10, 10])

Visualising Attention Weights

Attention weight matrices are interpretable. After training, you can extract the (seq_q, seq_k) matrix and visualise which positions each token attends to:

Python
import matplotlib.pyplot as plt
import seaborn as sns

def plot_attention(weights, tokens):
    """
    weights: (seq_len, seq_len) numpy array
    tokens:  list of token strings
    """
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(weights, xticklabels=tokens, yticklabels=tokens,
                cmap='Blues', ax=ax, vmin=0, vmax=1)
    ax.set_xlabel("Key positions")
    ax.set_ylabel("Query positions")
    ax.set_title("Attention Weight Heatmap")
    plt.tight_layout()
    plt.savefig("attention_heatmap.png", dpi=150)
    plt.show()


tokens = ["The", "cat", "sat", "on", "mat"]
# Grab weights from our model (first batch item)
w = attn[0].detach().numpy()
plot_attention(w, tokens)

Computational Complexity

| Operation | Complexity | |-----------|-----------| | Q Ɨ K^T | O(n² Ɨ d_k) | | Softmax | O(n²) | | Attn Ɨ V | O(n² Ɨ d_v) |

The O(n²) factor in sequence length n is the bottleneck that motivated sparse attention, linear attention, and Flash Attention (covered in later lessons).

Key Takeaways

  • Attention maps each position to a weighted blend of all other positions via Q, K, V projections.
  • Scaling by sqrt(d_k) keeps gradients stable.
  • Causal masking enforces left-to-right generation in decoders.
  • The all-pairs interaction makes long-range dependencies easy to learn but costs O(n²) memory and compute.
  • Learned weight matrices W_Q, W_K, W_V allow the model to discover which comparisons are useful.

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:š•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.