Learnixo
Back to blog
AI Systemsintermediate

Generalisation Techniques in Deep Learning

The full toolkit for improving deep learning generalisation — data augmentation, label smoothing, mixup, weight decay, early stopping, and cross-validation.

Asma Hafeez KhanMay 21, 20265 min read
Deep LearningGeneralisationData AugmentationLabel SmoothingMixup
Share:𝕏

The Generalisation Problem

Training performance ≠ Deployment performance

Why they differ:
  Training: model sees the same examples repeatedly, gradient descent finds
            the configuration that minimises loss on exactly these examples
  
  Deployment: unseen examples from a slightly different distribution
              (different hospital, different patient population, different time)

Goal: maximise performance on the deployment distribution,
      not just the training distribution.

Data Augmentation

Artificially expand the training set by creating modified copies of existing examples.

Python
import torch
import torchvision.transforms as T
from PIL import Image

# Image augmentation (chest X-ray / medical imaging)
train_transform = T.Compose([
    T.RandomHorizontalFlip(p=0.5),    # random horizontal flip
    T.RandomRotation(degrees=10),      # small rotations
    T.ColorJitter(brightness=0.2, contrast=0.2),  # brightness/contrast
    T.RandomCrop(224, padding=20),     # crop with padding
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_transform = T.Compose([   # NO augmentation at validation
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


# Time series augmentation (ECG, vitals)
def augment_time_series(
    signal: torch.Tensor,     # shape: (channels, length)
    noise_std: float = 0.02,
    scale_range: tuple = (0.9, 1.1),
    time_shift: int = 10,
) -> torch.Tensor:
    """Common ECG augmentation techniques."""
    # Additive Gaussian noise
    signal = signal + torch.randn_like(signal) * noise_std
    
    # Amplitude scaling
    scale = torch.empty(1).uniform_(*scale_range)
    signal = signal * scale
    
    # Time shift
    shift = torch.randint(-time_shift, time_shift + 1, (1,)).item()
    signal = torch.roll(signal, shift, dims=-1)
    
    return signal

Label Smoothing

Replace hard labels (0 or 1) with soft labels, reducing overconfidence:

Python
import torch
import torch.nn as nn
import torch.nn.functional as F

class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing: float = 0.1, n_classes: int = 10):
        super().__init__()
        self.smoothing = smoothing
        self.n_classes = n_classes
    
    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        # Hard label: [0, 0, 1, 0, ...]
        # Smooth label: [ε/K, ε/K, 1-ε+ε/K, ε/K, ...] where ε=smoothing, K=n_classes
        
        confidence = 1.0 - self.smoothing
        smooth_val = self.smoothing / (self.n_classes - 1)
        
        log_probs = F.log_softmax(logits, dim=-1)
        
        # Standard cross-entropy on true class
        nll_loss = -log_probs.gather(dim=-1, index=targets.unsqueeze(1)).squeeze(1)
        
        # Uniform distribution over all classes
        smooth_loss = -log_probs.mean(dim=-1)
        
        loss = confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()

# PyTorch built-in
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)

# Why it helps:
# Standard CE trains the model to output very large logits for the true class
# Label smoothing prevents this  model can't be infinitely confident
# Reduces overconfidence and improves calibration

Mixup Training

Linearly interpolate between pairs of training examples and their labels:

Python
def mixup_data(
    x: torch.Tensor,
    y: torch.Tensor,
    alpha: float = 0.2,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
    """Returns mixed inputs, label pairs, and the mixing coefficient λ."""
    lam = np.random.beta(alpha, alpha)   # λ ~ Beta(α, α)
    batch_size = x.size(0)
    index = torch.randperm(batch_size)   # shuffle second example
    
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_loss(
    criterion: nn.Module,
    pred: torch.Tensor,
    y_a: torch.Tensor,
    y_b: torch.Tensor,
    lam: float,
) -> torch.Tensor:
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


# Training loop with mixup
for inputs, targets in train_loader:
    mixed_inputs, y_a, y_b, lam = mixup_data(inputs, targets, alpha=0.2)
    outputs = model(mixed_inputs)
    loss = mixup_loss(criterion, outputs, y_a, y_b, lam)
    loss.backward()
    optimizer.step()

Weight Decay (L2 Regularisation)

Python
# L2 penalty: loss_total = loss_task + λ × ‖W‖²_F
# Gradient: ∂loss/∂W = ∂loss_task/∂W + 2λW
# SGD update: W  W - lr × (grad + 2λW) = (1 - 2λ·lr) × W - lr × grad
#  weights decay toward zero each step

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-3,
    weight_decay=1e-4,   # λ = 1e-4 (typical range: 1e-5 to 1e-2)
)

# AdamW is preferred over Adam + weight_decay:
# Standard Adam applies L2 to the gradient before moment estimation
# AdamW applies weight decay separately  more principled

Early Stopping

Python
class EarlyStopping:
    def __init__(
        self,
        patience: int = 10,
        min_delta: float = 1e-4,
        restore_best_weights: bool = True,
    ):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best = restore_best_weights
        self.best_loss = float("inf")
        self.best_weights = None
        self.counter = 0
    
    def step(self, model: nn.Module, val_loss: float) -> bool:
        """Returns True if training should stop."""
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            if self.restore_best:
                import copy
                self.best_weights = copy.deepcopy(model.state_dict())
        else:
            self.counter += 1
        
        return self.counter >= self.patience
    
    def restore(self, model: nn.Module) -> None:
        if self.best_weights is not None:
            model.load_state_dict(self.best_weights)

Interview Answer

"Generalisation techniques address the gap between training and deployment performance. Data augmentation (random flips, rotations, noise) forces the model to learn invariant features. Label smoothing softens one-hot targets to prevent overconfident predictions and improves calibration. Mixup trains on convex combinations of example pairs, encouraging smoother decision boundaries. Weight decay (L2 regularisation, AdamW) penalises large weights, preferring simpler solutions. Early stopping halts training when validation loss stops improving, saving the best checkpoint. In clinical ML, generalisation is especially critical because deployment populations differ from training (different hospitals, time periods, demographics) — I always validate on temporally or institutionally held-out data, not just random splits."

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.