Residual Connections
Why residual (skip) connections are essential for deep transformers, how they solve the vanishing gradient problem, and what the identity shortcut provides architecturally.
What a Residual Connection Is
A residual connection (He et al., 2015, ResNet) adds the input of a sublayer directly to its output:
Without residual: output = SubLayer(x)
With residual: output = x + SubLayer(x)In transformers, every sublayer β both attention and FFN β uses a residual connection:
Encoder block (Pre-LN):
a = x + Attention(LayerNorm(x))
out = a + FFN(LayerNorm(a))The input x is preserved and added back after the transformation.
The Vanishing Gradient Problem
Without residuals, gradient flow through a deep network is multiplicative:
βL/βxβ = βL/βxβ Β· βxβ/βxβββ Β· ... Β· βxβ/βxβ
For n=96 layers, each Jacobian has spectral radius < 1:
Product shrinks exponentially β gradients near zero at layer 0
Early layers learn nothing
With residuals:
βL/βxα΅’ = βL/βxα΅’ββ Β· (I + βSubLayer(xα΅’)/βxα΅’)
The identity I creates an additive path β gradients flow directly
from the loss to any layer without multiplicative decay.The Identity Shortcut Interpretation
Residuals implement an important inductive bias: each layer only needs to learn the residual β the difference from the identity:
If SubLayer(x) = 0 (zero function):
output = x + 0 = x (identity β the layer is a no-op)
The model can "choose" to make a sublayer do nothing
by keeping its weights small. This is much easier to learn
than trying to learn the identity mapping from scratch.
Consequence: layers that aren't useful get effectively zeroed out.
Layers that matter learn strong non-identity functions.This makes training more robust β adding extra layers doesn't necessarily hurt performance.
Gradient Flow Visualisation
Forward pass (depth 4):
x β [Layer1] β h1 β [Layer2] β h2 β [Layer3] β h3 β [Layer4] β loss
Without residuals (backprop):
βL/βx = J4 Β· J3 Β· J2 Β· J1 (product of Jacobians, may vanish)
With residuals (backprop):
βL/βx = (I + J4) Β· (I + J3) Β· (I + J2) Β· (I + J1)
Expanding: = I + J1 + J2 + J3 + J4 + J1J2 + ...
The I term means there is ALWAYS a gradient path of magnitude 1,
regardless of how many layers exist.Code Illustration
import torch
import torch.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, d_model: int, sublayer: nn.Module):
super().__init__()
self.sublayer = sublayer
self.norm = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
# Pre-LN style
return x + self.sublayer(self.norm(x), **kwargs)
# In practice, transformer blocks inline this:
class EncoderBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff):
super().__init__()
self.attn = MultiHeadAttention(d_model, num_heads)
self.ff = FFN(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
x = x + self.attn(self.norm1(x), mask=mask) # residual 1
x = x + self.ff(self.norm2(x)) # residual 2
return xResiduals Enable Depth Scaling
Without residuals:
BERT-base: 12 layers β already difficult to train stably
Training GPT-3 (96 layers): would fail entirely
With residuals + Pre-LN:
GPT-3: 96 layers, 175B parameters β trains stably with AdamW
PaLM: 118 layers β trains stably
The residual connection is the infrastructure enabling modern scaleResiduals as Ensemble
An alternative interpretation: a deep residual network is like an implicit ensemble of shallow networks:
2-layer residual network:
output = x + f1(x) + f2(x + f1(x))
Expanded: x + f1(x) + f2(x) + f2(f1(x)) (approximately)
This is like having paths of depth 0, 1, 1, 2 all contributing.
Ensemble of many depths β more robust to layer failures.Interview Answer
"Residual connections add the sublayer's input directly to its output: output = x + SubLayer(x). The key benefit is gradient flow β without residuals, backpropagation multiplies Jacobians across all layers, causing exponential vanishing. With residuals, there is always an identity path that carries gradients directly to any layer. They also encode a useful inductive bias: each layer only needs to learn the residual (the deviation from identity), making it easy for layers to become no-ops if they aren't useful. Residuals are the structural prerequisite for training transformers with 24, 96, or even 118 layers."
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.