Learnixo
Back to blog
AI Systemsintermediate

Weight Initialisation

Why weight initialisation matters for training, the Xavier and Kaiming schemes, what happens with bad initialisation, and PyTorch defaults.

Asma Hafeez KhanMay 21, 20265 min read
Deep LearningWeight InitialisationXavierKaimingTraining Stability
Share:𝕏

Why Initialisation Matters

Training starts at the initial weights and uses gradient descent to improve. Bad starting weights can cause:

All-zeros initialisation:
  Every neuron computes the same output β†’ same gradient β†’ same update
  Symmetry breaking doesn't happen β†’ all neurons remain identical
  Network stays stuck β€” no feature diversity

Very large weights:
  Large activations β†’ sigmoid/tanh saturate at 0 or 1
  Saturated neurons β†’ near-zero gradients β†’ vanishing gradients
  Backprop signal doesn't reach early layers

Very small weights:
  Activations shrink toward zero with each layer
  Deep networks β†’ activations die out completely
  Again β†’ vanishing gradient problem

Goal: weights initialised so that activations and gradients
have reasonable magnitude throughout the network.

The Problem: Activation Variance

Consider a layer: z = W x + b

If we initialise W ~ N(0, 1) and x has variance σ²_x:
  Var(z_j) = n Γ— σ²_W Γ— σ²_x     (sum of n independent terms)
  
  With n = 512 inputs: Var(z) = 512 Γ— 1 Γ— σ²_x = 512 Γ— σ²_x
  
  Variance EXPLODES as we go deeper:
    Layer 1: σ² = 512 Γ— σ²_x
    Layer 2: σ² = 512Β² Γ— σ²_x
    Layer 10: σ² = 512¹⁰ Γ— σ²_x    β†’ activations blow up

To keep variance constant through layers:
  We need: Var(z) β‰ˆ Var(x)
  β†’ n Γ— σ²_W = 1
  β†’ σ²_W = 1/n
  β†’ Οƒ_W = 1/√n

Xavier / Glorot Initialisation

Designed for tanh and sigmoid activations:

Xavier uniform:
  W ~ Uniform(-a, a) where a = √(6 / (fan_in + fan_out))

Xavier normal:
  W ~ Normal(0, σ²) where Οƒ = √(2 / (fan_in + fan_out))

Balances both forward variance (fan_in) and backward gradient variance (fan_out).

When to use: tanh, sigmoid activations, transformers, MLPs with tanh

Kaiming / He Initialisation

Designed for ReLU and its variants:

ReLU kills half the neurons (negative β†’ 0)
Need to compensate by using 2/fan_in instead of 1/fan_in

Kaiming normal:
  W ~ Normal(0, σ²) where Οƒ = √(2 / fan_in)

Kaiming uniform:
  W ~ Uniform(-a, a) where a = √(6 / fan_in)

When to use: ReLU, LeakyReLU, ELU, GELU (most modern MLPs)

PyTorch Implementation

Python
import torch
import torch.nn as nn
import numpy as np

# PyTorch defaults:
#   nn.Linear: Kaiming uniform (He uniform)
#   nn.Conv2d: Kaiming uniform (He uniform)
#   nn.Embedding: Normal(0, 1)

# Verify PyTorch defaults
layer = nn.Linear(512, 256)
print(f"Default weight std: {layer.weight.data.std():.4f}")
# Expected for Kaiming uniform with fan_in=512: std β‰ˆ 0.0625

# Manual initialisation
def init_weights(module: nn.Module) -> None:
    if isinstance(module, nn.Linear):
        nn.init.xavier_normal_(module.weight)    # for tanh/sigmoid
        nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Conv2d):
        nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
        if module.bias is not None:
            nn.init.zeros_(module.bias)

class MLP(nn.Module):
    def __init__(self, d_in: int, d_hidden: int, d_out: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, d_out),
        )
        self.apply(init_weights)   # apply init to all submodules

model = MLP(128, 256, 10)


# Common schemes as formulas
def xavier_std(fan_in: int, fan_out: int) -> float:
    return (2 / (fan_in + fan_out)) ** 0.5

def kaiming_std(fan_in: int, nonlinearity: str = "relu") -> float:
    gain = {"relu": 2.0, "tanh": 5/3, "sigmoid": 1.0, "linear": 1.0}[nonlinearity]
    return (gain / fan_in) ** 0.5

print(f"Xavier std (512β†’256): {xavier_std(512, 256):.5f}")
print(f"Kaiming std (512, ReLU): {kaiming_std(512):.5f}")

Verifying Activation Health

Python
def check_activation_stats(model: nn.Module, x: torch.Tensor) -> None:
    """Hook to log activation statistics through the network."""
    hooks = []
    
    def make_hook(name):
        def hook(module, input, output):
            print(f"{name}: mean={output.mean():.4f}, std={output.std():.4f}")
        return hook
    
    for name, module in model.named_modules():
        if isinstance(module, (nn.Linear, nn.ReLU, nn.BatchNorm1d)):
            hooks.append(module.register_forward_hook(make_hook(name)))
    
    with torch.no_grad():
        model(x)
    
    for h in hooks:
        h.remove()

# Healthy activations: mean near 0, std near 1 (before non-linearities)
# ReLU output: mean > 0 (half of inputs were positive)

What Goes Wrong Without Proper Initialisation

Deep network (20 layers), tanh activation, N(0,1) init:

Layer 1: activations std β‰ˆ 5
Layer 2: activations std β‰ˆ 15 (tanh saturates β€” std near 0 actually)
Layer 5: all activations near +1 or -1 (saturated)
Layer 10: all activations at exactly Β±1
Gradients: tanh'(x) = 1 - tanhΒ²(x) β‰ˆ 0 when |tanh(x)| β‰ˆ 1
β†’ vanishing gradient β€” no learning in early layers

With Xavier init:
  Each layer's activations have variance β‰ˆ 1
  tanh operates in its linear region (near 0) where gradient is ~1
  Gradient flows through all layers

Interview Answer

"Weight initialisation sets the starting point for gradient descent. All-zeros fails because symmetry prevents neurons from differentiating. Random weights must be scaled to keep activation variance stable across layers. Xavier initialisation (Οƒ = √(2/(fan_in + fan_out))) is designed for tanh/sigmoid β€” it balances forward activation variance and backward gradient variance. Kaiming/He initialisation (Οƒ = √(2/fan_in)) adjusts for ReLU's dead-half by doubling the variance. PyTorch's nn.Linear defaults to Kaiming uniform. In practice: use Kaiming for ReLU networks (most CNNs), Xavier for transformer attention layers and tanh networks, and always verify activation statistics in the first training step to catch initialization problems early."

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.