Encoder vs Decoder Architecture
Encoder: bidirectional for classification/embedding (BERT). Decoder: autoregressive for generation (GPT). Encoder-decoder: translation, summarization (T5). Masked vs unmasked attention.
The Three Transformer Families
All transformer-based language models are built from one of three architectural blueprints:
| Family | Attention | Primary Use | Examples | |--------|-----------|-------------|---------| | Encoder-only | Bidirectional | Embedding, classification | BERT, RoBERTa, DeBERTa | | Decoder-only | Causal (unidirectional) | Text generation | GPT-2, GPT-4, LLaMA | | Encoder-Decoder | Bidirectional enc + causal dec + cross-attn | Translation, summarisation | T5, BART, Whisper |
Encoder Architecture
An encoder reads the entire input simultaneously. Every token can attend to every other token ā past and future ā making the representation bidirectional and context-rich.
Encoder Block
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class EncoderBlock(nn.Module):
"""
One transformer encoder layer.
Uses Pre-LN (layer norm before attention) for stable training.
"""
def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
assert d_model % n_heads == 0
self.d_k = d_model // n_heads
self.n_heads = n_heads
# Multi-head attention projections
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
# Feed-forward
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model),
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def _split(self, t: torch.Tensor) -> torch.Tensor:
B, T, _ = t.shape
return t.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
def _merge(self, t: torch.Tensor) -> torch.Tensor:
B, h, T, d_k = t.shape
return t.transpose(1, 2).contiguous().view(B, T, h * d_k)
def forward(self, x: torch.Tensor, padding_mask: torch.Tensor = None):
"""
x: (B, T, d_model)
padding_mask: (B, T) ā True where padding tokens are
"""
# Pre-LN self-attention
residual = x
x_norm = self.norm1(x)
Q = self._split(self.W_Q(x_norm))
K = self._split(self.W_K(x_norm))
V = self._split(self.W_V(x_norm))
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if padding_mask is not None:
# (B, 1, 1, T) ā broadcast over heads and query positions
scores = scores.masked_fill(padding_mask[:, None, None, :], float('-inf'))
attn = F.softmax(scores, dim=-1)
out = self._merge(torch.matmul(self.dropout(attn), V))
x = residual + self.dropout(self.W_O(out))
# Pre-LN feed-forward
x = x + self.dropout(self.ff(self.norm2(x)))
return x
# Encoder stack
class BERTEncoder(nn.Module):
def __init__(
self,
vocab_size: int,
d_model: int = 768,
n_layers: int = 12,
n_heads: int = 12,
d_ff: int = 3072,
max_len: int = 512,
):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_len, d_model)
self.norm = nn.LayerNorm(d_model)
self.layers = nn.ModuleList(
[EncoderBlock(d_model, n_heads, d_ff) for _ in range(n_layers)]
)
def forward(self, input_ids: torch.Tensor, padding_mask: torch.Tensor = None):
B, T = input_ids.shape
pos = torch.arange(T, device=input_ids.device)
x = self.norm(self.token_emb(input_ids) + self.pos_emb(pos))
for layer in self.layers:
x = layer(x, padding_mask)
return x # (B, T, d_model)
enc = BERTEncoder(vocab_size=30522)
ids = torch.randint(0, 30522, (2, 64))
out = enc(ids)
print("Encoder output:", out.shape) # (2, 64, 768)Use Cases for Encoders
- Text classification: take the [CLS] token representation ā linear head
- Named entity recognition: per-token linear head
- Sentence embeddings: mean-pool over all token embeddings
- Semantic similarity: compare embeddings from two encoders
Decoder Architecture
A decoder generates tokens autoregressively ā one at a time, left to right. The causal mask ensures token i only attends to tokens at positions 0 through i.
class DecoderBlock(nn.Module):
"""GPT-style decoder block (no cross-attention)."""
def __init__(self, d_model: int, n_heads: int, d_ff: int, max_len: int = 2048):
super().__init__()
assert d_model % n_heads == 0
self.d_k = d_model // n_heads
self.n_heads = n_heads
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(0.1)
# Causal mask ā upper triangular
mask = torch.triu(torch.ones(max_len, max_len, dtype=torch.bool), diagonal=1)
self.register_buffer("causal_mask", mask)
def _split(self, t):
B, T, _ = t.shape
return t.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
def _merge(self, t):
B, h, T, d_k = t.shape
return t.transpose(1, 2).contiguous().view(B, T, h * d_k)
def forward(self, x: torch.Tensor):
B, T, _ = x.shape
residual = x
xn = self.norm1(x)
Q = self._split(self.W_Q(xn))
K = self._split(self.W_K(xn))
V = self._split(self.W_V(xn))
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# Apply causal mask ā only attend to past positions
scores = scores.masked_fill(self.causal_mask[:T, :T], float('-inf'))
attn = F.softmax(scores, dim=-1)
out = self._merge(torch.matmul(self.dropout(attn), V))
x = residual + self.dropout(self.W_O(out))
x = x + self.dropout(self.ff(self.norm2(x)))
return x
class GPTDecoder(nn.Module):
def __init__(self, vocab_size: int, d_model: int = 768, n_layers: int = 12,
n_heads: int = 12, d_ff: int = 3072, max_len: int = 1024):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_len, d_model)
self.layers = nn.ModuleList(
[DecoderBlock(d_model, n_heads, d_ff, max_len) for _ in range(n_layers)]
)
self.norm = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, input_ids: torch.Tensor):
B, T = input_ids.shape
pos = torch.arange(T, device=input_ids.device)
x = self.token_emb(input_ids) + self.pos_emb(pos)
for layer in self.layers:
x = layer(x)
logits = self.head(self.norm(x)) # (B, T, vocab_size)
return logits
gpt = GPTDecoder(vocab_size=50257)
ids = torch.randint(0, 50257, (2, 32))
logits = gpt(ids)
print("GPT logits:", logits.shape) # (2, 32, 50257)Masked Attention Explained
"Masking" means two distinct things depending on context:
Padding Mask (Encoder)
Encoders process variable-length sequences padded to a fixed length. Padding tokens should not contribute to attention:
# input_ids with padding (pad_id = 0)
input_ids = torch.tensor([[101, 2003, 1037, 2062, 102, 0, 0, 0]])
padding_mask = (input_ids == 0) # True where padded
# Scores at padded positions ā -inf ā softmax ā 0Causal Mask (Decoder)
Decoders use an upper-triangular boolean mask:
T = 6
causal = torch.triu(torch.ones(T, T, dtype=torch.bool), diagonal=1)
print(causal.int())
# tensor([[0, 1, 1, 1, 1, 1],
# [0, 0, 1, 1, 1, 1],
# [0, 0, 0, 1, 1, 1],
# [0, 0, 0, 0, 1, 1],
# [0, 0, 0, 0, 0, 1],
# [0, 0, 0, 0, 0, 0]])
# 1 (True) = masked = attend NOT allowedEncoder-Decoder: T5 / BART / Whisper
An encoder-decoder model adds cross-attention in each decoder block, letting decoder positions attend to the full encoder output:
Decoder layer:
1. Causal self-attention (decoder ā decoder)
2. Cross-attention (decoder query ā encoder keys/values)
3. Feed-forwardThis architecture shines when the input and output are different sequences: translating English to French, summarising a document, or transcribing audio to text.
Autoregressive Generation (Greedy)
def greedy_generate(model, tokenizer, prompt: str, max_new_tokens: int = 50):
"""Autoregressive greedy decoding with a GPT-style model."""
input_ids = tokenizer.encode(prompt, return_tensors="pt")
for _ in range(max_new_tokens):
with torch.no_grad():
logits = model(input_ids) # (1, T, vocab_size)
next_token_logits = logits[:, -1, :] # last position
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_token], dim=1)
if next_token.item() == tokenizer.eos_token_id:
break
return tokenizer.decode(input_ids[0])Which Architecture for Which Task?
| Task | Best Architecture | Why | |------|------------------|-----| | Sentence similarity | Encoder-only | Bidirectional context; good embeddings | | Text classification | Encoder-only | [CLS] token captures full context | | Open-ended generation | Decoder-only | Causal LM training matches the task | | Translation | Encoder-Decoder | Input and output are different sequences | | Summarisation | Either | Enc-dec for abstractive; decoder-only works too | | Code completion | Decoder-only | Fill-in-the-middle or left-to-right generation |
Key Takeaways
- Encoders use bidirectional attention ā all tokens see all other tokens ā best for understanding.
- Decoders use causal attention ā each token only sees past tokens ā necessary for generation.
- Encoder-decoders combine both with an additional cross-attention layer.
- Padding masks prevent attending to pad tokens; causal masks enforce temporal ordering.
- The architectural choice should follow the task, not just trend.
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.