Learnixo
Back to blog
AI Systemsintermediate

Batch Normalisation

How BatchNorm normalises activations mid-network, its learnable parameters, the difference between train and eval modes, and Layer Norm for transformers.

Asma Hafeez KhanMay 22, 20266 min read
Deep LearningBatch NormalisationLayer NormTrainingStabilityInterview
Share:𝕏

What Batch Normalisation Does

Problem: as training progresses, the distribution of each layer's inputs shifts
because parameters in previous layers change (internal covariate shift).
This forces each layer to constantly adapt to a changing distribution → slow training.

BatchNorm solution: normalise the inputs to each layer using batch statistics.

For a mini-batch B = {x_1, ..., x_m}:

  μ_B  = (1/m) Σ x_i          (batch mean)
  σ²_B = (1/m) Σ (x_i - μ_B)² (batch variance)
  x̂_i  = (x_i - μ_B) / √(σ²_B + ε)  (normalise)
  y_i  = γ · x̂_i + β          (scale and shift with learnable γ, β)

γ and β allow the network to undo the normalisation if the optimal representation
happens to not be zero-mean unit-variance.

BatchNorm from Scratch

Python
import torch
import torch.nn as nn

class BatchNorm1DManual(nn.Module):
    """BatchNorm for 1D inputs (tabular features)."""
    
    def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1):
        super().__init__()
        self.eps = eps
        self.momentum = momentum
        
        # Learnable scale (γ) and shift (β), initialised to 1 and 0
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta  = nn.Parameter(torch.zeros(num_features))
        
        # Running statistics for inference (not trained via backprop)
        self.register_buffer("running_mean", torch.zeros(num_features))
        self.register_buffer("running_var",  torch.ones(num_features))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.training:
            # Compute batch statistics
            batch_mean = x.mean(dim=0)
            batch_var  = x.var(dim=0, unbiased=False)
            
            # Update running statistics (exponential moving average)
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
            self.running_var  = (1 - self.momentum) * self.running_var  + self.momentum * batch_var
            
            # Normalise using batch statistics
            x_norm = (x - batch_mean) / torch.sqrt(batch_var + self.eps)
        else:
            # Use accumulated running statistics for inference
            x_norm = (x - self.running_mean) / torch.sqrt(self.running_var + self.eps)
        
        # Scale and shift
        return self.gamma * x_norm + self.beta

# Compare with PyTorch's implementation
X = torch.randn(32, 10)

bn_manual = BatchNorm1DManual(num_features=10)
bn_torch  = nn.BatchNorm1d(10)

bn_manual.train()
bn_torch.train()

out_manual = bn_manual(X)
out_torch  = bn_torch(X)

print(f"Manual output: mean={out_manual.mean():.4f}, std={out_manual.std():.4f}")
print(f"Torch output:  mean={out_torch.mean():.4f}, std={out_torch.std():.4f}")
# Both should have mean≈0, std≈1 after normalisation

BatchNorm in CNN vs MLP

Python
import torch
import torch.nn as nn

# For MLP (2D: batch × features): nn.BatchNorm1d
class ClinicalMLPWithBN(nn.Module):
    def __init__(self, n_features: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_features, 64),
            nn.BatchNorm1d(64),   # normalise over batch dimension
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Linear(32, 1),
        )
    
    def forward(self, x):
        return self.net(x)

# For CNN (4D: batch × channels × H × W): nn.BatchNorm2d
class ConvBlock(nn.Module):
    def __init__(self, in_c: int, out_c: int):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1, bias=False),  # no bias (BN has β)
            nn.BatchNorm2d(out_c),  # normalise over (batch, H, W) for each channel
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x):
        return self.block(x)

# Note: bias=False in Conv2d when using BatchNorm
# BatchNorm's β parameter acts as the bias → no need for an extra bias term

mlp = ClinicalMLPWithBN(n_features=20)
X = torch.randn(32, 20)
mlp.train()
out = mlp(X)
print(f"MLP output shape: {out.shape}")  # (32, 1)

Train vs Eval Mode: Critical Difference

Python
import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(10, 32),
    nn.BatchNorm1d(32),
    nn.ReLU(),
    nn.Linear(32, 1),
)

X_train_batch = torch.randn(64, 10)    # large batch
X_eval_single = torch.randn(1, 10)    # single sample at inference

# ── TRAINING mode ──
model.train()
out_train = model(X_train_batch)
print(f"Training mode, batch of 64: {out_train.shape}")

# WRONG: BatchNorm1d fails with batch_size=1 in training mode!
try:
    model.train()
    out_single_train = model(X_eval_single)
except Exception as e:
    print(f"Error with batch_size=1 in train mode: {type(e).__name__}: {e}")

# CORRECT: switch to eval mode for inference
model.eval()
out_single_eval = model(X_eval_single)
print(f"Eval mode, single sample: {out_single_eval.shape}")   # works!

# Another common bug: forgetting model.eval() in validation loop
#  BN uses batch statistics from val set  different from running stats
#  Validation loss appears different from what it will be at deployment

Layer Norm vs Batch Norm

Python
import torch
import torch.nn as nn

# BatchNorm1d: normalise over BATCH dimension (per-feature statistics)
#   input: (batch, features)
#   normalise: across batch for each feature
#   problem: statistics depend on batch size  fails for batch_size=1

# LayerNorm: normalise over FEATURE dimension (per-sample statistics)
#   input: (batch, features)
#   normalise: across features for each sample
#   works with batch_size=1  standard for Transformers

X = torch.randn(8, 32)   # (batch=8, features=32)

bn = nn.BatchNorm1d(32)
ln = nn.LayerNorm(32)

bn.train()
out_bn = bn(X)
out_ln = ln(X)

print(f"BatchNorm: per-feature mean={out_bn.mean(0)[:3]}")   # ≈0 across batch
print(f"LayerNorm: per-sample mean={out_ln.mean(1)[:3]}")     # ≈0 across features

# Use BatchNorm for: CNNs, MLPs (when batch size  8)
# Use LayerNorm for: Transformers, RNNs, any variable-length or single-sample inference
# Use GroupNorm for: CNNs with small batch sizes (common in object detection)

# Transformer FFN with LayerNorm
class TransformerBlock(nn.Module):
    def __init__(self, d_model: int = 512, d_ff: int = 2048):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.ff(self.norm1(x))   # pre-norm style (modern standard)
        return x

Benefits of Batch Normalisation

1. Faster convergence: reduces internal covariate shift → higher learning rates possible
2. Regularisation: adds noise (batch statistics vary per batch) → mild regularisation effect
3. Gradient flow: re-normalises activations → prevents vanishing/exploding activations
4. Less sensitive to initialisation: can use larger learning rates
5. Enables training of very deep networks (without skip connections)

Limitations:
  - Requires batch_size ≥ 8 (small batches → noisy statistics)
  - Different behaviour in train vs eval (a common source of bugs)
  - Not suitable for recurrent networks (each timestep is a different "feature")
  - Online inference with batch_size=1: must use eval mode

Interview Answer

"BatchNorm normalises the inputs to each layer using mini-batch statistics (mean and variance), then applies learnable scale (γ) and shift (β) parameters. This reduces internal covariate shift — the tendency for each layer's input distribution to change as upstream parameters update — enabling higher learning rates and faster convergence. BatchNorm also acts as a mild regulariser because batch statistics are noisy (different per mini-batch). The critical practical detail: BatchNorm has two modes. In training mode, it uses the current mini-batch statistics and updates running statistics via exponential moving average. In eval mode, it uses the accumulated running statistics — essential because at inference you may have batch_size=1. Forgetting model.eval() in validation causes the model to use batch statistics from the validation set, making validation loss inconsistent with production performance. For Transformers: use LayerNorm (normalises over features per sample) rather than BatchNorm — LayerNorm works with any batch size and is sequence-length agnostic."

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.