Learnixo

Deep Learning for AI Interviews · Lesson 54 of 56

Transformers: Attention Is All You Need (Simplified)

Why Transformers Replaced RNNs

RNN problems:
  1. Sequential computation: each timestep depends on the previous.
     → Cannot parallelise → slow training on GPUs
  2. Vanishing gradients over long sequences (even with LSTM/GRU)
  3. Long-range dependencies: information at position 1 must pass through
     all intermediate positions to influence position 100

Transformer solution:
  1. All positions attend to all other positions simultaneously → fully parallel
  2. Direct paths from any position to any other → no vanishing over distance
  3. Self-attention: O(n²) in sequence length but O(1) in path length between positions
  4. Positional encoding: inject position information since attention has no inherent order

Result: Transformers scale better (more data + more compute = better models)
and handle long-range dependencies far more effectively.

Scaled Dot-Product Attention

Python
import torch
import torch.nn as nn
import math

def scaled_dot_product_attention(
    Q: torch.Tensor,   # (batch, heads, seq_q, d_k)
    K: torch.Tensor,   # (batch, heads, seq_k, d_k)
    V: torch.Tensor,   # (batch, heads, seq_k, d_v)
    mask: torch.Tensor = None,
    dropout: float = 0.0,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Attention(Q, K, V) = softmax(QK.T / √d_k) · V
    
    Q: query — what I'm looking for
    K: key   — what I have
    V: value — what I return if found
    
    The query-key dot product measures similarity.
    Softmax converts to a probability distribution (attention weights).
    Output is a weighted sum of values.
    """
    d_k = Q.shape[-1]
    
    # Similarity scores: (batch, heads, seq_q, seq_k)
    scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)
    
    # Optional masking (causal mask for decoder, padding mask for encoder)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))
    
    # Softmax attention weights
    attn_weights = torch.softmax(scores, dim=-1)
    
    # Apply dropout to attention weights (used during training)
    if dropout > 0 and torch.is_grad_enabled():
        attn_weights = torch.dropout(attn_weights, dropout, train=True)
    
    # Weighted sum of values
    output = attn_weights @ V   # (batch, heads, seq_q, d_v)
    
    return output, attn_weights

# Dimension check
batch, heads, seq_len, d_k, d_v = 2, 8, 10, 64, 64
Q = torch.randn(batch, heads, seq_len, d_k)
K = torch.randn(batch, heads, seq_len, d_k)
V = torch.randn(batch, heads, seq_len, d_v)

output, weights = scaled_dot_product_attention(Q, K, V)
print(f"Attention output: {output.shape}")   # (2, 8, 10, 64)
print(f"Attention weights: {weights.shape}") # (2, 8, 10, 10)  seq × seq
print(f"Weights sum to 1? {weights.sum(-1).allclose(torch.ones(batch, heads, seq_len))}")

Multi-Head Attention

Python
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    """
    Multi-head attention: run attention h times in parallel with different projections.
    Each head can attend to different aspects of the input.
    """
    
    def __init__(self, d_model: int = 512, n_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads   # dimension per head
        
        # Query, Key, Value projection matrices
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)   # output projection
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(
        self,
        Q: torch.Tensor,   # (batch, seq_q, d_model)
        K: torch.Tensor,   # (batch, seq_k, d_model)
        V: torch.Tensor,   # (batch, seq_k, d_model)
        mask: torch.Tensor = None,
    ) -> torch.Tensor:
        batch, seq_q, _ = Q.shape
        
        # Project and reshape to (batch, heads, seq, d_k)
        def project_and_split(x, W):
            return W(x).view(batch, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        Q_proj = project_and_split(Q, self.W_Q)   # (B, H, seq_q, d_k)
        K_proj = project_and_split(K, self.W_K)   # (B, H, seq_k, d_k)
        V_proj = project_and_split(V, self.W_V)   # (B, H, seq_k, d_k)
        
        # Scaled dot-product attention
        d_k = Q_proj.shape[-1]
        scores = Q_proj @ K_proj.transpose(-2, -1) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))
        attn = self.dropout(torch.softmax(scores, dim=-1))
        out = attn @ V_proj   # (B, H, seq_q, d_k)
        
        # Concatenate heads and project back
        out = out.transpose(1, 2).contiguous().view(batch, seq_q, self.d_model)
        return self.W_O(out)

# Test
mha = MultiHeadAttention(d_model=512, n_heads=8)
X = torch.randn(4, 20, 512)   # (batch=4, seq_len=20, d_model=512)
out = mha(X, X, X)             # self-attention: Q=K=V=X
print(f"MHA output: {out.shape}")   # (4, 20, 512)

Positional Encoding

Python
import torch
import math

class PositionalEncoding(nn.Module):
    """
    Sinusoidal positional encoding (Vaswani et al., 2017).
    Adds position-dependent signal to token embeddings.
    
    PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
    PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
    
    Properties:
    - Unique encoding for each position
    - Smooth variation for nearby positions
    - Can generalise to longer sequences than seen in training
    """
    
    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(max_len).unsqueeze(1).float()  # (max_len, 1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
        )  # (d_model/2,)
        
        pe[:, 0::2] = torch.sin(position * div_term)  # even dimensions
        pe[:, 1::2] = torch.cos(position * div_term)  # odd dimensions
        
        pe = pe.unsqueeze(0)   # (1, max_len, d_model) for broadcasting
        self.register_buffer("pe", pe)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """x: (batch, seq_len, d_model)"""
        x = x + self.pe[:, :x.shape[1], :]
        return self.dropout(x)

# Alternative: learned positional embeddings (used in BERT, GPT)
class LearnedPositionalEmbedding(nn.Module):
    def __init__(self, max_positions: int, d_model: int):
        super().__init__()
        self.embed = nn.Embedding(max_positions, d_model)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        positions = torch.arange(x.shape[1], device=x.device)
        return x + self.embed(positions)

Transformer Encoder Block

Python
import torch
import torch.nn as nn

class TransformerEncoderBlock(nn.Module):
    """
    One transformer encoder layer:
      x → LayerNorm → MultiHeadAttention → + residual
      x → LayerNorm → FeedForward → + residual
    
    Pre-LayerNorm (modern): normalise BEFORE attention (more stable training)
    Post-LayerNorm (original): normalise after residual
    """
    
    def __init__(
        self,
        d_model: int = 512,
        n_heads: int = 8,
        d_ff: int = 2048,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.attn  = MultiHeadAttention(d_model, n_heads, dropout)
        self.ff    = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        # Self-attention sub-layer (pre-norm)
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x), mask)
        # Feed-forward sub-layer (pre-norm)
        x = x + self.ff(self.norm2(x))
        return x

block = TransformerEncoderBlock(d_model=512, n_heads=8, d_ff=2048)
X = torch.randn(4, 20, 512)
out = block(X)
print(f"Encoder block: {X.shape} → {out.shape}")   # (4, 20, 512)  same shape

BERT vs GPT Architecture

Encoder-only (BERT):
  Bidirectional: attends to ALL tokens (left and right context)
  Tasks: classification, NER, question answering
  Pre-training: masked language modelling (predict [MASK] tokens)
  Clinical: BioBERT, ClinicalBERT, PubMedBERT

Decoder-only (GPT):
  Causal (autoregressive): attends only to previous tokens
  Tasks: text generation, completion
  Pre-training: next-token prediction
  Clinical: Clinical GPT-style models for note generation

Encoder-decoder (T5, BART):
  Encoder: bidirectional for input
  Decoder: causal for generation
  Tasks: translation, summarisation, question answering
  Clinical: summarising discharge notes, coding ICD-10 from text

Interview Answer

"Transformers replaced RNNs because they process all sequence positions in parallel (unlike RNNs that are sequential) and have O(1) path length between any two positions (unlike RNNs where information must traverse intermediate timesteps). The core mechanism is scaled dot-product attention: Attention(Q,K,V) = softmax(QK.T / √d_k) · V. Q (query) represents what each position is looking for; K (key) represents what each position offers; the dot product measures compatibility. Scaling by √d_k prevents softmax saturation when d_k is large. Multi-head attention runs h attention operations in parallel with independent projections, allowing each head to specialise in different relationship types. Positional encoding injects order information since attention is inherently permutation-invariant. The full transformer block: LayerNorm → Multi-Head Attention → residual → LayerNorm → FFN (Linear-GELU-Linear) → residual. BERT (encoder-only, bidirectional) is used for classification; GPT (decoder-only, causal) for generation; T5 (encoder-decoder) for seq-to-seq tasks."