Learnixo
Back to blog
AI Systemsintermediate

Layer Normalization and Residual Connections

Pre-LN vs Post-LN transformer blocks; residual connections for gradient flow; RMSNorm in modern LLMs like LLaMA; code showing a complete Pre-LN transformer block.

Asma Hafeez KhanMay 15, 20267 min read
TransformersLayerNormRMSNormResidual ConnectionsTraining Stability
Share:𝕏

Why Normalisation Matters

Deep networks suffer from internal covariate shift: the distribution of each layer's input changes as earlier layers update their weights. This forces later layers to constantly re-adapt, slowing convergence.

Layer Normalisation (Ba et al. 2016) normalises across the feature dimension for each sample independently, unlike Batch Norm which normalises across the batch.

Layer Norm Formula

Given a vector x of dimension d:

mu    = mean(x)                          # scalar
sigma = std(x)                           # scalar
x_hat = (x - mu) / (sigma + epsilon)    # normalised
y     = gamma * x_hat + beta            # affine transform

Where gamma and beta are learnable parameters initialised to 1 and 0 respectively. They allow the model to undo normalisation if needed.

Python
import torch
import torch.nn as nn
import math


class LayerNorm(nn.Module):
    """Manual LayerNorm for educational clarity."""

    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta  = nn.Parameter(torch.zeros(d_model))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (..., d_model)  normalise over last dimension
        mu    = x.mean(dim=-1, keepdim=True)
        sigma = x.std(dim=-1, keepdim=True, unbiased=False)
        x_hat = (x - mu) / (sigma + self.eps)
        return self.gamma * x_hat + self.beta


# Compare with PyTorch built-in
ln_manual = LayerNorm(256)
ln_torch  = nn.LayerNorm(256)

x = torch.randn(4, 32, 256)
out_m = ln_manual(x)
out_t = ln_torch(x)
print("Manual:", out_m.shape)  # (4, 32, 256)
print("Max diff:", (out_m - out_t).abs().max().item())  # Very small (fp precision)

Residual Connections

A residual connection (He et al. 2015, originally from ResNets) adds the input of a sub-layer back to its output:

output = sublayer(x) + x

This creates a shortcut path for gradients to flow backward without passing through the sublayer. Without residuals, gradients in a 100-layer network would vanish or explode.

Mathematical insight: if sublayer(x) = 0 (worst case), the output is still x. The network learns incremental updates rather than full transformations.

Python
def residual_block_example():
    """Illustrate gradient flow with and without residuals."""
    d = 64
    n_layers = 20

    # Without residuals: gradient multiplied by Jacobian at each layer
    # With residuals:  gradient = 1 + Jacobian (identity shortcut)

    x = torch.randn(1, d, requires_grad=True)
    linear = nn.Linear(d, d)

    # Without residual
    y = x
    for _ in range(n_layers):
        y = torch.tanh(linear(y))
    y.sum().backward()
    print("Without residual, grad norm:", x.grad.norm().item())

    x.grad = None

    # With residual
    y = x
    for _ in range(n_layers):
        y = y + torch.tanh(linear(y))
    y.sum().backward()
    print("With residual, grad norm:", x.grad.norm().item())

Post-LN vs Pre-LN

Original Transformer: Post-LN

The "Attention Is All You Need" paper placed LayerNorm after the residual addition:

x = LayerNorm( x + Attention(x) )
x = LayerNorm( x + FFN(x) )

Problem: in early training, the gradients at the output layer flow back through LayerNorm at each layer. For deep networks (more than 12 layers), this leads to unstable training — the learning rate must be carefully scheduled with warm-up.

Modern Transformer: Pre-LN

Pre-LN (child et al., GPT-2 and beyond) places LayerNorm before the sublayer:

x = x + Attention(LayerNorm(x))
x = x + FFN(LayerNorm(x))

This keeps the residual stream in its original scale. Gradients bypass the normalisation on the main path, making training significantly more stable.

Python
class PostLNBlock(nn.Module):
    """Original Transformer: LayerNorm after residual."""

    def __init__(self, d_model: int, n_heads: int, d_ff: int):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.ff   = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model))
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = self.norm1(x + self.attn(x, x, x, need_weights=False)[0])
        x = self.norm2(x + self.ff(x))
        return x


class PreLNBlock(nn.Module):
    """Modern Transformer: LayerNorm before sublayer (Pre-LN)."""

    def __init__(self, d_model: int, n_heads: int, d_ff: int):
        super().__init__()
        self.attn  = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.ff    = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model))
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        # Norm before: the residual stream stays un-normalised
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x), need_weights=False)[0]
        x = x + self.ff(self.norm2(x))
        return x

RMSNorm: Simpler and Faster

RMSNorm (Zhang & Sennrich 2019) removes the mean-subtraction step entirely, only normalising by the root mean square of the activations:

rms(x) = sqrt( mean(x²) + epsilon )
y = (x / rms(x)) * gamma

No beta (bias) parameter needed. This is roughly 15% faster than LayerNorm and used in LLaMA, Mistral, Qwen, and most modern LLMs.

Python
class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization — used in LLaMA."""

    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(d_model))

    def _rms(self, x: torch.Tensor) -> torch.Tensor:
        return x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x / self._rms(x) * self.gamma


# Compare LayerNorm vs RMSNorm speed
import time

d_model = 4096
batch = 8
seq = 2048
x = torch.randn(batch, seq, d_model)

ln  = nn.LayerNorm(d_model)
rms = RMSNorm(d_model)

# Warmup
for _ in range(10):
    _ = ln(x)
    _ = rms(x)

N = 100
t0 = time.perf_counter()
for _ in range(N):
    _ = ln(x)
t_ln = (time.perf_counter() - t0) / N * 1000

t0 = time.perf_counter()
for _ in range(N):
    _ = rms(x)
t_rms = (time.perf_counter() - t0) / N * 1000

print(f"LayerNorm:  {t_ln:.2f} ms")
print(f"RMSNorm:    {t_rms:.2f} ms")

Full Pre-LN Transformer Block with RMSNorm

Python
import torch.nn.functional as F


class LLaMABlock(nn.Module):
    """
    LLaMA-style transformer block:
    - Pre-LN with RMSNorm (no beta bias)
    - SwiGLU feed-forward (covered in feedforward lesson)
    - Causal self-attention
    """

    def __init__(self, d_model: int, n_heads: int, d_ff: int, max_len: int = 4096):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_k = d_model // n_heads
        self.n_heads = n_heads

        # Attention
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

        # SwiGLU feed-forward: two gates
        self.gate_proj  = nn.Linear(d_model, d_ff, bias=False)
        self.up_proj    = nn.Linear(d_model, d_ff, bias=False)
        self.down_proj  = nn.Linear(d_ff, d_model, bias=False)

        # Pre-LN with RMSNorm
        self.attn_norm = RMSNorm(d_model)
        self.ff_norm   = RMSNorm(d_model)

        # Causal mask
        mask = torch.triu(torch.ones(max_len, max_len, dtype=torch.bool), diagonal=1)
        self.register_buffer("causal_mask", mask)

    def _split(self, t):
        B, T, _ = t.shape
        return t.view(B, T, self.n_heads, self.d_k).transpose(1, 2)

    def _merge(self, t):
        B, h, T, d_k = t.shape
        return t.transpose(1, 2).contiguous().view(B, T, h * d_k)

    def _attention(self, x: torch.Tensor) -> torch.Tensor:
        B, T, _ = x.shape
        Q = self._split(self.W_Q(x))
        K = self._split(self.W_K(x))
        V = self._split(self.W_V(x))
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        scores = scores.masked_fill(self.causal_mask[:T, :T], float('-inf'))
        attn = F.softmax(scores, dim=-1)
        return self.W_O(self._merge(torch.matmul(attn, V)))

    def _swiglu(self, x: torch.Tensor) -> torch.Tensor:
        # SwiGLU: gate(x) * silu(up(x))
        return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Pre-LN attention with residual
        x = x + self._attention(self.attn_norm(x))
        # Pre-LN feed-forward with residual
        x = x + self._swiglu(self.ff_norm(x))
        return x


# Test the full block
block = LLaMABlock(d_model=256, n_heads=8, d_ff=688)
x = torch.randn(2, 16, 256)
out = block(x)
print("LLaMA block output:", out.shape)   # (2, 16, 256)

Comparison of Normalisation Strategies

| Strategy | Norm position | Formula | Used by | |----------|--------------|---------|---------| | Post-LN | After residual | (x + sub) → norm | Original Transformer | | Pre-LN | Before sublayer | x + sub(norm(x)) | GPT-2, GPT-3 | | Pre-LN + RMSNorm | Before sublayer | x + sub(rms_norm(x)) | LLaMA, Mistral |

Key Takeaways

  • LayerNorm normalises each sample over the feature dimension, stabilising activations across layers.
  • Residual connections give gradients a direct path backward, preventing vanishing gradients in deep networks.
  • Pre-LN (norm before attention/FFN) is more stable for very deep networks than the original Post-LN.
  • RMSNorm removes mean subtraction and the beta parameter — faster and used in modern LLMs.
  • The combination of Pre-LN + RMSNorm + residuals is the training-stability foundation of LLaMA and similar models.

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.