Statistics Inside AI Models
Where descriptive statistics appear inside neural networks and training pipelines — from batch normalisation to loss surfaces to gradient statistics.
Statistics Are Everywhere in ML
Basic descriptive statistics — mean, variance, standard deviation — appear constantly inside neural networks:
Batch Normalisation: normalises activations using batch mean and variance
Layer Normalisation: normalises using per-sample mean and variance
Weight Initialisation: sets initial weights with specific mean and std
Loss Monitoring: tracks mean and std of loss across steps
Gradient Clipping: clips gradients whose norm exceeds a threshold
Optimiser Moments: Adam tracks mean (1st moment) and variance (2nd moment)
of gradientsBatch Normalisation
import torch
import torch.nn as nn
import numpy as np
# Manual BN forward pass (conceptual)
def batch_norm_manual(
x: torch.Tensor, # shape: (batch, features)
gamma: torch.Tensor, # learnable scale
beta: torch.Tensor, # learnable shift
eps: float = 1e-5,
) -> torch.Tensor:
# Compute batch statistics
batch_mean = x.mean(dim=0) # mean over batch dimension
batch_var = x.var(dim=0, unbiased=False) # variance (population, not sample)
# Normalise
x_hat = (x - batch_mean) / torch.sqrt(batch_var + eps)
# Scale and shift (learned)
return gamma * x_hat + beta
# PyTorch BatchNorm
bn = nn.BatchNorm1d(256) # 256 features
# Internally tracks running_mean and running_var for inferenceLayer Normalisation (used in Transformers)
# LayerNorm normalises across the FEATURE dimension, not the batch
# Mean and std computed per sample, not per feature
def layer_norm_manual(
x: torch.Tensor, # shape: (batch, seq_len, d_model)
gamma: torch.Tensor, # shape: (d_model,)
beta: torch.Tensor, # shape: (d_model,)
eps: float = 1e-5,
) -> torch.Tensor:
# Compute statistics across last dimension (features)
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
x_hat = (x - mean) / torch.sqrt(var + eps)
return gamma * x_hat + beta
ln = nn.LayerNorm(768) # normalise 768-dim embeddingsWeight Initialisation
# Good initialisation keeps activations and gradients at reasonable scale
# Xavier/Glorot: std = sqrt(2 / (fan_in + fan_out))
# He/Kaiming: std = sqrt(2 / fan_in) — for ReLU
def xavier_std(fan_in: int, fan_out: int) -> float:
return (2 / (fan_in + fan_out)) ** 0.5
def kaiming_std(fan_in: int) -> float:
return (2 / fan_in) ** 0.5
# Linear 512 → 256
print(f"Xavier std: {xavier_std(512, 256):.4f}") # 0.0543
print(f"Kaiming std: {kaiming_std(512):.4f}") # 0.0625
layer = nn.Linear(512, 256)
nn.init.xavier_uniform_(layer.weight) # uniform distribution scaled by Xavier
nn.init.kaiming_normal_(layer.weight) # normal distribution scaled by HeAdam Optimiser: Tracking Gradient Moments
# Adam maintains two statistics per parameter:
# m_t: 1st moment (mean of gradients) — direction
# v_t: 2nd moment (mean of squared gradients) — variance/scale
class AdamManual:
def __init__(self, params, lr=1e-3, beta1=0.9, beta2=0.999, eps=1e-8):
self.params = params
self.lr, self.beta1, self.beta2, self.eps = lr, beta1, beta2, eps
self.m = [torch.zeros_like(p) for p in params] # 1st moment (mean)
self.v = [torch.zeros_like(p) for p in params] # 2nd moment (variance)
self.t = 0
def step(self):
self.t += 1
for i, p in enumerate(self.params):
g = p.grad
self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * g # EMA of gradient
self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * g ** 2 # EMA of gradient²
# Bias correction (corrects for initialisation at 0)
m_hat = self.m[i] / (1 - self.beta1 ** self.t)
v_hat = self.v[i] / (1 - self.beta2 ** self.t)
# Update: scale gradient by 1/√variance (adaptive learning rate)
p.data -= self.lr * m_hat / (torch.sqrt(v_hat) + self.eps)Gradient Statistics for Training Monitoring
def log_gradient_stats(model: nn.Module, step: int) -> dict:
"""Compute mean and std of gradient magnitudes — detects vanishing/exploding."""
grad_norms = []
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm().item()
grad_norms.append(grad_norm)
if not grad_norms:
return {}
return {
"step": step,
"grad_norm_mean": float(np.mean(grad_norms)),
"grad_norm_std": float(np.std(grad_norms)),
"grad_norm_max": float(np.max(grad_norms)),
}
# Alert if gradient norm mean is near 0 (vanishing) or > 10 (exploding)
stats = log_gradient_stats(model, step=1000)
if stats["grad_norm_mean"] < 1e-4:
print("Vanishing gradients — consider gradient clipping or LR increase")
if stats["grad_norm_max"] > 100:
print("Exploding gradients — check gradient clipping")Interview Answer
"Basic statistics permeate neural network training. Batch Normalisation computes the mean and variance of activations across the batch and normalises them, stabilising the training signal. Layer Normalisation does the same but across the feature dimension per sample — the standard in Transformers. Adam Optimiser maintains exponential moving averages of the gradient (mean) and squared gradient (variance) to adapt the learning rate per parameter. Weight initialisation schemes set initial weight standard deviations (Xavier: √(2/(fan_in + fan_out))) to keep activations neither saturated nor vanishingly small. Monitoring gradient norm mean and standard deviation during training detects vanishing and exploding gradient problems early."
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.