Deep Learning for AI Interviews · Lesson 8 of 56
Weight Initialization: Why It Matters
Why Initialisation Matters
Training starts at the initial weights and uses gradient descent to improve. Bad starting weights can cause:
All-zeros initialisation:
Every neuron computes the same output → same gradient → same update
Symmetry breaking doesn't happen → all neurons remain identical
Network stays stuck — no feature diversity
Very large weights:
Large activations → sigmoid/tanh saturate at 0 or 1
Saturated neurons → near-zero gradients → vanishing gradients
Backprop signal doesn't reach early layers
Very small weights:
Activations shrink toward zero with each layer
Deep networks → activations die out completely
Again → vanishing gradient problem
Goal: weights initialised so that activations and gradients
have reasonable magnitude throughout the network.The Problem: Activation Variance
Consider a layer: z = W x + b
If we initialise W ~ N(0, 1) and x has variance σ²_x:
Var(z_j) = n × σ²_W × σ²_x (sum of n independent terms)
With n = 512 inputs: Var(z) = 512 × 1 × σ²_x = 512 × σ²_x
Variance EXPLODES as we go deeper:
Layer 1: σ² = 512 × σ²_x
Layer 2: σ² = 512² × σ²_x
Layer 10: σ² = 512¹⁰ × σ²_x → activations blow up
To keep variance constant through layers:
We need: Var(z) ≈ Var(x)
→ n × σ²_W = 1
→ σ²_W = 1/n
→ σ_W = 1/√nXavier / Glorot Initialisation
Designed for tanh and sigmoid activations:
Xavier uniform:
W ~ Uniform(-a, a) where a = √(6 / (fan_in + fan_out))
Xavier normal:
W ~ Normal(0, σ²) where σ = √(2 / (fan_in + fan_out))
Balances both forward variance (fan_in) and backward gradient variance (fan_out).
When to use: tanh, sigmoid activations, transformers, MLPs with tanhKaiming / He Initialisation
Designed for ReLU and its variants:
ReLU kills half the neurons (negative → 0)
Need to compensate by using 2/fan_in instead of 1/fan_in
Kaiming normal:
W ~ Normal(0, σ²) where σ = √(2 / fan_in)
Kaiming uniform:
W ~ Uniform(-a, a) where a = √(6 / fan_in)
When to use: ReLU, LeakyReLU, ELU, GELU (most modern MLPs)PyTorch Implementation
import torch
import torch.nn as nn
import numpy as np
# PyTorch defaults:
# nn.Linear: Kaiming uniform (He uniform)
# nn.Conv2d: Kaiming uniform (He uniform)
# nn.Embedding: Normal(0, 1)
# Verify PyTorch defaults
layer = nn.Linear(512, 256)
print(f"Default weight std: {layer.weight.data.std():.4f}")
# Expected for Kaiming uniform with fan_in=512: std ≈ 0.0625
# Manual initialisation
def init_weights(module: nn.Module) -> None:
if isinstance(module, nn.Linear):
nn.init.xavier_normal_(module.weight) # for tanh/sigmoid
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
if module.bias is not None:
nn.init.zeros_(module.bias)
class MLP(nn.Module):
def __init__(self, d_in: int, d_hidden: int, d_out: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_in, d_hidden),
nn.ReLU(),
nn.Linear(d_hidden, d_out),
)
self.apply(init_weights) # apply init to all submodules
model = MLP(128, 256, 10)
# Common schemes as formulas
def xavier_std(fan_in: int, fan_out: int) -> float:
return (2 / (fan_in + fan_out)) ** 0.5
def kaiming_std(fan_in: int, nonlinearity: str = "relu") -> float:
gain = {"relu": 2.0, "tanh": 5/3, "sigmoid": 1.0, "linear": 1.0}[nonlinearity]
return (gain / fan_in) ** 0.5
print(f"Xavier std (512→256): {xavier_std(512, 256):.5f}")
print(f"Kaiming std (512, ReLU): {kaiming_std(512):.5f}")Verifying Activation Health
def check_activation_stats(model: nn.Module, x: torch.Tensor) -> None:
"""Hook to log activation statistics through the network."""
hooks = []
def make_hook(name):
def hook(module, input, output):
print(f"{name}: mean={output.mean():.4f}, std={output.std():.4f}")
return hook
for name, module in model.named_modules():
if isinstance(module, (nn.Linear, nn.ReLU, nn.BatchNorm1d)):
hooks.append(module.register_forward_hook(make_hook(name)))
with torch.no_grad():
model(x)
for h in hooks:
h.remove()
# Healthy activations: mean near 0, std near 1 (before non-linearities)
# ReLU output: mean > 0 (half of inputs were positive)What Goes Wrong Without Proper Initialisation
Deep network (20 layers), tanh activation, N(0,1) init:
Layer 1: activations std ≈ 5
Layer 2: activations std ≈ 15 (tanh saturates — std near 0 actually)
Layer 5: all activations near +1 or -1 (saturated)
Layer 10: all activations at exactly ±1
Gradients: tanh'(x) = 1 - tanh²(x) ≈ 0 when |tanh(x)| ≈ 1
→ vanishing gradient — no learning in early layers
With Xavier init:
Each layer's activations have variance ≈ 1
tanh operates in its linear region (near 0) where gradient is ~1
Gradient flows through all layersInterview Answer
"Weight initialisation sets the starting point for gradient descent. All-zeros fails because symmetry prevents neurons from differentiating. Random weights must be scaled to keep activation variance stable across layers. Xavier initialisation (σ = √(2/(fan_in + fan_out))) is designed for tanh/sigmoid — it balances forward activation variance and backward gradient variance. Kaiming/He initialisation (σ = √(2/fan_in)) adjusts for ReLU's dead-half by doubling the variance. PyTorch's nn.Linear defaults to Kaiming uniform. In practice: use Kaiming for ReLU networks (most CNNs), Xavier for transformer attention layers and tanh networks, and always verify activation statistics in the first training step to catch initialization problems early."