Learnixo

Deep Learning for AI Interviews · Lesson 8 of 56

Weight Initialization: Why It Matters

Why Initialisation Matters

Training starts at the initial weights and uses gradient descent to improve. Bad starting weights can cause:

All-zeros initialisation:
  Every neuron computes the same output → same gradient → same update
  Symmetry breaking doesn't happen → all neurons remain identical
  Network stays stuck — no feature diversity

Very large weights:
  Large activations → sigmoid/tanh saturate at 0 or 1
  Saturated neurons → near-zero gradients → vanishing gradients
  Backprop signal doesn't reach early layers

Very small weights:
  Activations shrink toward zero with each layer
  Deep networks → activations die out completely
  Again → vanishing gradient problem

Goal: weights initialised so that activations and gradients
have reasonable magnitude throughout the network.

The Problem: Activation Variance

Consider a layer: z = W x + b

If we initialise W ~ N(0, 1) and x has variance σ²_x:
  Var(z_j) = n × σ²_W × σ²_x     (sum of n independent terms)
  
  With n = 512 inputs: Var(z) = 512 × 1 × σ²_x = 512 × σ²_x
  
  Variance EXPLODES as we go deeper:
    Layer 1: σ² = 512 × σ²_x
    Layer 2: σ² = 512² × σ²_x
    Layer 10: σ² = 512¹⁰ × σ²_x    → activations blow up

To keep variance constant through layers:
  We need: Var(z) ≈ Var(x)
  → n × σ²_W = 1
  → σ²_W = 1/n
  → σ_W = 1/√n

Xavier / Glorot Initialisation

Designed for tanh and sigmoid activations:

Xavier uniform:
  W ~ Uniform(-a, a) where a = √(6 / (fan_in + fan_out))

Xavier normal:
  W ~ Normal(0, σ²) where σ = √(2 / (fan_in + fan_out))

Balances both forward variance (fan_in) and backward gradient variance (fan_out).

When to use: tanh, sigmoid activations, transformers, MLPs with tanh

Kaiming / He Initialisation

Designed for ReLU and its variants:

ReLU kills half the neurons (negative → 0)
Need to compensate by using 2/fan_in instead of 1/fan_in

Kaiming normal:
  W ~ Normal(0, σ²) where σ = √(2 / fan_in)

Kaiming uniform:
  W ~ Uniform(-a, a) where a = √(6 / fan_in)

When to use: ReLU, LeakyReLU, ELU, GELU (most modern MLPs)

PyTorch Implementation

Python
import torch
import torch.nn as nn
import numpy as np

# PyTorch defaults:
#   nn.Linear: Kaiming uniform (He uniform)
#   nn.Conv2d: Kaiming uniform (He uniform)
#   nn.Embedding: Normal(0, 1)

# Verify PyTorch defaults
layer = nn.Linear(512, 256)
print(f"Default weight std: {layer.weight.data.std():.4f}")
# Expected for Kaiming uniform with fan_in=512: std  0.0625

# Manual initialisation
def init_weights(module: nn.Module) -> None:
    if isinstance(module, nn.Linear):
        nn.init.xavier_normal_(module.weight)    # for tanh/sigmoid
        nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Conv2d):
        nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
        if module.bias is not None:
            nn.init.zeros_(module.bias)

class MLP(nn.Module):
    def __init__(self, d_in: int, d_hidden: int, d_out: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, d_out),
        )
        self.apply(init_weights)   # apply init to all submodules

model = MLP(128, 256, 10)


# Common schemes as formulas
def xavier_std(fan_in: int, fan_out: int) -> float:
    return (2 / (fan_in + fan_out)) ** 0.5

def kaiming_std(fan_in: int, nonlinearity: str = "relu") -> float:
    gain = {"relu": 2.0, "tanh": 5/3, "sigmoid": 1.0, "linear": 1.0}[nonlinearity]
    return (gain / fan_in) ** 0.5

print(f"Xavier std (512→256): {xavier_std(512, 256):.5f}")
print(f"Kaiming std (512, ReLU): {kaiming_std(512):.5f}")

Verifying Activation Health

Python
def check_activation_stats(model: nn.Module, x: torch.Tensor) -> None:
    """Hook to log activation statistics through the network."""
    hooks = []
    
    def make_hook(name):
        def hook(module, input, output):
            print(f"{name}: mean={output.mean():.4f}, std={output.std():.4f}")
        return hook
    
    for name, module in model.named_modules():
        if isinstance(module, (nn.Linear, nn.ReLU, nn.BatchNorm1d)):
            hooks.append(module.register_forward_hook(make_hook(name)))
    
    with torch.no_grad():
        model(x)
    
    for h in hooks:
        h.remove()

# Healthy activations: mean near 0, std near 1 (before non-linearities)
# ReLU output: mean > 0 (half of inputs were positive)

What Goes Wrong Without Proper Initialisation

Deep network (20 layers), tanh activation, N(0,1) init:

Layer 1: activations std ≈ 5
Layer 2: activations std ≈ 15 (tanh saturates — std near 0 actually)
Layer 5: all activations near +1 or -1 (saturated)
Layer 10: all activations at exactly ±1
Gradients: tanh'(x) = 1 - tanh²(x) ≈ 0 when |tanh(x)| ≈ 1
→ vanishing gradient — no learning in early layers

With Xavier init:
  Each layer's activations have variance ≈ 1
  tanh operates in its linear region (near 0) where gradient is ~1
  Gradient flows through all layers

Interview Answer

"Weight initialisation sets the starting point for gradient descent. All-zeros fails because symmetry prevents neurons from differentiating. Random weights must be scaled to keep activation variance stable across layers. Xavier initialisation (σ = √(2/(fan_in + fan_out))) is designed for tanh/sigmoid — it balances forward activation variance and backward gradient variance. Kaiming/He initialisation (σ = √(2/fan_in)) adjusts for ReLU's dead-half by doubling the variance. PyTorch's nn.Linear defaults to Kaiming uniform. In practice: use Kaiming for ReLU networks (most CNNs), Xavier for transformer attention layers and tanh networks, and always verify activation statistics in the first training step to catch initialization problems early."