Learnixo
Back to blog
AI Systemsintermediate

Loss Functions

MSE, MAE, BCE, cross-entropy, focal loss — how each loss measures prediction error and which to use for regression, binary classification, and multi-class problems.

Asma Hafeez KhanMay 22, 20266 min read
Deep LearningLoss FunctionsCross-EntropyMSEFocal LossInterview
Share:𝕏

Loss = Feedback Signal for Training

The loss function measures how wrong the model's predictions are.
Its gradient tells the optimiser in which direction to update weights.

Choosing the wrong loss gives misleading gradients → model learns the wrong thing.

Key principle: match the loss to the output distribution assumption.
  Gaussian outputs (regression) → MSE (MLE for Gaussian)
  Bernoulli outputs (binary class) → BCE (MLE for Bernoulli)
  Categorical outputs (multi-class) → Cross-entropy (MLE for Categorical)

Regression Losses

Python
import torch
import torch.nn as nn
import numpy as np

# Simulated predictions and targets
y_pred = torch.tensor([2.5, 0.0, 2.0, 8.0])
y_true = torch.tensor([3.0, -0.5, 2.0, 7.0])

# ── MSE: Mean Squared Error ──
# L = (1/n) Σ (y_pred - y_true)²
# Penalises large errors heavily (squared)
# Sensitive to outliers
mse = nn.MSELoss()
print(f"MSE: {mse(y_pred, y_true):.4f}")

# ── MAE: Mean Absolute Error ──
# L = (1/n) Σ |y_pred - y_true|
# Robust to outliers (linear penalty)
# But: gradient is constant (not informative near minimum)
mae = nn.L1Loss()
print(f"MAE: {mae(y_pred, y_true):.4f}")

# ── Huber Loss (Smooth L1) ──
# Quadratic for small errors, linear for large errors
# Best of both: smooth near min + robust to outliers
# δ controls transition point
huber = nn.HuberLoss(delta=1.0)
print(f"Huber: {huber(y_pred, y_true):.4f}")

# When to use which:
# MSE: when outliers are meaningful (they represent important cases)
# MAE: when dataset has real outliers that shouldn't dominate
# Huber: general-purpose robust regression

# Clinical example: predicting INR value
# Outlier INR=9 may be a real dangerous case → use MSE, not MAE

Binary Classification Loss

Python
# ── Binary Cross-Entropy (BCE) ──
# L = -[y·log(σ(z)) + (1-y)·log(1-σ(z))]
# where z is the raw logit, σ is sigmoid

# BCEWithLogitsLoss: takes raw logits (numerically stable)
bce_logits = nn.BCEWithLogitsLoss()

# BCELoss: takes probabilities (apply sigmoid first)
bce_prob = nn.BCELoss()

# Simulated: 4 patients, binary readmission label
logits = torch.tensor([1.5, -0.5, 2.0, -1.0])     # raw network outputs
labels = torch.tensor([1.0, 0.0, 1.0, 0.0])         # ground truth

loss1 = bce_logits(logits, labels)
probs = torch.sigmoid(logits)
loss2 = bce_prob(probs, labels)

print(f"BCEWithLogitsLoss: {loss1:.4f}")
print(f"BCELoss (sigmoid first): {loss2:.4f}")  # should match

# BCEWithLogitsLoss is always preferred  numerically more stable

Class Imbalance: Weighted BCE and Focal Loss

Python
# ── Weighted BCE: penalise minority class errors more ──
# In clinical data: readmission rate ~10%  10:1 imbalance
pos_weight = torch.tensor([9.0])  # weight positive examples  more
weighted_bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

logits = torch.tensor([1.5, -0.5, 2.0, -1.0])
labels = torch.tensor([1.0, 0.0, 1.0, 0.0])
print(f"Weighted BCE: {weighted_bce(logits, labels):.4f}")

# ── Focal Loss: down-weight easy examples ──
# FL(p_t) = -(1 - p_t)^γ · log(p_t)
# γ (gamma) controls focus: γ=0 is BCE, γ=2 is standard Focal
# Down-weights easy negatives, focusing on hard/rare positives

def focal_loss(
    logits: torch.Tensor,
    labels: torch.Tensor,
    gamma: float = 2.0,
    alpha: float = 0.25,   # weight for positive class
) -> torch.Tensor:
    p = torch.sigmoid(logits)
    p_t = p * labels + (1 - p) * (1 - labels)           # prob of true class
    alpha_t = alpha * labels + (1 - alpha) * (1 - labels)
    
    ce_loss = -torch.log(p_t + 1e-8)
    focal_weight = (1 - p_t) ** gamma
    loss = alpha_t * focal_weight * ce_loss
    return loss.mean()

fl = focal_loss(logits, labels, gamma=2.0, alpha=0.75)  # α=0.75 for 10:1 imbalance
print(f"Focal loss (γ=2): {fl:.4f}")

Multi-Class Loss

Python
# ── Cross-Entropy Loss ──
# L = -Σ_c y_c · log(softmax(z)_c)
# For a 5-class problem (e.g., clinical severity levels)

criterion = nn.CrossEntropyLoss()

# Takes raw logits (NOT softmax applied  handles internally)
logits = torch.randn(16, 5)     # (batch=16, n_classes=5)
labels = torch.randint(0, 5, (16,))    # class indices, not one-hot

loss = criterion(logits, labels)
print(f"Cross-entropy loss: {loss:.4f}")

# ── Label Smoothing ──
# Instead of hard target [0, 0, 1, 0, 0], use [ε/K, ε/K, 1-ε+ε/K, ε/K, ε/K]
# Prevents overconfidence, improves calibration
smooth_criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
smooth_loss = smooth_criterion(logits, labels)
print(f"Label-smoothed CE loss: {smooth_loss:.4f}")

# ── NLL Loss (when using log_softmax) ──
log_probs = torch.log_softmax(logits, dim=-1)
nll = nn.NLLLoss()(log_probs, labels)   # equivalent to CrossEntropyLoss

Loss for Clinical Calibration

Python
from sklearn.calibration import calibration_curve
import torch

# Calibration: does P(readmitted|score=0.8) really mean 80% readmission rate?
# Poor calibration  clinical decisions are based on misleading probabilities

def expected_calibration_error(
    probs: torch.Tensor,
    labels: torch.Tensor,
    n_bins: int = 10,
) -> float:
    """ECE measures miscalibration. Lower is better (0 = perfect calibration)."""
    probs_np = probs.numpy()
    labels_np = labels.numpy()
    
    bins = torch.linspace(0, 1, n_bins + 1)
    ece = 0.0
    n = len(probs_np)
    
    for i in range(n_bins):
        mask = (probs_np >= bins[i].item()) & (probs_np < bins[i+1].item())
        if mask.sum() == 0:
            continue
        bin_conf = probs_np[mask].mean()
        bin_acc  = labels_np[mask].mean()
        bin_size = mask.sum() / n
        ece += bin_size * abs(bin_acc - bin_conf)
    
    return float(ece)

# Simulate
probs  = torch.rand(500)
labels = (probs + 0.1 * torch.randn(500)).clamp(0, 1).round()
ece = expected_calibration_error(probs, labels)
print(f"ECE: {ece:.4f}")  # target < 0.05 for clinical use

Choosing the Right Loss

Task                            | Loss                   | PyTorch class
--------------------------------|------------------------|--------------------------------
Regression (normal noise)       | MSE                    | nn.MSELoss()
Regression (robust to outliers) | Huber                  | nn.HuberLoss()
Regression (sparse target)      | MAE                    | nn.L1Loss()
Binary classification           | BCE with logits        | nn.BCEWithLogitsLoss()
Binary + class imbalance        | Weighted BCE           | BCEWithLogitsLoss(pos_weight=w)
Binary + extreme imbalance      | Focal loss             | custom
Multi-class classification      | Cross-entropy          | nn.CrossEntropyLoss()
Multi-class + overconfidence    | Label-smoothed CE      | CrossEntropyLoss(label_smoothing=ε)
Multi-label classification      | BCE per label          | nn.BCEWithLogitsLoss()

Interview Answer

"Loss function choice maps to distributional assumptions: MSE for Gaussian regression, BCE for Bernoulli binary classification, cross-entropy for categorical multi-class. Always use BCEWithLogitsLoss (raw logits) not BCELoss (probabilities) — it's numerically more stable. For class imbalance (common in clinical data: 5–15% positive readmission rates), either weight the positive class via pos_weight or use Focal loss, which dynamically down-weights easy negatives. Label smoothing (ε=0.1) prevents overconfidence in multi-class tasks and often improves calibration — important in clinical AI where probability outputs are used for decision thresholds. Always validate that predicted probabilities are calibrated (use ECE or reliability diagrams) before clinical deployment, since a well-performing AUC model may still have severely miscalibrated probabilities."

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.