GPT Architecture: Inside the Decoder-Only Transformer
Deep dive into GPT's decoder-only architecture: token embeddings, causal attention, FFN layers, residual stream, and how autoregressive generation works end-to-end.
The Decoder-Only Design Choice
GPT (Generative Pre-trained Transformer) uses only the decoder portion of the original encoder-decoder transformer. The key difference: a causal (unidirectional) attention mask ensures each token attends only to previous tokens.
This design choice has significant implications:
- Training objective: Next-token prediction — simple, scalable, and data-efficient
- No cross-attention: No separate encoder, so no additional module to align
- Generation is natural: The model generates left-to-right, matching the training objective exactly
The original encoder-decoder architecture was designed for sequence-to-sequence tasks (translation). GPT's authors recognized that next-token prediction at scale produces general-purpose language understanding as an emergent property.
Full Architecture: Residual Stream View
The cleanest mental model for GPT is the residual stream:
Input tokens → Embedding → [Residual Stream] → Unembed → Logits
↑
Each transformer block
reads from and writes to
the same stream:
stream = stream + attn(stream)
stream = stream + ffn(stream)Every layer adds to a shared residual stream. This framing (from Anthropic's mechanistic interpretability work) clarifies why residual connections are fundamental — they allow information to flow unchanged through layers that don't need to transform it.
Layer-by-Layer Implementation
import torch
import torch.nn as nn
import math
from typing import Optional
class GPTConfig:
"""GPT-2-scale model configuration."""
vocab_size: int = 50257
n_positions: int = 1024 # Max sequence length
n_embd: int = 768 # d_model
n_layer: int = 12 # Number of transformer blocks
n_head: int = 12 # Number of attention heads
dropout: float = 0.1
bias: bool = True
class CausalSelfAttention(nn.Module):
"""Multi-head self-attention with causal mask."""
def __init__(self, config: GPTConfig):
super().__init__()
assert config.n_embd % config.n_head == 0
self.n_head = config.n_head
self.n_embd = config.n_embd
self.head_dim = config.n_embd // config.n_head
# Q, K, V projections combined for efficiency
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
# Causal mask — lower triangular
self.register_buffer(
"bias",
torch.tril(torch.ones(config.n_positions, config.n_positions))
.view(1, 1, config.n_positions, config.n_positions)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape # batch, sequence_len, embedding_dim
# Compute Q, K, V
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
# Reshape for multi-head attention: (B, n_head, T, head_dim)
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
# Scaled dot-product attention
scale = 1.0 / math.sqrt(self.head_dim)
attn = (q @ k.transpose(-2, -1)) * scale # (B, n_head, T, T)
# Apply causal mask: fill with -inf where mask is 0
attn = attn.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
attn = torch.softmax(attn, dim=-1)
attn = self.attn_dropout(attn)
# Weighted sum of values
out = attn @ v # (B, n_head, T, head_dim)
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.resid_dropout(self.c_proj(out))
class MLP(nn.Module):
"""Position-wise feed-forward network."""
def __init__(self, config: GPTConfig):
super().__init__()
# 4× expansion is the GPT-2 convention
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
self.gelu = nn.GELU()
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
class Block(nn.Module):
"""Transformer block: LayerNorm → Attention → residual, LayerNorm → FFN → residual."""
def __init__(self, config: GPTConfig):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Pre-LayerNorm (GPT-2 and most modern LLMs use pre-norm)
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class GPT(nn.Module):
"""Full GPT model."""
def __init__(self, config: GPTConfig):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict({
"wte": nn.Embedding(config.vocab_size, config.n_embd), # Token embeddings
"wpe": nn.Embedding(config.n_positions, config.n_embd), # Position embeddings
"drop": nn.Dropout(config.dropout),
"h": nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
"ln_f": nn.LayerNorm(config.n_embd), # Final LayerNorm
})
# Unembed: project back to vocab_size
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Weight tying: embedding table = unembed matrix (saves parameters and stabilizes training)
self.transformer["wte"].weight = self.lm_head.weight
def forward(
self,
idx: torch.Tensor,
targets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
B, T = idx.shape
assert T <= self.config.n_positions
# Build input: token embedding + positional embedding
pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
tok_emb = self.transformer["wte"](idx) # (B, T, n_embd)
pos_emb = self.transformer["wpe"](pos) # (T, n_embd) — broadcast over batch
x = self.transformer["drop"](tok_emb + pos_emb)
# Pass through transformer blocks
for block in self.transformer["h"]:
x = block(x)
x = self.transformer["ln_f"](x)
# Compute logits
logits = self.lm_head(x) # (B, T, vocab_size)
# Compute cross-entropy loss if targets provided
loss = None
if targets is not None:
loss = torch.nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-1,
)
return logits, lossAutoregressive Generation
GPT generates one token at a time, feeding each output back as input:
@torch.no_grad()
def generate(
model: GPT,
idx: torch.Tensor,
max_new_tokens: int,
temperature: float = 1.0,
top_k: Optional[int] = None,
) -> torch.Tensor:
"""
Autoregressive token generation with optional top-k sampling.
idx: (B, T) input token indices
"""
for _ in range(max_new_tokens):
# Crop to context window if needed
idx_cond = idx if idx.size(1) <= model.config.n_positions else idx[:, -model.config.n_positions:]
logits, _ = model(idx_cond)
# Take logits at the last position (the next-token prediction)
logits = logits[:, -1, :] / temperature # (B, vocab_size)
# Optional top-k filtering: zero out all but top-k probabilities
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
probs = torch.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
idx = torch.cat((idx, idx_next), dim=1) # Append and continue
return idxThe key property: the model processes the full context on every forward pass, but only the last token's logits determine the next token. KV caching avoids recomputing keys and values for past tokens.
GPT Scale Variants
| Model | Layers | d_model | Heads | Parameters | |---|---|---|---|---| | GPT-2 Small | 12 | 768 | 12 | 117M | | GPT-2 Medium | 24 | 1024 | 16 | 345M | | GPT-2 Large | 36 | 1280 | 20 | 762M | | GPT-2 XL | 48 | 1600 | 25 | 1.5B | | GPT-3 | 96 | 12288 | 96 | 175B | | GPT-4 (estimated) | ~120 | ~16384 | ~128 | ~1T (MoE) |
Parameter count: roughly 12 × n_layer × n_embd² (accounting for QKV, output projections, and FFN layers).
What Each Component Learns
Attention heads learn to:
- Copy information from specific positions (induction heads)
- Attend to syntactically related tokens (subject-verb agreement)
- Implement name-binding (associating pronouns to their referents)
FFN layers learn to:
- Store factual associations (Paris → France capital)
- Implement "memory" — directly storing key-value pairs from pretraining
- Apply transformations that depend on the current token's meaning
Residual stream enables:
- Early layers to pass information directly to late layers
- Different computations to compose additively
- Gradient flow through hundreds of layers without vanishing
Pre-Norm vs Post-Norm
GPT-2 switched to pre-norm (LayerNorm before attention/FFN), unlike the original transformer which used post-norm (LayerNorm after residual addition).
# Post-norm (original transformer — unstable at scale):
x = LayerNorm(x + Attention(x))
# Pre-norm (GPT-2, LLaMA — more stable at large scale):
x = x + Attention(LayerNorm(x))Pre-norm keeps the residual path clean (no LayerNorm on the highway), which stabilizes gradients during training of deep networks. This is why all modern large LLMs use pre-norm.
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.