Interview: Transformer Architecture (Part 2)
10 more senior-level questions: KV cache, quantization, speculative decoding, scaling laws, MoE, and system design with transformer-based models.
Q1: How does the KV cache work and what are its memory costs?
Answer: During autoregressive generation, each token attends to all previous tokens. Without caching, generating token N requires a full forward pass over all N-1 previous tokens ā O(N²) total operations.
The KV cache stores the Key and Value projections from all previous tokens. When generating token N, only token N's Q, K, V need to be computed; all previous K and V vectors are retrieved from cache. Only the new token runs through the network.
Memory cost per token:
bytes = 2 Ć num_layers Ć num_kv_heads Ć head_dim Ć dtype_bytesFor LLaMA-3-8B (32 layers, 8 KV heads, head_dim=128, float16):
- Bytes per token = 2 Ć 32 Ć 8 Ć 128 Ć 2 = 131,072 bytes = 128 KB
- For 8192-token context: 128KB Ć 8192 = 1GB per concurrent request
This is why GQA (fewer KV heads than Q heads) is critical for serving: LLaMA-3-8B uses 8 KV heads vs 32 Q heads, reducing KV cache by 4Ć compared to standard MHA.
Q2: What is the difference between greedy decoding, beam search, and sampling?
Answer:
Greedy: At each step, select the highest-probability token. Fast, deterministic, but often suboptimal for longer sequences ā a sequence of locally-best tokens may not be globally optimal.
Beam search: Maintain K candidate sequences ("beams") at each step, expand each by all possible next tokens, keep the K highest-scoring sequences. More thorough than greedy, but more compute. Good for tasks with clear correct answers (translation, summarization). Problem: tends to generate repetitive, "safe" text.
Sampling: Sample from the probability distribution rather than taking the argmax. Introduces stochasticity. Combined with temperature, top-k, and top-p filtering for quality control.
Top-k sampling: Truncate to the top K tokens before sampling. Prevents extremely unlikely tokens from being selected.
Nucleus (top-p) sampling: Truncate to the minimum set of tokens whose cumulative probability exceeds p. More adaptive than fixed K ā uses more tokens when the distribution is flat, fewer when it's peaked.
For production LLMs: do_sample=True, temperature=0.7, top_p=0.9 is a common default for chat. do_sample=False (greedy) for code and structured output.
Q3: You need to serve a 70B LLM to 1000 concurrent users on a budget. What's your architecture?
Answer:
Model configuration:
- Quantize to 4-bit AWQ/GPTQ ā reduces from ~140GB to ~35GB
- Use tensor parallelism across 4Ć A100-80GB GPUs (35GB / 4 ā 9GB per GPU, fits with room for KV cache)
- GQA is critical: 8 KV heads instead of 64 minimizes KV cache memory per user
Serving framework:
- Use vLLM with PagedAttention ā critical for handling 1000 concurrent users efficiently
- PagedAttention allocates KV cache in fixed pages (like virtual memory), preventing fragmentation
- Continuous batching: don't wait for all requests to finish before starting new ones ā interleave dynamically
KV cache budget:
- With 4-bit quantized KV cache: ~16KB per token instead of 128KB
- 1000 users Ć 2048 avg tokens Ć 16KB = 32GB ā manageable on 4Ć80GB GPUs
Optimizations:
- Prefix caching: cache the system prompt's KV pairs (same across all users)
- Speculative decoding with a 7B draft model: 2Ć throughput improvement at same quality
- Tensor parallelism within the node; no need for pipeline parallelism at 70B
Scaling: This 4-GPU node handles ~200 concurrent users at 50 tokens/second each. 5 such nodes handle 1000 users with redundancy.
Q4: What is catastrophic forgetting and how do LoRA adapters mitigate it?
Answer: Catastrophic forgetting: when a neural network is fine-tuned on new data, gradient updates can overwrite learned representations for the original task. The model "forgets" prior knowledge while learning the new distribution.
Why full fine-tuning causes it:
- Gradient updates modify all weights equally
- Weights that encode world knowledge (in FFN layers) get overwritten by domain-specific updates
- The further the fine-tuning distribution is from pretraining, the more forgetting occurs
LoRA's mitigation mechanism:
- The pretrained weights are frozen ā they cannot be modified
- Only the small adapter matrices (rank decomposition ĪW = A Ć B) are trained
- The original knowledge is preserved in frozen weights
- Domain adaptation is encoded in the delta (ĪW), which is additive
Additionally, LoRA's low-rank constraint limits the "expressiveness" of the update ā the adapter can learn domain-specific patterns but cannot learn to fully replace the base representations.
For tasks requiring heavy adaptation (converting a general model to a highly specialized clinical model), some forgetting is acceptable. For tasks requiring both general and specialized ability, LoRA is preferred.
Q5: Explain Flash Attention's key insight and what problem it solves.
Answer: Standard attention stores the full NĆN attention weight matrix in GPU high-bandwidth memory (HBM). For long sequences, this is prohibitive: N=128k ā 128k Ć 128k Ć 2 bytes = 32GB just for attention weights.
Flash Attention's insight: Attention doesn't need to be computed all at once. The computation can be tiled into blocks that fit in SRAM (the fast on-chip memory). Using the online softmax trick (iteratively updating running max and sum), the output can be computed incrementally without ever materializing the full NĆN matrix in HBM.
What it achieves:
- Memory: O(N²) ā O(N) in HBM (tiles fit in SRAM, not persisted to HBM)
- Speed: 2ā4Ć faster wall-clock due to fewer HBM reads/writes (the bottleneck)
- Mathematically identical output to standard attention ā not an approximation
The IO complexity key insight: The attention bottleneck is memory bandwidth (HBM reads/writes), not arithmetic FLOPs. The GPU is often waiting for data, not computing. Flash Attention reduces HBM traffic by computing in tiles ā the same FLOPs, but fewer memory roundtrips.
Q6: What is the Chinchilla scaling law and how should it affect model selection decisions?
Answer: Chinchilla (Hoffmann et al., 2022) found that compute-optimal training uses approximately 20 tokens per parameter:
For a fixed compute budget: N* Ć D* = C / 6
Optimal: D* / N* ā 20GPT-3 (175B params, 300B tokens = 1.7 tokens/param) was massively under-trained. The Chinchilla-optimal for that compute budget was a ~70B model on 1.4T tokens.
Practical implications:
-
For training: Given your compute budget, don't train a very large model on little data. Train a smaller model longer ā it achieves the same loss at lower inference cost.
-
For inference-heavy deployment: Modern practice trains beyond Chinchilla optimal. LLaMA-3-8B uses 15T tokens (1875 tokens/param) ā far beyond the Chinchilla-optimal 160B tokens. This produces a smaller model with better quality than a compute-optimal model, because serving costs are much lower for 8B vs 70B.
-
For model selection: Don't equate model size with quality. A 7B model trained on 15T tokens (LLaMA-3-7B) can outperform a 70B model trained on 1T tokens (original LLaMA-65B) on many benchmarks.
Q7: How does Mixture of Experts achieve better quality without proportionally more compute?
Answer: MoE replaces the dense FFN with N expert FFNs and a router that selects K experts per token (typically K=2, N=8).
Why it works:
- Each expert specializes in a subset of the data distribution ā some experts handle syntactic patterns, others handle domain-specific knowledge, others handle different reasoning types
- Because each expert sees only the tokens routed to it, it can learn more focused, specialized representations
- Total model capacity scales with N (more expert parameters), but active compute per token scales with K (only K experts fire)
The key equation:
Capacity = N Ć expert_params (large)
Active compute = K Ć expert_params (smaller by factor N/K)For Mixtral-8Ć7B: total 47B params, active ~13B params per token. A 47B dense model would require 3.6Ć more compute per token.
Limitations: Expert parallelism requires routing tokens to different GPUs (all-to-all communication), which adds latency. Load balancing is needed to prevent expert collapse (all tokens routing to 2 experts). Memory is still proportional to total parameters.
Q8: What is speculative decoding and when is it most effective?
Answer: A small "draft" model proposes K tokens; the large "target" model verifies all K tokens in a single parallel forward pass. Accepted tokens are kept; the first rejection is corrected by resampling from the target's distribution.
Why it's faster: The large model's forward pass takes the same time whether processing 1 or K tokens (up to memory limits). If K=4 tokens are proposed and the acceptance rate is 80%, approximately 3.3 tokens are generated per target forward pass ā vs 1 token without speculative decoding.
Most effective when:
- Draft and target share the same architecture family (Llama-3-7B draft for Llama-3-70B target)
- The task is factual/constrained (code, Q&A) ā high acceptance rate
- Temperature is low (more deterministic distribution ā draft matches target better)
Least effective when:
- Creative writing at high temperature ā distributions diverge, acceptance rate drops
- Draft and target are from different training distributions
- Very short outputs (overhead of draft model dominates)
Production usage: vLLM and TGI support speculative decoding natively. Typical speedup: 2ā3Ć with minimal quality change.
Q9: System design ā build a real-time medical transcription system that extracts structured drug information
Scenario: Physicians speak during patient visits. System must transcribe speech, extract drug mentions, dosages, and interactions, and return structured JSON in under 2 seconds.
Answer:
Pipeline:
Audio stream ā Whisper (streaming ASR) ā Text chunks ā LLM extraction ā JSON outputComponent selection:
- ASR: Whisper v3 large with streaming (process every 3-second audio chunk)
- Extraction LLM: A quantized 7B model (LLaMA-3-8B-Instruct in AWQ-4bit) for fast extraction
- Output format: Function calling or constrained generation to guarantee valid JSON
Latency budget (2 second total):
- ASR: ~200ms for 3-second chunk (Whisper on A10G)
- LLM extraction: ~500ms for 7B model on extracted text (typically 50ā200 tokens output)
- Network + parsing: ~100ms
- Total: ~800ms, well within 2 seconds
Structured output:
from pydantic import BaseModel
class DrugExtraction(BaseModel):
drug_name: str
dose_mg: float | None
frequency: str | None
route: str | None
indication: str | None
interactions_flagged: list[str]Use outlines library or vLLM structured output to constrain generation to valid DrugExtraction JSON.
Scaling: Deploy LLM on 1Ć A10G (24GB). AWQ-quantized 7B fits in ~4GB, leaving 20GB for KV cache and batching. Can handle 50 concurrent physician sessions.
Q10: How do you evaluate whether a transformer model is "aligned" vs just "capable"?
Answer: Capability and alignment are orthogonal dimensions:
- Capable but misaligned: Knows the correct answer but gives a more agreeable wrong answer to please the user
- Aligned but incapable: Refuses clearly harmful requests but also refuses legitimate tasks
- Both: Accurately answers factual questions and appropriately refuses harmful requests
Capability evaluation:
- Standardized benchmarks: MMLU, ARC, HellaSwag (factual recall, reasoning)
- Domain-specific tests (medical USMLE, coding HumanEval)
- Long-context tasks: document QA, code review
Alignment evaluation:
- Truthfulness: TruthfulQA, model-based fact-checking
- Refusal calibration: Does the model refuse harmful requests? Does it also refuse legitimate medical questions it shouldn't?
- Sycophancy tests: Present a wrong answer confidently and see if the model agrees or corrects
- Instruction following: IFEval ā measures exact compliance with structured instructions
- Safety evals: Adversarial red-teaming, prompt injection resistance
The key tension: Over-alignment (too many refusals, excessive caveats) reduces utility. Under-alignment allows harm. RLHF/DPO training tries to optimize both simultaneously, but the optimal balance depends on the deployment context.
For medical AI: capability (accurate pharmacology knowledge) is necessary but insufficient. The model must also be calibrated ā knowing when to say "I'm not certain" rather than generating confident misinformation.
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.