Learnixo
Back to blog
AI Systemsbeginner

Statistics Inside AI Models

Where descriptive statistics appear inside neural networks and training pipelines — from batch normalisation to loss surfaces to gradient statistics.

Asma Hafeez KhanMay 21, 20264 min read
StatisticsNeural NetworksBatch NormTrainingInterview
Share:𝕏

Statistics Are Everywhere in ML

Basic descriptive statistics — mean, variance, standard deviation — appear constantly inside neural networks:

Batch Normalisation:   normalises activations using batch mean and variance
Layer Normalisation:   normalises using per-sample mean and variance
Weight Initialisation: sets initial weights with specific mean and std
Loss Monitoring:       tracks mean and std of loss across steps
Gradient Clipping:     clips gradients whose norm exceeds a threshold
Optimiser Moments:     Adam tracks mean (1st moment) and variance (2nd moment)
                       of gradients

Batch Normalisation

Python
import torch
import torch.nn as nn
import numpy as np

# Manual BN forward pass (conceptual)
def batch_norm_manual(
    x: torch.Tensor,          # shape: (batch, features)
    gamma: torch.Tensor,       # learnable scale
    beta: torch.Tensor,        # learnable shift
    eps: float = 1e-5,
) -> torch.Tensor:
    # Compute batch statistics
    batch_mean = x.mean(dim=0)             # mean over batch dimension
    batch_var  = x.var(dim=0, unbiased=False)  # variance (population, not sample)
    
    # Normalise
    x_hat = (x - batch_mean) / torch.sqrt(batch_var + eps)
    
    # Scale and shift (learned)
    return gamma * x_hat + beta

# PyTorch BatchNorm
bn = nn.BatchNorm1d(256)           # 256 features
# Internally tracks running_mean and running_var for inference

Layer Normalisation (used in Transformers)

Python
# LayerNorm normalises across the FEATURE dimension, not the batch
# Mean and std computed per sample, not per feature

def layer_norm_manual(
    x: torch.Tensor,          # shape: (batch, seq_len, d_model)
    gamma: torch.Tensor,       # shape: (d_model,)
    beta: torch.Tensor,        # shape: (d_model,)
    eps: float = 1e-5,
) -> torch.Tensor:
    # Compute statistics across last dimension (features)
    mean = x.mean(dim=-1, keepdim=True)
    var  = x.var(dim=-1, keepdim=True, unbiased=False)
    
    x_hat = (x - mean) / torch.sqrt(var + eps)
    return gamma * x_hat + beta

ln = nn.LayerNorm(768)  # normalise 768-dim embeddings

Weight Initialisation

Python
# Good initialisation keeps activations and gradients at reasonable scale
# Xavier/Glorot: std = sqrt(2 / (fan_in + fan_out))
# He/Kaiming:    std = sqrt(2 / fan_in)  for ReLU

def xavier_std(fan_in: int, fan_out: int) -> float:
    return (2 / (fan_in + fan_out)) ** 0.5

def kaiming_std(fan_in: int) -> float:
    return (2 / fan_in) ** 0.5

# Linear 512  256
print(f"Xavier std: {xavier_std(512, 256):.4f}")  # 0.0543
print(f"Kaiming std: {kaiming_std(512):.4f}")     # 0.0625

layer = nn.Linear(512, 256)
nn.init.xavier_uniform_(layer.weight)  # uniform distribution scaled by Xavier
nn.init.kaiming_normal_(layer.weight)  # normal distribution scaled by He

Adam Optimiser: Tracking Gradient Moments

Python
# Adam maintains two statistics per parameter:
# m_t: 1st moment (mean of gradients)  direction
# v_t: 2nd moment (mean of squared gradients)  variance/scale

class AdamManual:
    def __init__(self, params, lr=1e-3, beta1=0.9, beta2=0.999, eps=1e-8):
        self.params = params
        self.lr, self.beta1, self.beta2, self.eps = lr, beta1, beta2, eps
        self.m = [torch.zeros_like(p) for p in params]  # 1st moment (mean)
        self.v = [torch.zeros_like(p) for p in params]  # 2nd moment (variance)
        self.t = 0
    
    def step(self):
        self.t += 1
        for i, p in enumerate(self.params):
            g = p.grad
            self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * g      # EMA of gradient
            self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * g ** 2  # EMA of gradient²
            
            # Bias correction (corrects for initialisation at 0)
            m_hat = self.m[i] / (1 - self.beta1 ** self.t)
            v_hat = self.v[i] / (1 - self.beta2 ** self.t)
            
            # Update: scale gradient by 1/√variance (adaptive learning rate)
            p.data -= self.lr * m_hat / (torch.sqrt(v_hat) + self.eps)

Gradient Statistics for Training Monitoring

Python
def log_gradient_stats(model: nn.Module, step: int) -> dict:
    """Compute mean and std of gradient magnitudes — detects vanishing/exploding."""
    grad_norms = []
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norm = param.grad.norm().item()
            grad_norms.append(grad_norm)
    
    if not grad_norms:
        return {}
    
    return {
        "step": step,
        "grad_norm_mean": float(np.mean(grad_norms)),
        "grad_norm_std": float(np.std(grad_norms)),
        "grad_norm_max": float(np.max(grad_norms)),
    }

# Alert if gradient norm mean is near 0 (vanishing) or > 10 (exploding)
stats = log_gradient_stats(model, step=1000)
if stats["grad_norm_mean"] < 1e-4:
    print("Vanishing gradients — consider gradient clipping or LR increase")
if stats["grad_norm_max"] > 100:
    print("Exploding gradients — check gradient clipping")

Interview Answer

"Basic statistics permeate neural network training. Batch Normalisation computes the mean and variance of activations across the batch and normalises them, stabilising the training signal. Layer Normalisation does the same but across the feature dimension per sample — the standard in Transformers. Adam Optimiser maintains exponential moving averages of the gradient (mean) and squared gradient (variance) to adapt the learning rate per parameter. Weight initialisation schemes set initial weight standard deviations (Xavier: √(2/(fan_in + fan_out))) to keep activations neither saturated nor vanishingly small. Monitoring gradient norm mean and standard deviation during training detects vanishing and exploding gradient problems early."

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.