Weight Initialisation
Why weight initialisation matters for training, the Xavier and Kaiming schemes, what happens with bad initialisation, and PyTorch defaults.
Why Initialisation Matters
Training starts at the initial weights and uses gradient descent to improve. Bad starting weights can cause:
All-zeros initialisation:
Every neuron computes the same output β same gradient β same update
Symmetry breaking doesn't happen β all neurons remain identical
Network stays stuck β no feature diversity
Very large weights:
Large activations β sigmoid/tanh saturate at 0 or 1
Saturated neurons β near-zero gradients β vanishing gradients
Backprop signal doesn't reach early layers
Very small weights:
Activations shrink toward zero with each layer
Deep networks β activations die out completely
Again β vanishing gradient problem
Goal: weights initialised so that activations and gradients
have reasonable magnitude throughout the network.The Problem: Activation Variance
Consider a layer: z = W x + b
If we initialise W ~ N(0, 1) and x has variance ΟΒ²_x:
Var(z_j) = n Γ ΟΒ²_W Γ ΟΒ²_x (sum of n independent terms)
With n = 512 inputs: Var(z) = 512 Γ 1 Γ ΟΒ²_x = 512 Γ ΟΒ²_x
Variance EXPLODES as we go deeper:
Layer 1: ΟΒ² = 512 Γ ΟΒ²_x
Layer 2: ΟΒ² = 512Β² Γ ΟΒ²_x
Layer 10: ΟΒ² = 512ΒΉβ° Γ ΟΒ²_x β activations blow up
To keep variance constant through layers:
We need: Var(z) β Var(x)
β n Γ ΟΒ²_W = 1
β ΟΒ²_W = 1/n
β Ο_W = 1/βnXavier / Glorot Initialisation
Designed for tanh and sigmoid activations:
Xavier uniform:
W ~ Uniform(-a, a) where a = β(6 / (fan_in + fan_out))
Xavier normal:
W ~ Normal(0, ΟΒ²) where Ο = β(2 / (fan_in + fan_out))
Balances both forward variance (fan_in) and backward gradient variance (fan_out).
When to use: tanh, sigmoid activations, transformers, MLPs with tanhKaiming / He Initialisation
Designed for ReLU and its variants:
ReLU kills half the neurons (negative β 0)
Need to compensate by using 2/fan_in instead of 1/fan_in
Kaiming normal:
W ~ Normal(0, ΟΒ²) where Ο = β(2 / fan_in)
Kaiming uniform:
W ~ Uniform(-a, a) where a = β(6 / fan_in)
When to use: ReLU, LeakyReLU, ELU, GELU (most modern MLPs)PyTorch Implementation
import torch
import torch.nn as nn
import numpy as np
# PyTorch defaults:
# nn.Linear: Kaiming uniform (He uniform)
# nn.Conv2d: Kaiming uniform (He uniform)
# nn.Embedding: Normal(0, 1)
# Verify PyTorch defaults
layer = nn.Linear(512, 256)
print(f"Default weight std: {layer.weight.data.std():.4f}")
# Expected for Kaiming uniform with fan_in=512: std β 0.0625
# Manual initialisation
def init_weights(module: nn.Module) -> None:
if isinstance(module, nn.Linear):
nn.init.xavier_normal_(module.weight) # for tanh/sigmoid
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
if module.bias is not None:
nn.init.zeros_(module.bias)
class MLP(nn.Module):
def __init__(self, d_in: int, d_hidden: int, d_out: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_in, d_hidden),
nn.ReLU(),
nn.Linear(d_hidden, d_out),
)
self.apply(init_weights) # apply init to all submodules
model = MLP(128, 256, 10)
# Common schemes as formulas
def xavier_std(fan_in: int, fan_out: int) -> float:
return (2 / (fan_in + fan_out)) ** 0.5
def kaiming_std(fan_in: int, nonlinearity: str = "relu") -> float:
gain = {"relu": 2.0, "tanh": 5/3, "sigmoid": 1.0, "linear": 1.0}[nonlinearity]
return (gain / fan_in) ** 0.5
print(f"Xavier std (512β256): {xavier_std(512, 256):.5f}")
print(f"Kaiming std (512, ReLU): {kaiming_std(512):.5f}")Verifying Activation Health
def check_activation_stats(model: nn.Module, x: torch.Tensor) -> None:
"""Hook to log activation statistics through the network."""
hooks = []
def make_hook(name):
def hook(module, input, output):
print(f"{name}: mean={output.mean():.4f}, std={output.std():.4f}")
return hook
for name, module in model.named_modules():
if isinstance(module, (nn.Linear, nn.ReLU, nn.BatchNorm1d)):
hooks.append(module.register_forward_hook(make_hook(name)))
with torch.no_grad():
model(x)
for h in hooks:
h.remove()
# Healthy activations: mean near 0, std near 1 (before non-linearities)
# ReLU output: mean > 0 (half of inputs were positive)What Goes Wrong Without Proper Initialisation
Deep network (20 layers), tanh activation, N(0,1) init:
Layer 1: activations std β 5
Layer 2: activations std β 15 (tanh saturates β std near 0 actually)
Layer 5: all activations near +1 or -1 (saturated)
Layer 10: all activations at exactly Β±1
Gradients: tanh'(x) = 1 - tanhΒ²(x) β 0 when |tanh(x)| β 1
β vanishing gradient β no learning in early layers
With Xavier init:
Each layer's activations have variance β 1
tanh operates in its linear region (near 0) where gradient is ~1
Gradient flows through all layersInterview Answer
"Weight initialisation sets the starting point for gradient descent. All-zeros fails because symmetry prevents neurons from differentiating. Random weights must be scaled to keep activation variance stable across layers. Xavier initialisation (Ο = β(2/(fan_in + fan_out))) is designed for tanh/sigmoid β it balances forward activation variance and backward gradient variance. Kaiming/He initialisation (Ο = β(2/fan_in)) adjusts for ReLU's dead-half by doubling the variance. PyTorch's nn.Linear defaults to Kaiming uniform. In practice: use Kaiming for ReLU networks (most CNNs), Xavier for transformer attention layers and tanh networks, and always verify activation statistics in the first training step to catch initialization problems early."
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.