Transformer Architecture Q&A · Lesson 17 of 23
Layer Normalization vs Batch Normalization
What Layer Norm Does
Layer normalisation normalises across the feature dimension of a single sample:
Input x ∈ ℝ^d (one token's representation)
mean: μ = (1/d) Σᵢ xᵢ
variance: σ² = (1/d) Σᵢ (xᵢ - μ)²
normalised: x̂ᵢ = (xᵢ - μ) / √(σ² + ε)
output: yᵢ = γᵢ · x̂ᵢ + βᵢγ (scale) and β (shift) are learned parameters — one per feature dimension. They allow the model to undo the normalisation if needed. ε (typically 1e-5) prevents division by zero.
Batch Norm vs Layer Norm
Batch Norm: normalises across the BATCH dimension
For feature i: uses mean/variance computed over all samples in the batch
Pros: works well for CNNs, image tasks
Cons: depends on batch size; behaves differently at train vs inference;
problematic for variable-length sequences
Layer Norm: normalises across the FEATURE dimension
For sample s: uses mean/variance computed over all features of that sample
Pros: independent of batch size; same behaviour at train/inference;
works naturally with sequences of any length
Cons: slightly less stable in some settings (addressed by RMSNorm)Layer norm is the natural choice for transformers because sequences have variable length and features (not samples) define the representation.
Code
import torch
import torch.nn as nn
class LayerNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-5):
super().__init__()
self.gamma = nn.Parameter(torch.ones(d_model))
self.beta = nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (batch, seq_len, d_model)
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
x_norm = (x - mean) / torch.sqrt(var + self.eps)
return self.gamma * x_norm + self.beta
# PyTorch built-in:
layer_norm = nn.LayerNorm(d_model=512)Post-LN vs Pre-LN
The original Transformer uses Post-LayerNorm (normalize after adding the residual):
Post-LN (Vaswani et al. 2017):
output = LayerNorm(x + SubLayer(x))
Pre-LN (GPT, LLaMA):
output = x + SubLayer(LayerNorm(x))Why Pre-LN dominates modern models:
Post-LN problem:
Gradients at the beginning of training pass through the summation
before the LayerNorm. For deep models (N=96+ layers), gradients
explode or vanish without careful warmup scheduling.
Pre-LN advantage:
SubLayer(LayerNorm(x)) is normalised → bounded gradients
The residual path x is always clean — gradients flow freely
Trains stably without elaborate warmup, even at depth
Slight quality cost vs Post-LN at convergence (marginal in practice)RMSNorm: Simplifying Layer Norm
LLaMA and many modern models use RMSNorm, which removes the mean-centering step:
RMSNorm(x) = x / RMS(x) · γ
where RMS(x) = √(1/d Σᵢ xᵢ²)
No β (shift) parameter — only γ (scale)class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(d_model))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return (x / rms) * self.gammaRMSNorm is faster (no mean computation) and empirically equivalent to LayerNorm on large models.
Role in Transformer Training Stability
Without normalisation:
Deep transformer (24+ layers): gradients explode or vanish
Internal covariate shift: each layer's input distribution shifts
as weights in earlier layers change → slow, unstable training
With Pre-LN:
Each sublayer input has mean ≈ 0, variance ≈ 1 at each step
Gradients are bounded and well-conditioned
Models with 96, 128, even 200+ layers train stablyInterview Answer
"Layer norm normalises each token's representation across its feature dimensions — computing per-sample mean and variance over the d_model features, then applying a learned scale γ and shift β. Unlike batch norm, it's independent of batch size and handles variable-length sequences naturally. The original Transformer uses Post-LN (normalise after residual); modern models use Pre-LN (normalise before the sublayer) because it gives more stable gradients without warmup tricks. LLaMA and Mistral use RMSNorm, which drops the mean-centering step for speed with equivalent quality."