Learnixo
Back to blog
AI Systemsintermediate

Interview: Transformer Architecture (Part 1)

10 senior-level questions on transformer internals: attention mechanics, positional encodings, normalization, and architectural design choices.

Asma Hafeez KhanMay 16, 20268 min read
TransformersInterviewArchitectureDeep Learning
Share:š•

Q1: Why does scaled dot-product attention divide by sqrt(d_k)?

Answer: Without scaling, the dot products QK^T grow in magnitude with d_k. For large d_k, the softmax receives very large inputs, which pushes it into regions of near-zero gradient (the function becomes nearly step-like — the max element dominates and everything else collapses to zero).

Dividing by sqrt(d_k) keeps the dot products in a range where softmax gradients are healthy. Intuitively: if Q and K have entries drawn from a zero-mean unit-variance distribution, then QK^T (a dot product of d_k terms) has variance d_k. Dividing by sqrt(d_k) restores unit variance.

Python
import torch, torch.nn.functional as F

d_k = 512
q = torch.randn(1, d_k)
k = torch.randn(10, d_k)

dot_products = q @ k.T  # (1, 10)
print(f"Max without scale: {dot_products.max():.1f}")
print(f"Max with scale:    {(dot_products / d_k**0.5).max():.1f}")

# Without scale: softmax is very peaked (effectively argmax)
print(F.softmax(dot_products, dim=-1))      # Near [0, 0, 1, 0, 0]
print(F.softmax(dot_products / d_k**0.5, dim=-1))  # More spread out

Q2: What is the difference between multi-head attention and single-head attention, and why does multi-head help?

Answer: Single-head attention uses one set of Q, K, V projections. Multi-head attention uses H independent sets of projections (heads), each with smaller head dimension (d_k = d_model / H), then concatenates outputs and projects back to d_model.

Why multi-head helps: each head can learn to attend to different types of relationships simultaneously. In practice:

  • Head A might learn syntactic relationships (verbs attending to subjects)
  • Head B might learn semantic relationships (a drug name attending to its mechanism)
  • Head C might learn positional relationships (each token attending to its neighbors)

Single-head attention forces one attention pattern per layer; multi-head allows the model to simultaneously extract multiple types of structure. Empirically, multi-head consistently outperforms single-head at the same parameter count.


Q3: What is the residual connection's role in transformers?

Answer: The residual (skip) connection adds the input directly to the output of each sub-layer:

output = x + Sublayer(Norm(x))

Two key functions:

  1. Gradient flow: In deep networks, gradients must pass through many layers. Multiplying many small gradients leads to vanishing gradients. Residual connections create "gradient highways" — the gradient can flow directly through the addition without passing through any saturated activation function.

  2. Identity preservation: The sublayer only needs to learn a small "correction" to the input (the residual), not a full transformation. This is easier to learn and more stable. Early in training, the sublayer output is near-zero, so the model starts with near-identity mappings and gradually refines them.

Without residual connections, transformers deeper than 6–8 layers are difficult to train. With them, 96-layer models (GPT-3) are stable.


Q4: Explain pre-norm vs post-norm. Why did the field shift to pre-norm?

Answer:

Post-norm (original transformer): x = Norm(x + Sublayer(x))

  • Layer normalization is applied after the residual addition
  • The output has normalized statistics, which is good for the next layer's inputs
  • But the gradient flowing backward must pass through the normalization — which can cause instability at depth

Pre-norm (modern transformers, LLaMA): x = x + Sublayer(Norm(x))

  • Normalization is applied to the input before the sublayer
  • The residual path is clean — gradients flow through the addition unmodified
  • Training is more stable, especially for deep networks and large learning rates

The shift happened after empirical observation that pre-norm models are much more stable at scale. Post-norm can achieve better final accuracy if carefully tuned (warm-up learning rate schedules help), but pre-norm is more robust for large-scale training where instability wastes compute.


Q5: What is the computational complexity of self-attention and why does it matter?

Answer: Self-attention is O(n² Ɨ d) in both time and memory, where n is sequence length and d is the model dimension.

  • For each of n tokens, we compute attention weights over all n tokens: n² operations
  • Each attention weight computation involves a d-dimensional dot product
  • The attention matrix has n² elements — each must be stored in memory

Why it matters:

  • Memory: attention matrix for n=128k is 128k Ɨ 128k Ɨ 2 bytes ā‰ˆ 32GB per layer per head. This is why long-context models require Flash Attention or sparse attention
  • Time: doubling sequence length quadruples attention computation time
  • Inference scaling: generating 1000 tokens requires O(n²) attention operations total (each new token attends to all previous tokens)

Flash Attention solves the memory problem (O(n) memory via tiling) but not the time complexity (O(n²) FLOPs). Sparse attention (sliding window) reduces both to O(n Ɨ window_size).


Q6: How does feed-forward network know what to do for each position? Doesn't it just see one token at a time?

Answer: The FFN processes each position independently (position-wise), but by the time a token reaches the FFN, its representation has already been enriched by attention. The attention sub-layer has gathered information from relevant tokens across the sequence and incorporated it into the token's representation.

Original token 5 representation: just "warfarin" embedding
After attention: "warfarin" embedding + context from "patient prescribed [mask]..."
After FFN: transforms this enriched representation using learned knowledge about warfarin

The FFN's role is to transform the attention-enriched representation — applying learned "knowledge" to produce a better next-layer representation. Think of it as: attention determines what context is relevant, FFN applies learned patterns to that contextualized representation.

This is why FFN layers are often called "memory" layers — they store factual knowledge (drug mechanisms, world facts) in their weights.


Q7: What makes layer normalization different from batch normalization, and why does LN work better for transformers?

Answer:

Batch normalization: Normalizes across the batch dimension for each feature:

μ_feature = mean(x[:, feature])  # across batch
BN(x[sample, feature]) = (x[sample, feature] - μ_feature) / σ_feature

Depends on batch statistics — behavior differs between training and inference, and is unstable with small batches.

Layer normalization: Normalizes across the feature dimension for each sample:

μ_sample = mean(x[sample, :])  # across features
LN(x[sample, feature]) = (x[sample, feature] - μ_sample) / σ_sample

Independent of batch size. Same behavior in training and inference.

Transformers use variable-length sequences with masking, making batch normalization statistics noisy. Layer normalization works on each example independently and is insensitive to sequence length, padding, and batch size — essential for language modeling. It also generalizes to one-shot and few-shot settings where batch size might be 1.


Q8: What is weight tying between embedding and output projection layers?

Answer: The token embedding matrix (vocabulary size Ɨ d_model) and the final output projection matrix (d_model Ɨ vocabulary size) are mathematically transposed versions of each other. Weight tying shares the same tensor for both:

Python
embedding_weight = nn.Parameter(torch.randn(vocab_size, d_model))
# Token lookup: x = embedding_weight[token_id]
# Final projection: logits = hidden_state @ embedding_weight.T

Why it works: Both layers are learning the same semantic space. The embedding lookup finds a representation for a token; the output projection is checking how well a hidden state matches each token's representation. Sharing weights forces consistency.

Benefits:

  • Saves vocab_size Ɨ d_model parameters (for GPT-2: 50k Ɨ 768 = 38M parameters)
  • Empirically improves training — the model can't learn inconsistent representations
  • Regularizes the output projection using the same signal as the embedding layer

Used in: GPT-2, BERT (partially), T5. Not always used in LLaMA (depends on version).


Q9: Explain the role of the softmax temperature and when you would adjust it at inference.

Answer: Temperature Ļ„ divides logits before softmax, controlling distribution sharpness:

Python
def sample_with_temperature(logits, temperature=1.0):
    scaled_logits = logits / temperature
    probs = torch.softmax(scaled_logits, dim=-1)
    return torch.multinomial(probs, 1)

# temperature < 1: sharper distribution, model is more "confident", less diverse
# temperature = 1: model's true learned distribution
# temperature > 1: flatter distribution, more diverse but less coherent

When to adjust:

  • T = 0 (greedy): Best for factual Q&A, code generation — maximize likelihood
  • T ā‰ˆ 0.1–0.3: Deterministic-ish, good for structured output
  • T ā‰ˆ 0.7–1.0: Balanced quality and diversity, good for general chat
  • T > 1.0: Creative writing, brainstorming — more variety but more hallucination

Temperature is typically combined with top-k (truncate to top K candidates) and top-p (nucleus sampling: truncate to candidates covering P% of probability mass).


Q10: What happens when a transformer model processes a text that exceeds its context window?

Answer: The model physically cannot process more tokens than its context window in a single forward pass. Options when text exceeds the limit:

Truncation: Simply discard tokens beyond the limit. Which end to truncate from matters — in RAG, the retrieved documents are usually truncated from the end (keep the beginning). For conversations, the oldest messages are dropped.

Chunking: Split the document and process each chunk separately. For question answering, you can process each chunk and aggregate answers.

Sliding window: Process overlapping chunks. The end of chunk 1 overlaps with the beginning of chunk 2 to preserve cross-boundary context.

Hierarchical processing: Summarize each chunk first (with a smaller context window), then process the summaries (which fit in the context).

Long context models: Use a model with a larger context window (200k tokens). Still has memory and quality limits at extreme lengths due to the "lost in the middle" phenomenon.

The key practical rule: structure your context so the most important information appears at the beginning (system prompt) and end (the query) — models attend best to these positions.

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:š•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.