Deep Learning for AI Interviews · Lesson 50 of 56
Batch Normalization: How and Why It Works
What Batch Normalisation Does
Problem: as training progresses, the distribution of each layer's inputs shifts
because parameters in previous layers change (internal covariate shift).
This forces each layer to constantly adapt to a changing distribution → slow training.
BatchNorm solution: normalise the inputs to each layer using batch statistics.
For a mini-batch B = {x_1, ..., x_m}:
μ_B = (1/m) Σ x_i (batch mean)
σ²_B = (1/m) Σ (x_i - μ_B)² (batch variance)
x̂_i = (x_i - μ_B) / √(σ²_B + ε) (normalise)
y_i = γ · x̂_i + β (scale and shift with learnable γ, β)
γ and β allow the network to undo the normalisation if the optimal representation
happens to not be zero-mean unit-variance.BatchNorm from Scratch
import torch
import torch.nn as nn
class BatchNorm1DManual(nn.Module):
"""BatchNorm for 1D inputs (tabular features)."""
def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1):
super().__init__()
self.eps = eps
self.momentum = momentum
# Learnable scale (γ) and shift (β), initialised to 1 and 0
self.gamma = nn.Parameter(torch.ones(num_features))
self.beta = nn.Parameter(torch.zeros(num_features))
# Running statistics for inference (not trained via backprop)
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.training:
# Compute batch statistics
batch_mean = x.mean(dim=0)
batch_var = x.var(dim=0, unbiased=False)
# Update running statistics (exponential moving average)
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
# Normalise using batch statistics
x_norm = (x - batch_mean) / torch.sqrt(batch_var + self.eps)
else:
# Use accumulated running statistics for inference
x_norm = (x - self.running_mean) / torch.sqrt(self.running_var + self.eps)
# Scale and shift
return self.gamma * x_norm + self.beta
# Compare with PyTorch's implementation
X = torch.randn(32, 10)
bn_manual = BatchNorm1DManual(num_features=10)
bn_torch = nn.BatchNorm1d(10)
bn_manual.train()
bn_torch.train()
out_manual = bn_manual(X)
out_torch = bn_torch(X)
print(f"Manual output: mean={out_manual.mean():.4f}, std={out_manual.std():.4f}")
print(f"Torch output: mean={out_torch.mean():.4f}, std={out_torch.std():.4f}")
# Both should have mean≈0, std≈1 after normalisationBatchNorm in CNN vs MLP
import torch
import torch.nn as nn
# For MLP (2D: batch × features): nn.BatchNorm1d
class ClinicalMLPWithBN(nn.Module):
def __init__(self, n_features: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_features, 64),
nn.BatchNorm1d(64), # normalise over batch dimension
nn.ReLU(),
nn.Linear(64, 32),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.Linear(32, 1),
)
def forward(self, x):
return self.net(x)
# For CNN (4D: batch × channels × H × W): nn.BatchNorm2d
class ConvBlock(nn.Module):
def __init__(self, in_c: int, out_c: int):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_c, out_c, 3, padding=1, bias=False), # no bias (BN has β)
nn.BatchNorm2d(out_c), # normalise over (batch, H, W) for each channel
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.block(x)
# Note: bias=False in Conv2d when using BatchNorm
# BatchNorm's β parameter acts as the bias → no need for an extra bias term
mlp = ClinicalMLPWithBN(n_features=20)
X = torch.randn(32, 20)
mlp.train()
out = mlp(X)
print(f"MLP output shape: {out.shape}") # (32, 1)Train vs Eval Mode: Critical Difference
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Linear(10, 32),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.Linear(32, 1),
)
X_train_batch = torch.randn(64, 10) # large batch
X_eval_single = torch.randn(1, 10) # single sample at inference
# ── TRAINING mode ──
model.train()
out_train = model(X_train_batch)
print(f"Training mode, batch of 64: {out_train.shape}")
# WRONG: BatchNorm1d fails with batch_size=1 in training mode!
try:
model.train()
out_single_train = model(X_eval_single)
except Exception as e:
print(f"Error with batch_size=1 in train mode: {type(e).__name__}: {e}")
# CORRECT: switch to eval mode for inference
model.eval()
out_single_eval = model(X_eval_single)
print(f"Eval mode, single sample: {out_single_eval.shape}") # works!
# Another common bug: forgetting model.eval() in validation loop
# → BN uses batch statistics from val set → different from running stats
# → Validation loss appears different from what it will be at deploymentLayer Norm vs Batch Norm
import torch
import torch.nn as nn
# BatchNorm1d: normalise over BATCH dimension (per-feature statistics)
# input: (batch, features)
# normalise: across batch for each feature
# problem: statistics depend on batch size → fails for batch_size=1
# LayerNorm: normalise over FEATURE dimension (per-sample statistics)
# input: (batch, features)
# normalise: across features for each sample
# works with batch_size=1 → standard for Transformers
X = torch.randn(8, 32) # (batch=8, features=32)
bn = nn.BatchNorm1d(32)
ln = nn.LayerNorm(32)
bn.train()
out_bn = bn(X)
out_ln = ln(X)
print(f"BatchNorm: per-feature mean={out_bn.mean(0)[:3]}") # ≈0 across batch
print(f"LayerNorm: per-sample mean={out_ln.mean(1)[:3]}") # ≈0 across features
# Use BatchNorm for: CNNs, MLPs (when batch size ≥ 8)
# Use LayerNorm for: Transformers, RNNs, any variable-length or single-sample inference
# Use GroupNorm for: CNNs with small batch sizes (common in object detection)
# Transformer FFN with LayerNorm
class TransformerBlock(nn.Module):
def __init__(self, d_model: int = 512, d_ff: int = 2048):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.ff(self.norm1(x)) # pre-norm style (modern standard)
return xBenefits of Batch Normalisation
1. Faster convergence: reduces internal covariate shift → higher learning rates possible
2. Regularisation: adds noise (batch statistics vary per batch) → mild regularisation effect
3. Gradient flow: re-normalises activations → prevents vanishing/exploding activations
4. Less sensitive to initialisation: can use larger learning rates
5. Enables training of very deep networks (without skip connections)
Limitations:
- Requires batch_size ≥ 8 (small batches → noisy statistics)
- Different behaviour in train vs eval (a common source of bugs)
- Not suitable for recurrent networks (each timestep is a different "feature")
- Online inference with batch_size=1: must use eval modeInterview Answer
"BatchNorm normalises the inputs to each layer using mini-batch statistics (mean and variance), then applies learnable scale (γ) and shift (β) parameters. This reduces internal covariate shift — the tendency for each layer's input distribution to change as upstream parameters update — enabling higher learning rates and faster convergence. BatchNorm also acts as a mild regulariser because batch statistics are noisy (different per mini-batch). The critical practical detail: BatchNorm has two modes. In training mode, it uses the current mini-batch statistics and updates running statistics via exponential moving average. In eval mode, it uses the accumulated running statistics — essential because at inference you may have batch_size=1. Forgetting model.eval() in validation causes the model to use batch statistics from the validation set, making validation loss inconsistent with production performance. For Transformers: use LayerNorm (normalises over features per sample) rather than BatchNorm — LayerNorm works with any batch size and is sequence-length agnostic."