Learnixo

Transformer Architecture Q&A · Lesson 18 of 23

Residual Connections: Why Skip Connections Help

What a Residual Connection Is

A residual connection (He et al., 2015, ResNet) adds the input of a sublayer directly to its output:

Without residual:   output = SubLayer(x)
With residual:      output = x + SubLayer(x)

In transformers, every sublayer — both attention and FFN — uses a residual connection:

Encoder block (Pre-LN):
  a   = x + Attention(LayerNorm(x))
  out = a + FFN(LayerNorm(a))

The input x is preserved and added back after the transformation.


The Vanishing Gradient Problem

Without residuals, gradient flow through a deep network is multiplicative:

∂L/∂x₀ = ∂L/∂xₙ · ∂xₙ/∂xₙ₋₁ · ... · ∂x₁/∂x₀

For n=96 layers, each Jacobian has spectral radius < 1:
  Product shrinks exponentially → gradients near zero at layer 0
  Early layers learn nothing

With residuals:
  ∂L/∂xᵢ = ∂L/∂xᵢ₊₁ · (I + ∂SubLayer(xᵢ)/∂xᵢ)

The identity I creates an additive path — gradients flow directly
from the loss to any layer without multiplicative decay.

The Identity Shortcut Interpretation

Residuals implement an important inductive bias: each layer only needs to learn the residual — the difference from the identity:

If SubLayer(x) = 0 (zero function):
  output = x + 0 = x  (identity — the layer is a no-op)

The model can "choose" to make a sublayer do nothing
by keeping its weights small. This is much easier to learn
than trying to learn the identity mapping from scratch.

Consequence: layers that aren't useful get effectively zeroed out.
             Layers that matter learn strong non-identity functions.

This makes training more robust — adding extra layers doesn't necessarily hurt performance.


Gradient Flow Visualisation

Forward pass (depth 4):
  x → [Layer1] → h1 → [Layer2] → h2 → [Layer3] → h3 → [Layer4] → loss

Without residuals (backprop):
  ∂L/∂x = J4 · J3 · J2 · J1   (product of Jacobians, may vanish)

With residuals (backprop):
  ∂L/∂x = (I + J4) · (I + J3) · (I + J2) · (I + J1)
  Expanding: = I + J1 + J2 + J3 + J4 + J1J2 + ...

  The I term means there is ALWAYS a gradient path of magnitude 1,
  regardless of how many layers exist.

Code Illustration

Python
import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, d_model: int, sublayer: nn.Module):
        super().__init__()
        self.sublayer = sublayer
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        # Pre-LN style
        return x + self.sublayer(self.norm(x), **kwargs)

# In practice, transformer blocks inline this:
class EncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ff   = FFN(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        x = x + self.attn(self.norm1(x), mask=mask)   # residual 1
        x = x + self.ff(self.norm2(x))                 # residual 2
        return x

Residuals Enable Depth Scaling

Without residuals:
  BERT-base: 12 layers — already difficult to train stably
  Training GPT-3 (96 layers): would fail entirely

With residuals + Pre-LN:
  GPT-3: 96 layers, 175B parameters — trains stably with AdamW
  PaLM:  118 layers — trains stably
  The residual connection is the infrastructure enabling modern scale

Residuals as Ensemble

An alternative interpretation: a deep residual network is like an implicit ensemble of shallow networks:

2-layer residual network:
  output = x + f1(x) + f2(x + f1(x))
  Expanded: x + f1(x) + f2(x) + f2(f1(x))  (approximately)

This is like having paths of depth 0, 1, 1, 2 all contributing.
Ensemble of many depths → more robust to layer failures.

Interview Answer

"Residual connections add the sublayer's input directly to its output: output = x + SubLayer(x). The key benefit is gradient flow — without residuals, backpropagation multiplies Jacobians across all layers, causing exponential vanishing. With residuals, there is always an identity path that carries gradients directly to any layer. They also encode a useful inductive bias: each layer only needs to learn the residual (the deviation from identity), making it easy for layers to become no-ops if they aren't useful. Residuals are the structural prerequisite for training transformers with 24, 96, or even 118 layers."