Deep Learning for AI Interviews · Lesson 54 of 56
Transformers: Attention Is All You Need (Simplified)
Why Transformers Replaced RNNs
RNN problems:
1. Sequential computation: each timestep depends on the previous.
→ Cannot parallelise → slow training on GPUs
2. Vanishing gradients over long sequences (even with LSTM/GRU)
3. Long-range dependencies: information at position 1 must pass through
all intermediate positions to influence position 100
Transformer solution:
1. All positions attend to all other positions simultaneously → fully parallel
2. Direct paths from any position to any other → no vanishing over distance
3. Self-attention: O(n²) in sequence length but O(1) in path length between positions
4. Positional encoding: inject position information since attention has no inherent order
Result: Transformers scale better (more data + more compute = better models)
and handle long-range dependencies far more effectively.Scaled Dot-Product Attention
import torch
import torch.nn as nn
import math
def scaled_dot_product_attention(
Q: torch.Tensor, # (batch, heads, seq_q, d_k)
K: torch.Tensor, # (batch, heads, seq_k, d_k)
V: torch.Tensor, # (batch, heads, seq_k, d_v)
mask: torch.Tensor = None,
dropout: float = 0.0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Attention(Q, K, V) = softmax(QK.T / √d_k) · V
Q: query — what I'm looking for
K: key — what I have
V: value — what I return if found
The query-key dot product measures similarity.
Softmax converts to a probability distribution (attention weights).
Output is a weighted sum of values.
"""
d_k = Q.shape[-1]
# Similarity scores: (batch, heads, seq_q, seq_k)
scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)
# Optional masking (causal mask for decoder, padding mask for encoder)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
# Softmax attention weights
attn_weights = torch.softmax(scores, dim=-1)
# Apply dropout to attention weights (used during training)
if dropout > 0 and torch.is_grad_enabled():
attn_weights = torch.dropout(attn_weights, dropout, train=True)
# Weighted sum of values
output = attn_weights @ V # (batch, heads, seq_q, d_v)
return output, attn_weights
# Dimension check
batch, heads, seq_len, d_k, d_v = 2, 8, 10, 64, 64
Q = torch.randn(batch, heads, seq_len, d_k)
K = torch.randn(batch, heads, seq_len, d_k)
V = torch.randn(batch, heads, seq_len, d_v)
output, weights = scaled_dot_product_attention(Q, K, V)
print(f"Attention output: {output.shape}") # (2, 8, 10, 64)
print(f"Attention weights: {weights.shape}") # (2, 8, 10, 10) — seq × seq
print(f"Weights sum to 1? {weights.sum(-1).allclose(torch.ones(batch, heads, seq_len))}")Multi-Head Attention
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
"""
Multi-head attention: run attention h times in parallel with different projections.
Each head can attend to different aspects of the input.
"""
def __init__(self, d_model: int = 512, n_heads: int = 8, dropout: float = 0.1):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads # dimension per head
# Query, Key, Value projection matrices
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False) # output projection
self.dropout = nn.Dropout(dropout)
def forward(
self,
Q: torch.Tensor, # (batch, seq_q, d_model)
K: torch.Tensor, # (batch, seq_k, d_model)
V: torch.Tensor, # (batch, seq_k, d_model)
mask: torch.Tensor = None,
) -> torch.Tensor:
batch, seq_q, _ = Q.shape
# Project and reshape to (batch, heads, seq, d_k)
def project_and_split(x, W):
return W(x).view(batch, -1, self.n_heads, self.d_k).transpose(1, 2)
Q_proj = project_and_split(Q, self.W_Q) # (B, H, seq_q, d_k)
K_proj = project_and_split(K, self.W_K) # (B, H, seq_k, d_k)
V_proj = project_and_split(V, self.W_V) # (B, H, seq_k, d_k)
# Scaled dot-product attention
d_k = Q_proj.shape[-1]
scores = Q_proj @ K_proj.transpose(-2, -1) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
attn = self.dropout(torch.softmax(scores, dim=-1))
out = attn @ V_proj # (B, H, seq_q, d_k)
# Concatenate heads and project back
out = out.transpose(1, 2).contiguous().view(batch, seq_q, self.d_model)
return self.W_O(out)
# Test
mha = MultiHeadAttention(d_model=512, n_heads=8)
X = torch.randn(4, 20, 512) # (batch=4, seq_len=20, d_model=512)
out = mha(X, X, X) # self-attention: Q=K=V=X
print(f"MHA output: {out.shape}") # (4, 20, 512)Positional Encoding
import torch
import math
class PositionalEncoding(nn.Module):
"""
Sinusoidal positional encoding (Vaswani et al., 2017).
Adds position-dependent signal to token embeddings.
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
Properties:
- Unique encoding for each position
- Smooth variation for nearby positions
- Can generalise to longer sequences than seen in training
"""
def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(max_len).unsqueeze(1).float() # (max_len, 1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
) # (d_model/2,)
pe[:, 0::2] = torch.sin(position * div_term) # even dimensions
pe[:, 1::2] = torch.cos(position * div_term) # odd dimensions
pe = pe.unsqueeze(0) # (1, max_len, d_model) for broadcasting
self.register_buffer("pe", pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x: (batch, seq_len, d_model)"""
x = x + self.pe[:, :x.shape[1], :]
return self.dropout(x)
# Alternative: learned positional embeddings (used in BERT, GPT)
class LearnedPositionalEmbedding(nn.Module):
def __init__(self, max_positions: int, d_model: int):
super().__init__()
self.embed = nn.Embedding(max_positions, d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
positions = torch.arange(x.shape[1], device=x.device)
return x + self.embed(positions)Transformer Encoder Block
import torch
import torch.nn as nn
class TransformerEncoderBlock(nn.Module):
"""
One transformer encoder layer:
x → LayerNorm → MultiHeadAttention → + residual
x → LayerNorm → FeedForward → + residual
Pre-LayerNorm (modern): normalise BEFORE attention (more stable training)
Post-LayerNorm (original): normalise after residual
"""
def __init__(
self,
d_model: int = 512,
n_heads: int = 8,
d_ff: int = 2048,
dropout: float = 0.1,
):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(d_model, n_heads, dropout)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
# Self-attention sub-layer (pre-norm)
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x), mask)
# Feed-forward sub-layer (pre-norm)
x = x + self.ff(self.norm2(x))
return x
block = TransformerEncoderBlock(d_model=512, n_heads=8, d_ff=2048)
X = torch.randn(4, 20, 512)
out = block(X)
print(f"Encoder block: {X.shape} → {out.shape}") # (4, 20, 512) — same shapeBERT vs GPT Architecture
Encoder-only (BERT):
Bidirectional: attends to ALL tokens (left and right context)
Tasks: classification, NER, question answering
Pre-training: masked language modelling (predict [MASK] tokens)
Clinical: BioBERT, ClinicalBERT, PubMedBERT
Decoder-only (GPT):
Causal (autoregressive): attends only to previous tokens
Tasks: text generation, completion
Pre-training: next-token prediction
Clinical: Clinical GPT-style models for note generation
Encoder-decoder (T5, BART):
Encoder: bidirectional for input
Decoder: causal for generation
Tasks: translation, summarisation, question answering
Clinical: summarising discharge notes, coding ICD-10 from textInterview Answer
"Transformers replaced RNNs because they process all sequence positions in parallel (unlike RNNs that are sequential) and have O(1) path length between any two positions (unlike RNNs where information must traverse intermediate timesteps). The core mechanism is scaled dot-product attention: Attention(Q,K,V) = softmax(QK.T / √d_k) · V. Q (query) represents what each position is looking for; K (key) represents what each position offers; the dot product measures compatibility. Scaling by √d_k prevents softmax saturation when d_k is large. Multi-head attention runs h attention operations in parallel with independent projections, allowing each head to specialise in different relationship types. Positional encoding injects order information since attention is inherently permutation-invariant. The full transformer block: LayerNorm → Multi-Head Attention → residual → LayerNorm → FFN (Linear-GELU-Linear) → residual. BERT (encoder-only, bidirectional) is used for classification; GPT (decoder-only, causal) for generation; T5 (encoder-decoder) for seq-to-seq tasks."