Learnixo
Back to blog
AI Systemsintermediate

Self-Attention vs Cross-Attention

Self-attention: query and key from the same sequence. Cross-attention: query from decoder, key/value from encoder. Use in encoder-decoder models with code examples.

Asma Hafeez KhanMay 15, 20266 min read
TransformersSelf-AttentionCross-AttentionEncoder-DecoderArchitecture
Share:𝕏

The Distinction at a Glance

Both self-attention and cross-attention use the same scaled dot-product formula:

Attention(Q, K, V) = softmax( Q × K^T / sqrt(d_k) ) × V

The difference is where Q, K, and V come from:

| Variant | Q source | K, V source | |---------|----------|-------------| | Self-attention | Same sequence | Same sequence | | Cross-attention | Decoder hidden states | Encoder hidden states |

Self-Attention

In self-attention, a single input sequence X is projected into all three matrices:

Q = X @ W_Q
K = X @ W_K
V = X @ W_V

Every token attends to every other token in the same sequence. This lets the model build context-aware representations: the word "bank" in "river bank" attends differently to surrounding words than "bank" in "bank account".

Self-Attention in Encoders (BERT)

Encoder self-attention is bidirectional — token i can attend to any token j, including those that come after it. This is ideal for understanding tasks like classification, named entity recognition, or building embeddings.

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class SelfAttention(nn.Module):
    """Bidirectional self-attention (encoder-style)."""

    def __init__(self, d_model: int, d_k: int):
        super().__init__()
        self.d_k = d_k
        self.W_Q = nn.Linear(d_model, d_k, bias=False)
        self.W_K = nn.Linear(d_model, d_k, bias=False)
        self.W_V = nn.Linear(d_model, d_k, bias=False)

    def forward(self, x: torch.Tensor, padding_mask: torch.Tensor = None):
        """
        x: (B, T, d_model)
        padding_mask: (B, T) — True where padding tokens live
        """
        Q = self.W_Q(x)   # (B, T, d_k)
        K = self.W_K(x)
        V = self.W_V(x)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Expand padding mask to (B, 1, T) so it broadcasts across queries
        if padding_mask is not None:
            scores = scores.masked_fill(padding_mask.unsqueeze(1), float('-inf'))

        attn = F.softmax(scores, dim=-1)
        return torch.matmul(attn, V), attn

Causal Self-Attention in Decoders (GPT)

Decoder self-attention adds a causal mask: token i may only attend to tokens j <= i. This ensures autoregressive generation — the model cannot "see the future".

Python
class CausalSelfAttention(nn.Module):
    """Unidirectional (causal) self-attention — decoder style."""

    def __init__(self, d_model: int, d_k: int, max_len: int = 2048):
        super().__init__()
        self.d_k = d_k
        self.W_Q = nn.Linear(d_model, d_k, bias=False)
        self.W_K = nn.Linear(d_model, d_k, bias=False)
        self.W_V = nn.Linear(d_model, d_k, bias=False)

        # Pre-compute a causal mask (upper-triangular, above diagonal)
        mask = torch.triu(torch.ones(max_len, max_len, dtype=torch.bool), diagonal=1)
        self.register_buffer("causal_mask", mask)

    def forward(self, x: torch.Tensor):
        B, T, _ = x.shape
        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        scores = scores.masked_fill(self.causal_mask[:T, :T], float('-inf'))

        attn = F.softmax(scores, dim=-1)
        return torch.matmul(attn, V), attn


# Test
causal_sa = CausalSelfAttention(d_model=64, d_k=64)
x = torch.randn(2, 8, 64)
out, attn = causal_sa(x)
print("Output:", out.shape)      # (2, 8, 64)
# Lower-left triangle should have non-zero weights; upper-right should be ~0
print("Attn[0, 0, :]:", attn[0, 0, :].detach())  # Only position 0 is non-zero

Cross-Attention

Cross-attention bridges an encoder and a decoder. The decoder generates queries from its own hidden states, but looks up keys and values from the encoder's output.

Q = decoder_hidden @ W_Q     # (T_dec, d_k)
K = encoder_output @ W_K     # (T_enc, d_k)
V = encoder_output @ W_V     # (T_enc, d_v)

The result (T_dec, d_v) — each decoder position is a weighted sum over all encoder positions. This is how translation works: each output word "attends" to the most relevant input words.

Python
class CrossAttention(nn.Module):
    """
    Cross-attention: queries come from the decoder,
    keys and values come from the encoder.
    """

    def __init__(self, d_model: int, d_k: int):
        super().__init__()
        self.d_k = d_k
        self.W_Q = nn.Linear(d_model, d_k, bias=False)
        self.W_K = nn.Linear(d_model, d_k, bias=False)
        self.W_V = nn.Linear(d_model, d_k, bias=False)
        self.W_O = nn.Linear(d_k, d_model, bias=False)

    def forward(
        self,
        decoder_hidden: torch.Tensor,   # (B, T_dec, d_model)
        encoder_output: torch.Tensor,   # (B, T_enc, d_model)
        encoder_padding_mask: torch.Tensor = None,  # (B, T_enc)
    ):
        Q = self.W_Q(decoder_hidden)     # (B, T_dec, d_k)
        K = self.W_K(encoder_output)     # (B, T_enc, d_k)
        V = self.W_V(encoder_output)     # (B, T_enc, d_k)

        # Scores: (B, T_dec, T_enc)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if encoder_padding_mask is not None:
            # Expand to (B, 1, T_enc)  broadcast over decoder positions
            scores = scores.masked_fill(encoder_padding_mask.unsqueeze(1), float('-inf'))

        attn = F.softmax(scores, dim=-1)   # (B, T_dec, T_enc)
        out = torch.matmul(attn, V)        # (B, T_dec, d_k)
        return self.W_O(out), attn

Full Encoder-Decoder Block

A standard encoder-decoder transformer layer (T5, BART) stacks all three attention types:

Python
class EncoderDecoderBlock(nn.Module):
    """
    A single decoder block with:
      1. Causal self-attention on decoder sequence
      2. Cross-attention over encoder output
      3. Feed-forward network
    """

    def __init__(self, d_model: int, d_k: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = CausalSelfAttention(d_model, d_k)
        self.cross_attn = CrossAttention(d_model, d_k)

        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, decoder_input, encoder_output, encoder_mask=None):
        # 1. Causal self-attention (decoder attends to itself)
        sa_out, _ = self.self_attn(self.norm1(decoder_input))
        x = decoder_input + self.dropout(sa_out)

        # 2. Cross-attention (decoder attends to encoder)
        ca_out, cross_weights = self.cross_attn(
            self.norm2(x), encoder_output, encoder_mask
        )
        x = x + self.dropout(ca_out)

        # 3. Feed-forward
        x = x + self.dropout(self.ff(self.norm3(x)))
        return x, cross_weights


# Demo
d_model, d_k, d_ff = 128, 64, 512
block = EncoderDecoderBlock(d_model, d_k, d_ff)

enc_out = torch.randn(2, 12, d_model)   # encoder produced 12 tokens
dec_in  = torch.randn(2, 7,  d_model)   # decoder has seen 7 tokens so far

out, cross_w = block(dec_in, enc_out)
print("Decoder block output:", out.shape)   # (2, 7, 128)
print("Cross-attention weights:", cross_w.shape)  # (2, 7, 12)

Attention Flow in Popular Models

| Model | Self-Attention (Encoder) | Self-Attention (Decoder) | Cross-Attention | |-------|--------------------------|--------------------------|-----------------| | BERT | Bidirectional | — | — | | GPT | — | Causal | — | | T5 | Bidirectional | Causal | Yes | | BART | Bidirectional | Causal | Yes | | Whisper | Bidirectional (audio) | Causal (text) | Yes |

When to Use Which

Self-attention (encoder, bidirectional)

  • Sentence embeddings, semantic search
  • Text classification, NER, QA over fixed context
  • BERT-style models

Self-attention (decoder, causal)

  • Language modelling, text generation
  • GPT-style models

Cross-attention (encoder-decoder)

  • Machine translation
  • Summarisation
  • Speech recognition (audio → text)
  • Image captioning (visual features → text)

Key Takeaways

  • Self-attention and cross-attention share the same math but differ in data routing.
  • Causal masking in decoder self-attention enforces left-to-right generation.
  • Cross-attention is what allows encoder-decoder architectures to "translate" one sequence into another.
  • The encoder mask (padding) is propagated through cross-attention so the decoder ignores pad tokens.

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.