Learnixo

Transformer Architecture Q&A · Lesson 17 of 23

Layer Normalization vs Batch Normalization

What Layer Norm Does

Layer normalisation normalises across the feature dimension of a single sample:

Input x ∈ ℝ^d  (one token's representation)

mean: μ = (1/d) Σᵢ xᵢ
variance: σ² = (1/d) Σᵢ (xᵢ - μ)²

normalised: x̂ᵢ = (xᵢ - μ) / √(σ² + ε)

output: yᵢ = γᵢ · x̂ᵢ + βᵢ

γ (scale) and β (shift) are learned parameters — one per feature dimension. They allow the model to undo the normalisation if needed. ε (typically 1e-5) prevents division by zero.


Batch Norm vs Layer Norm

Batch Norm: normalises across the BATCH dimension
  For feature i: uses mean/variance computed over all samples in the batch
  Pros: works well for CNNs, image tasks
  Cons: depends on batch size; behaves differently at train vs inference;
        problematic for variable-length sequences

Layer Norm: normalises across the FEATURE dimension
  For sample s: uses mean/variance computed over all features of that sample
  Pros: independent of batch size; same behaviour at train/inference;
        works naturally with sequences of any length
  Cons: slightly less stable in some settings (addressed by RMSNorm)

Layer norm is the natural choice for transformers because sequences have variable length and features (not samples) define the representation.


Code

Python
import torch
import torch.nn as nn

class LayerNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta  = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch, seq_len, d_model)
        mean = x.mean(dim=-1, keepdim=True)
        var  = x.var(dim=-1, keepdim=True, unbiased=False)
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * x_norm + self.beta

# PyTorch built-in:
layer_norm = nn.LayerNorm(d_model=512)

Post-LN vs Pre-LN

The original Transformer uses Post-LayerNorm (normalize after adding the residual):

Post-LN (Vaswani et al. 2017):
  output = LayerNorm(x + SubLayer(x))

Pre-LN (GPT, LLaMA):
  output = x + SubLayer(LayerNorm(x))

Why Pre-LN dominates modern models:

Post-LN problem:
  Gradients at the beginning of training pass through the summation
  before the LayerNorm. For deep models (N=96+ layers), gradients
  explode or vanish without careful warmup scheduling.

Pre-LN advantage:
  SubLayer(LayerNorm(x)) is normalised → bounded gradients
  The residual path x is always clean — gradients flow freely
  Trains stably without elaborate warmup, even at depth
  Slight quality cost vs Post-LN at convergence (marginal in practice)

RMSNorm: Simplifying Layer Norm

LLaMA and many modern models use RMSNorm, which removes the mean-centering step:

RMSNorm(x) = x / RMS(x) · γ

where RMS(x) = √(1/d Σᵢ xᵢ²)

No β (shift) parameter — only γ (scale)
Python
class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return (x / rms) * self.gamma

RMSNorm is faster (no mean computation) and empirically equivalent to LayerNorm on large models.


Role in Transformer Training Stability

Without normalisation:
  Deep transformer (24+ layers): gradients explode or vanish
  Internal covariate shift: each layer's input distribution shifts
    as weights in earlier layers change → slow, unstable training

With Pre-LN:
  Each sublayer input has mean ≈ 0, variance ≈ 1 at each step
  Gradients are bounded and well-conditioned
  Models with 96, 128, even 200+ layers train stably

Interview Answer

"Layer norm normalises each token's representation across its feature dimensions — computing per-sample mean and variance over the d_model features, then applying a learned scale γ and shift β. Unlike batch norm, it's independent of batch size and handles variable-length sequences naturally. The original Transformer uses Post-LN (normalise after residual); modern models use Pre-LN (normalise before the sublayer) because it gives more stable gradients without warmup tricks. LLaMA and Mistral use RMSNorm, which drops the mean-centering step for speed with equivalent quality."