Layer Normalization and Residual Connections
Pre-LN vs Post-LN transformer blocks; residual connections for gradient flow; RMSNorm in modern LLMs like LLaMA; code showing a complete Pre-LN transformer block.
Why Normalisation Matters
Deep networks suffer from internal covariate shift: the distribution of each layer's input changes as earlier layers update their weights. This forces later layers to constantly re-adapt, slowing convergence.
Layer Normalisation (Ba et al. 2016) normalises across the feature dimension for each sample independently, unlike Batch Norm which normalises across the batch.
Layer Norm Formula
Given a vector x of dimension d:
mu = mean(x) # scalar
sigma = std(x) # scalar
x_hat = (x - mu) / (sigma + epsilon) # normalised
y = gamma * x_hat + beta # affine transformWhere gamma and beta are learnable parameters initialised to 1 and 0 respectively. They allow the model to undo normalisation if needed.
import torch
import torch.nn as nn
import math
class LayerNorm(nn.Module):
"""Manual LayerNorm for educational clarity."""
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(d_model))
self.beta = nn.Parameter(torch.zeros(d_model))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (..., d_model) — normalise over last dimension
mu = x.mean(dim=-1, keepdim=True)
sigma = x.std(dim=-1, keepdim=True, unbiased=False)
x_hat = (x - mu) / (sigma + self.eps)
return self.gamma * x_hat + self.beta
# Compare with PyTorch built-in
ln_manual = LayerNorm(256)
ln_torch = nn.LayerNorm(256)
x = torch.randn(4, 32, 256)
out_m = ln_manual(x)
out_t = ln_torch(x)
print("Manual:", out_m.shape) # (4, 32, 256)
print("Max diff:", (out_m - out_t).abs().max().item()) # Very small (fp precision)Residual Connections
A residual connection (He et al. 2015, originally from ResNets) adds the input of a sub-layer back to its output:
output = sublayer(x) + xThis creates a shortcut path for gradients to flow backward without passing through the sublayer. Without residuals, gradients in a 100-layer network would vanish or explode.
Mathematical insight: if sublayer(x) = 0 (worst case), the output is still x. The network learns incremental updates rather than full transformations.
def residual_block_example():
"""Illustrate gradient flow with and without residuals."""
d = 64
n_layers = 20
# Without residuals: gradient multiplied by Jacobian at each layer
# With residuals: gradient = 1 + Jacobian (identity shortcut)
x = torch.randn(1, d, requires_grad=True)
linear = nn.Linear(d, d)
# Without residual
y = x
for _ in range(n_layers):
y = torch.tanh(linear(y))
y.sum().backward()
print("Without residual, grad norm:", x.grad.norm().item())
x.grad = None
# With residual
y = x
for _ in range(n_layers):
y = y + torch.tanh(linear(y))
y.sum().backward()
print("With residual, grad norm:", x.grad.norm().item())Post-LN vs Pre-LN
Original Transformer: Post-LN
The "Attention Is All You Need" paper placed LayerNorm after the residual addition:
x = LayerNorm( x + Attention(x) )
x = LayerNorm( x + FFN(x) )Problem: in early training, the gradients at the output layer flow back through LayerNorm at each layer. For deep networks (more than 12 layers), this leads to unstable training — the learning rate must be carefully scheduled with warm-up.
Modern Transformer: Pre-LN
Pre-LN (child et al., GPT-2 and beyond) places LayerNorm before the sublayer:
x = x + Attention(LayerNorm(x))
x = x + FFN(LayerNorm(x))This keeps the residual stream in its original scale. Gradients bypass the normalisation on the main path, making training significantly more stable.
class PostLNBlock(nn.Module):
"""Original Transformer: LayerNorm after residual."""
def __init__(self, d_model: int, n_heads: int, d_ff: int):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model))
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
x = self.norm1(x + self.attn(x, x, x, need_weights=False)[0])
x = self.norm2(x + self.ff(x))
return x
class PreLNBlock(nn.Module):
"""Modern Transformer: LayerNorm before sublayer (Pre-LN)."""
def __init__(self, d_model: int, n_heads: int, d_ff: int):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model))
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
# Norm before: the residual stream stays un-normalised
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x), need_weights=False)[0]
x = x + self.ff(self.norm2(x))
return xRMSNorm: Simpler and Faster
RMSNorm (Zhang & Sennrich 2019) removes the mean-subtraction step entirely, only normalising by the root mean square of the activations:
rms(x) = sqrt( mean(x²) + epsilon )
y = (x / rms(x)) * gammaNo beta (bias) parameter needed. This is roughly 15% faster than LayerNorm and used in LLaMA, Mistral, Qwen, and most modern LLMs.
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization — used in LLaMA."""
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(d_model))
def _rms(self, x: torch.Tensor) -> torch.Tensor:
return x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x / self._rms(x) * self.gamma
# Compare LayerNorm vs RMSNorm speed
import time
d_model = 4096
batch = 8
seq = 2048
x = torch.randn(batch, seq, d_model)
ln = nn.LayerNorm(d_model)
rms = RMSNorm(d_model)
# Warmup
for _ in range(10):
_ = ln(x)
_ = rms(x)
N = 100
t0 = time.perf_counter()
for _ in range(N):
_ = ln(x)
t_ln = (time.perf_counter() - t0) / N * 1000
t0 = time.perf_counter()
for _ in range(N):
_ = rms(x)
t_rms = (time.perf_counter() - t0) / N * 1000
print(f"LayerNorm: {t_ln:.2f} ms")
print(f"RMSNorm: {t_rms:.2f} ms")Full Pre-LN Transformer Block with RMSNorm
import torch.nn.functional as F
class LLaMABlock(nn.Module):
"""
LLaMA-style transformer block:
- Pre-LN with RMSNorm (no beta bias)
- SwiGLU feed-forward (covered in feedforward lesson)
- Causal self-attention
"""
def __init__(self, d_model: int, n_heads: int, d_ff: int, max_len: int = 4096):
super().__init__()
assert d_model % n_heads == 0
self.d_k = d_model // n_heads
self.n_heads = n_heads
# Attention
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
# SwiGLU feed-forward: two gates
self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
self.up_proj = nn.Linear(d_model, d_ff, bias=False)
self.down_proj = nn.Linear(d_ff, d_model, bias=False)
# Pre-LN with RMSNorm
self.attn_norm = RMSNorm(d_model)
self.ff_norm = RMSNorm(d_model)
# Causal mask
mask = torch.triu(torch.ones(max_len, max_len, dtype=torch.bool), diagonal=1)
self.register_buffer("causal_mask", mask)
def _split(self, t):
B, T, _ = t.shape
return t.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
def _merge(self, t):
B, h, T, d_k = t.shape
return t.transpose(1, 2).contiguous().view(B, T, h * d_k)
def _attention(self, x: torch.Tensor) -> torch.Tensor:
B, T, _ = x.shape
Q = self._split(self.W_Q(x))
K = self._split(self.W_K(x))
V = self._split(self.W_V(x))
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
scores = scores.masked_fill(self.causal_mask[:T, :T], float('-inf'))
attn = F.softmax(scores, dim=-1)
return self.W_O(self._merge(torch.matmul(attn, V)))
def _swiglu(self, x: torch.Tensor) -> torch.Tensor:
# SwiGLU: gate(x) * silu(up(x))
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Pre-LN attention with residual
x = x + self._attention(self.attn_norm(x))
# Pre-LN feed-forward with residual
x = x + self._swiglu(self.ff_norm(x))
return x
# Test the full block
block = LLaMABlock(d_model=256, n_heads=8, d_ff=688)
x = torch.randn(2, 16, 256)
out = block(x)
print("LLaMA block output:", out.shape) # (2, 16, 256)Comparison of Normalisation Strategies
| Strategy | Norm position | Formula | Used by | |----------|--------------|---------|---------| | Post-LN | After residual | (x + sub) → norm | Original Transformer | | Pre-LN | Before sublayer | x + sub(norm(x)) | GPT-2, GPT-3 | | Pre-LN + RMSNorm | Before sublayer | x + sub(rms_norm(x)) | LLaMA, Mistral |
Key Takeaways
- LayerNorm normalises each sample over the feature dimension, stabilising activations across layers.
- Residual connections give gradients a direct path backward, preventing vanishing gradients in deep networks.
- Pre-LN (norm before attention/FFN) is more stable for very deep networks than the original Post-LN.
- RMSNorm removes mean subtraction and the beta parameter — faster and used in modern LLMs.
- The combination of Pre-LN + RMSNorm + residuals is the training-stability foundation of LLaMA and similar models.
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.