Self-Attention vs Cross-Attention
Self-attention: query and key from the same sequence. Cross-attention: query from decoder, key/value from encoder. Use in encoder-decoder models with code examples.
The Distinction at a Glance
Both self-attention and cross-attention use the same scaled dot-product formula:
Attention(Q, K, V) = softmax( Q × K^T / sqrt(d_k) ) × VThe difference is where Q, K, and V come from:
| Variant | Q source | K, V source | |---------|----------|-------------| | Self-attention | Same sequence | Same sequence | | Cross-attention | Decoder hidden states | Encoder hidden states |
Self-Attention
In self-attention, a single input sequence X is projected into all three matrices:
Q = X @ W_Q
K = X @ W_K
V = X @ W_VEvery token attends to every other token in the same sequence. This lets the model build context-aware representations: the word "bank" in "river bank" attends differently to surrounding words than "bank" in "bank account".
Self-Attention in Encoders (BERT)
Encoder self-attention is bidirectional — token i can attend to any token j, including those that come after it. This is ideal for understanding tasks like classification, named entity recognition, or building embeddings.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelfAttention(nn.Module):
"""Bidirectional self-attention (encoder-style)."""
def __init__(self, d_model: int, d_k: int):
super().__init__()
self.d_k = d_k
self.W_Q = nn.Linear(d_model, d_k, bias=False)
self.W_K = nn.Linear(d_model, d_k, bias=False)
self.W_V = nn.Linear(d_model, d_k, bias=False)
def forward(self, x: torch.Tensor, padding_mask: torch.Tensor = None):
"""
x: (B, T, d_model)
padding_mask: (B, T) — True where padding tokens live
"""
Q = self.W_Q(x) # (B, T, d_k)
K = self.W_K(x)
V = self.W_V(x)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# Expand padding mask to (B, 1, T) so it broadcasts across queries
if padding_mask is not None:
scores = scores.masked_fill(padding_mask.unsqueeze(1), float('-inf'))
attn = F.softmax(scores, dim=-1)
return torch.matmul(attn, V), attnCausal Self-Attention in Decoders (GPT)
Decoder self-attention adds a causal mask: token i may only attend to tokens j <= i. This ensures autoregressive generation — the model cannot "see the future".
class CausalSelfAttention(nn.Module):
"""Unidirectional (causal) self-attention — decoder style."""
def __init__(self, d_model: int, d_k: int, max_len: int = 2048):
super().__init__()
self.d_k = d_k
self.W_Q = nn.Linear(d_model, d_k, bias=False)
self.W_K = nn.Linear(d_model, d_k, bias=False)
self.W_V = nn.Linear(d_model, d_k, bias=False)
# Pre-compute a causal mask (upper-triangular, above diagonal)
mask = torch.triu(torch.ones(max_len, max_len, dtype=torch.bool), diagonal=1)
self.register_buffer("causal_mask", mask)
def forward(self, x: torch.Tensor):
B, T, _ = x.shape
Q = self.W_Q(x)
K = self.W_K(x)
V = self.W_V(x)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
scores = scores.masked_fill(self.causal_mask[:T, :T], float('-inf'))
attn = F.softmax(scores, dim=-1)
return torch.matmul(attn, V), attn
# Test
causal_sa = CausalSelfAttention(d_model=64, d_k=64)
x = torch.randn(2, 8, 64)
out, attn = causal_sa(x)
print("Output:", out.shape) # (2, 8, 64)
# Lower-left triangle should have non-zero weights; upper-right should be ~0
print("Attn[0, 0, :]:", attn[0, 0, :].detach()) # Only position 0 is non-zeroCross-Attention
Cross-attention bridges an encoder and a decoder. The decoder generates queries from its own hidden states, but looks up keys and values from the encoder's output.
Q = decoder_hidden @ W_Q # (T_dec, d_k)
K = encoder_output @ W_K # (T_enc, d_k)
V = encoder_output @ W_V # (T_enc, d_v)The result (T_dec, d_v) — each decoder position is a weighted sum over all encoder positions. This is how translation works: each output word "attends" to the most relevant input words.
class CrossAttention(nn.Module):
"""
Cross-attention: queries come from the decoder,
keys and values come from the encoder.
"""
def __init__(self, d_model: int, d_k: int):
super().__init__()
self.d_k = d_k
self.W_Q = nn.Linear(d_model, d_k, bias=False)
self.W_K = nn.Linear(d_model, d_k, bias=False)
self.W_V = nn.Linear(d_model, d_k, bias=False)
self.W_O = nn.Linear(d_k, d_model, bias=False)
def forward(
self,
decoder_hidden: torch.Tensor, # (B, T_dec, d_model)
encoder_output: torch.Tensor, # (B, T_enc, d_model)
encoder_padding_mask: torch.Tensor = None, # (B, T_enc)
):
Q = self.W_Q(decoder_hidden) # (B, T_dec, d_k)
K = self.W_K(encoder_output) # (B, T_enc, d_k)
V = self.W_V(encoder_output) # (B, T_enc, d_k)
# Scores: (B, T_dec, T_enc)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if encoder_padding_mask is not None:
# Expand to (B, 1, T_enc) — broadcast over decoder positions
scores = scores.masked_fill(encoder_padding_mask.unsqueeze(1), float('-inf'))
attn = F.softmax(scores, dim=-1) # (B, T_dec, T_enc)
out = torch.matmul(attn, V) # (B, T_dec, d_k)
return self.W_O(out), attnFull Encoder-Decoder Block
A standard encoder-decoder transformer layer (T5, BART) stacks all three attention types:
class EncoderDecoderBlock(nn.Module):
"""
A single decoder block with:
1. Causal self-attention on decoder sequence
2. Cross-attention over encoder output
3. Feed-forward network
"""
def __init__(self, d_model: int, d_k: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.self_attn = CausalSelfAttention(d_model, d_k)
self.cross_attn = CrossAttention(d_model, d_k)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model),
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, decoder_input, encoder_output, encoder_mask=None):
# 1. Causal self-attention (decoder attends to itself)
sa_out, _ = self.self_attn(self.norm1(decoder_input))
x = decoder_input + self.dropout(sa_out)
# 2. Cross-attention (decoder attends to encoder)
ca_out, cross_weights = self.cross_attn(
self.norm2(x), encoder_output, encoder_mask
)
x = x + self.dropout(ca_out)
# 3. Feed-forward
x = x + self.dropout(self.ff(self.norm3(x)))
return x, cross_weights
# Demo
d_model, d_k, d_ff = 128, 64, 512
block = EncoderDecoderBlock(d_model, d_k, d_ff)
enc_out = torch.randn(2, 12, d_model) # encoder produced 12 tokens
dec_in = torch.randn(2, 7, d_model) # decoder has seen 7 tokens so far
out, cross_w = block(dec_in, enc_out)
print("Decoder block output:", out.shape) # (2, 7, 128)
print("Cross-attention weights:", cross_w.shape) # (2, 7, 12)Attention Flow in Popular Models
| Model | Self-Attention (Encoder) | Self-Attention (Decoder) | Cross-Attention | |-------|--------------------------|--------------------------|-----------------| | BERT | Bidirectional | — | — | | GPT | — | Causal | — | | T5 | Bidirectional | Causal | Yes | | BART | Bidirectional | Causal | Yes | | Whisper | Bidirectional (audio) | Causal (text) | Yes |
When to Use Which
Self-attention (encoder, bidirectional)
- Sentence embeddings, semantic search
- Text classification, NER, QA over fixed context
- BERT-style models
Self-attention (decoder, causal)
- Language modelling, text generation
- GPT-style models
Cross-attention (encoder-decoder)
- Machine translation
- Summarisation
- Speech recognition (audio → text)
- Image captioning (visual features → text)
Key Takeaways
- Self-attention and cross-attention share the same math but differ in data routing.
- Causal masking in decoder self-attention enforces left-to-right generation.
- Cross-attention is what allows encoder-decoder architectures to "translate" one sequence into another.
- The encoder mask (padding) is propagated through cross-attention so the decoder ignores pad tokens.
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.