Generalisation Techniques in Deep Learning
The full toolkit for improving deep learning generalisation — data augmentation, label smoothing, mixup, weight decay, early stopping, and cross-validation.
The Generalisation Problem
Training performance ≠ Deployment performance
Why they differ:
Training: model sees the same examples repeatedly, gradient descent finds
the configuration that minimises loss on exactly these examples
Deployment: unseen examples from a slightly different distribution
(different hospital, different patient population, different time)
Goal: maximise performance on the deployment distribution,
not just the training distribution.Data Augmentation
Artificially expand the training set by creating modified copies of existing examples.
import torch
import torchvision.transforms as T
from PIL import Image
# Image augmentation (chest X-ray / medical imaging)
train_transform = T.Compose([
T.RandomHorizontalFlip(p=0.5), # random horizontal flip
T.RandomRotation(degrees=10), # small rotations
T.ColorJitter(brightness=0.2, contrast=0.2), # brightness/contrast
T.RandomCrop(224, padding=20), # crop with padding
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
val_transform = T.Compose([ # NO augmentation at validation
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Time series augmentation (ECG, vitals)
def augment_time_series(
signal: torch.Tensor, # shape: (channels, length)
noise_std: float = 0.02,
scale_range: tuple = (0.9, 1.1),
time_shift: int = 10,
) -> torch.Tensor:
"""Common ECG augmentation techniques."""
# Additive Gaussian noise
signal = signal + torch.randn_like(signal) * noise_std
# Amplitude scaling
scale = torch.empty(1).uniform_(*scale_range)
signal = signal * scale
# Time shift
shift = torch.randint(-time_shift, time_shift + 1, (1,)).item()
signal = torch.roll(signal, shift, dims=-1)
return signalLabel Smoothing
Replace hard labels (0 or 1) with soft labels, reducing overconfidence:
import torch
import torch.nn as nn
import torch.nn.functional as F
class LabelSmoothingCrossEntropy(nn.Module):
def __init__(self, smoothing: float = 0.1, n_classes: int = 10):
super().__init__()
self.smoothing = smoothing
self.n_classes = n_classes
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
# Hard label: [0, 0, 1, 0, ...]
# Smooth label: [ε/K, ε/K, 1-ε+ε/K, ε/K, ...] where ε=smoothing, K=n_classes
confidence = 1.0 - self.smoothing
smooth_val = self.smoothing / (self.n_classes - 1)
log_probs = F.log_softmax(logits, dim=-1)
# Standard cross-entropy on true class
nll_loss = -log_probs.gather(dim=-1, index=targets.unsqueeze(1)).squeeze(1)
# Uniform distribution over all classes
smooth_loss = -log_probs.mean(dim=-1)
loss = confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()
# PyTorch built-in
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
# Why it helps:
# Standard CE trains the model to output very large logits for the true class
# Label smoothing prevents this — model can't be infinitely confident
# Reduces overconfidence and improves calibrationMixup Training
Linearly interpolate between pairs of training examples and their labels:
def mixup_data(
x: torch.Tensor,
y: torch.Tensor,
alpha: float = 0.2,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
"""Returns mixed inputs, label pairs, and the mixing coefficient λ."""
lam = np.random.beta(alpha, alpha) # λ ~ Beta(α, α)
batch_size = x.size(0)
index = torch.randperm(batch_size) # shuffle second example
mixed_x = lam * x + (1 - lam) * x[index]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
def mixup_loss(
criterion: nn.Module,
pred: torch.Tensor,
y_a: torch.Tensor,
y_b: torch.Tensor,
lam: float,
) -> torch.Tensor:
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
# Training loop with mixup
for inputs, targets in train_loader:
mixed_inputs, y_a, y_b, lam = mixup_data(inputs, targets, alpha=0.2)
outputs = model(mixed_inputs)
loss = mixup_loss(criterion, outputs, y_a, y_b, lam)
loss.backward()
optimizer.step()Weight Decay (L2 Regularisation)
# L2 penalty: loss_total = loss_task + λ × ‖W‖²_F
# Gradient: ∂loss/∂W = ∂loss_task/∂W + 2λW
# SGD update: W ← W - lr × (grad + 2λW) = (1 - 2λ·lr) × W - lr × grad
# → weights decay toward zero each step
optimizer = torch.optim.AdamW(
model.parameters(),
lr=1e-3,
weight_decay=1e-4, # λ = 1e-4 (typical range: 1e-5 to 1e-2)
)
# AdamW is preferred over Adam + weight_decay:
# Standard Adam applies L2 to the gradient before moment estimation
# AdamW applies weight decay separately — more principledEarly Stopping
class EarlyStopping:
def __init__(
self,
patience: int = 10,
min_delta: float = 1e-4,
restore_best_weights: bool = True,
):
self.patience = patience
self.min_delta = min_delta
self.restore_best = restore_best_weights
self.best_loss = float("inf")
self.best_weights = None
self.counter = 0
def step(self, model: nn.Module, val_loss: float) -> bool:
"""Returns True if training should stop."""
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
if self.restore_best:
import copy
self.best_weights = copy.deepcopy(model.state_dict())
else:
self.counter += 1
return self.counter >= self.patience
def restore(self, model: nn.Module) -> None:
if self.best_weights is not None:
model.load_state_dict(self.best_weights)Interview Answer
"Generalisation techniques address the gap between training and deployment performance. Data augmentation (random flips, rotations, noise) forces the model to learn invariant features. Label smoothing softens one-hot targets to prevent overconfident predictions and improves calibration. Mixup trains on convex combinations of example pairs, encouraging smoother decision boundaries. Weight decay (L2 regularisation, AdamW) penalises large weights, preferring simpler solutions. Early stopping halts training when validation loss stops improving, saving the best checkpoint. In clinical ML, generalisation is especially critical because deployment populations differ from training (different hospitals, time periods, demographics) — I always validate on temporally or institutionally held-out data, not just random splits."
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.