Loss Functions
MSE, MAE, BCE, cross-entropy, focal loss — how each loss measures prediction error and which to use for regression, binary classification, and multi-class problems.
Loss = Feedback Signal for Training
The loss function measures how wrong the model's predictions are.
Its gradient tells the optimiser in which direction to update weights.
Choosing the wrong loss gives misleading gradients → model learns the wrong thing.
Key principle: match the loss to the output distribution assumption.
Gaussian outputs (regression) → MSE (MLE for Gaussian)
Bernoulli outputs (binary class) → BCE (MLE for Bernoulli)
Categorical outputs (multi-class) → Cross-entropy (MLE for Categorical)Regression Losses
import torch
import torch.nn as nn
import numpy as np
# Simulated predictions and targets
y_pred = torch.tensor([2.5, 0.0, 2.0, 8.0])
y_true = torch.tensor([3.0, -0.5, 2.0, 7.0])
# ── MSE: Mean Squared Error ──
# L = (1/n) Σ (y_pred - y_true)²
# Penalises large errors heavily (squared)
# Sensitive to outliers
mse = nn.MSELoss()
print(f"MSE: {mse(y_pred, y_true):.4f}")
# ── MAE: Mean Absolute Error ──
# L = (1/n) Σ |y_pred - y_true|
# Robust to outliers (linear penalty)
# But: gradient is constant (not informative near minimum)
mae = nn.L1Loss()
print(f"MAE: {mae(y_pred, y_true):.4f}")
# ── Huber Loss (Smooth L1) ──
# Quadratic for small errors, linear for large errors
# Best of both: smooth near min + robust to outliers
# δ controls transition point
huber = nn.HuberLoss(delta=1.0)
print(f"Huber: {huber(y_pred, y_true):.4f}")
# When to use which:
# MSE: when outliers are meaningful (they represent important cases)
# MAE: when dataset has real outliers that shouldn't dominate
# Huber: general-purpose robust regression
# Clinical example: predicting INR value
# Outlier INR=9 may be a real dangerous case → use MSE, not MAEBinary Classification Loss
# ── Binary Cross-Entropy (BCE) ──
# L = -[y·log(σ(z)) + (1-y)·log(1-σ(z))]
# where z is the raw logit, σ is sigmoid
# BCEWithLogitsLoss: takes raw logits (numerically stable)
bce_logits = nn.BCEWithLogitsLoss()
# BCELoss: takes probabilities (apply sigmoid first)
bce_prob = nn.BCELoss()
# Simulated: 4 patients, binary readmission label
logits = torch.tensor([1.5, -0.5, 2.0, -1.0]) # raw network outputs
labels = torch.tensor([1.0, 0.0, 1.0, 0.0]) # ground truth
loss1 = bce_logits(logits, labels)
probs = torch.sigmoid(logits)
loss2 = bce_prob(probs, labels)
print(f"BCEWithLogitsLoss: {loss1:.4f}")
print(f"BCELoss (sigmoid first): {loss2:.4f}") # should match
# BCEWithLogitsLoss is always preferred — numerically more stableClass Imbalance: Weighted BCE and Focal Loss
# ── Weighted BCE: penalise minority class errors more ──
# In clinical data: readmission rate ~10% → 10:1 imbalance
pos_weight = torch.tensor([9.0]) # weight positive examples 9× more
weighted_bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
logits = torch.tensor([1.5, -0.5, 2.0, -1.0])
labels = torch.tensor([1.0, 0.0, 1.0, 0.0])
print(f"Weighted BCE: {weighted_bce(logits, labels):.4f}")
# ── Focal Loss: down-weight easy examples ──
# FL(p_t) = -(1 - p_t)^γ · log(p_t)
# γ (gamma) controls focus: γ=0 is BCE, γ=2 is standard Focal
# Down-weights easy negatives, focusing on hard/rare positives
def focal_loss(
logits: torch.Tensor,
labels: torch.Tensor,
gamma: float = 2.0,
alpha: float = 0.25, # weight for positive class
) -> torch.Tensor:
p = torch.sigmoid(logits)
p_t = p * labels + (1 - p) * (1 - labels) # prob of true class
alpha_t = alpha * labels + (1 - alpha) * (1 - labels)
ce_loss = -torch.log(p_t + 1e-8)
focal_weight = (1 - p_t) ** gamma
loss = alpha_t * focal_weight * ce_loss
return loss.mean()
fl = focal_loss(logits, labels, gamma=2.0, alpha=0.75) # α=0.75 for 10:1 imbalance
print(f"Focal loss (γ=2): {fl:.4f}")Multi-Class Loss
# ── Cross-Entropy Loss ──
# L = -Σ_c y_c · log(softmax(z)_c)
# For a 5-class problem (e.g., clinical severity levels)
criterion = nn.CrossEntropyLoss()
# Takes raw logits (NOT softmax applied — handles internally)
logits = torch.randn(16, 5) # (batch=16, n_classes=5)
labels = torch.randint(0, 5, (16,)) # class indices, not one-hot
loss = criterion(logits, labels)
print(f"Cross-entropy loss: {loss:.4f}")
# ── Label Smoothing ──
# Instead of hard target [0, 0, 1, 0, 0], use [ε/K, ε/K, 1-ε+ε/K, ε/K, ε/K]
# Prevents overconfidence, improves calibration
smooth_criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
smooth_loss = smooth_criterion(logits, labels)
print(f"Label-smoothed CE loss: {smooth_loss:.4f}")
# ── NLL Loss (when using log_softmax) ──
log_probs = torch.log_softmax(logits, dim=-1)
nll = nn.NLLLoss()(log_probs, labels) # equivalent to CrossEntropyLossLoss for Clinical Calibration
from sklearn.calibration import calibration_curve
import torch
# Calibration: does P(readmitted|score=0.8) really mean 80% readmission rate?
# Poor calibration → clinical decisions are based on misleading probabilities
def expected_calibration_error(
probs: torch.Tensor,
labels: torch.Tensor,
n_bins: int = 10,
) -> float:
"""ECE measures miscalibration. Lower is better (0 = perfect calibration)."""
probs_np = probs.numpy()
labels_np = labels.numpy()
bins = torch.linspace(0, 1, n_bins + 1)
ece = 0.0
n = len(probs_np)
for i in range(n_bins):
mask = (probs_np >= bins[i].item()) & (probs_np < bins[i+1].item())
if mask.sum() == 0:
continue
bin_conf = probs_np[mask].mean()
bin_acc = labels_np[mask].mean()
bin_size = mask.sum() / n
ece += bin_size * abs(bin_acc - bin_conf)
return float(ece)
# Simulate
probs = torch.rand(500)
labels = (probs + 0.1 * torch.randn(500)).clamp(0, 1).round()
ece = expected_calibration_error(probs, labels)
print(f"ECE: {ece:.4f}") # target < 0.05 for clinical useChoosing the Right Loss
Task | Loss | PyTorch class
--------------------------------|------------------------|--------------------------------
Regression (normal noise) | MSE | nn.MSELoss()
Regression (robust to outliers) | Huber | nn.HuberLoss()
Regression (sparse target) | MAE | nn.L1Loss()
Binary classification | BCE with logits | nn.BCEWithLogitsLoss()
Binary + class imbalance | Weighted BCE | BCEWithLogitsLoss(pos_weight=w)
Binary + extreme imbalance | Focal loss | custom
Multi-class classification | Cross-entropy | nn.CrossEntropyLoss()
Multi-class + overconfidence | Label-smoothed CE | CrossEntropyLoss(label_smoothing=ε)
Multi-label classification | BCE per label | nn.BCEWithLogitsLoss()Interview Answer
"Loss function choice maps to distributional assumptions: MSE for Gaussian regression, BCE for Bernoulli binary classification, cross-entropy for categorical multi-class. Always use BCEWithLogitsLoss (raw logits) not BCELoss (probabilities) — it's numerically more stable. For class imbalance (common in clinical data: 5–15% positive readmission rates), either weight the positive class via pos_weight or use Focal loss, which dynamically down-weights easy negatives. Label smoothing (ε=0.1) prevents overconfidence in multi-class tasks and often improves calibration — important in clinical AI where probability outputs are used for decision thresholds. Always validate that predicted probabilities are calibrated (use ECE or reliability diagrams) before clinical deployment, since a well-performing AUC model may still have severely miscalibrated probabilities."
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.