Learnixo

Deep Learning for AI Interviews · Lesson 51 of 56

Data Augmentation Strategies for Deep Learning

Why Augmentation Works

Data augmentation synthetically increases dataset size by creating transformed copies.
Each augmented example should preserve the semantic label (a rotated X-ray still shows the same condition).

Effect on training:
  1. More training examples → harder to memorise, forces generalisation
  2. Invariance: the model learns that predictions should be stable under transforms
  3. Regularisation: equivalent to penalising predictions that are sensitive to augmentations

Rule of thumb:
  If you can see the label is preserved after the transform, it's safe.
  If the transform could change the clinical finding, don't use it.

Standard Image Augmentation

Python
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF

# ── Safe for chest X-ray ──
xray_train_transform = T.Compose([
    T.Resize(256),
    T.RandomResizedCrop(224, scale=(0.8, 1.0)),   # mild zoom
    T.RandomHorizontalFlip(p=0.5),                 # OK: L/R symmetric
    T.RandomRotation(degrees=10),                  # small rotation only
    T.ColorJitter(brightness=0.1, contrast=0.1),  # mild intensity
    T.ToTensor(),
    T.Normalize(mean=[0.5], std=[0.5]),
])

# ── More aggressive: histopathology (rotation-invariant) ──
pathology_train_transform = T.Compose([
    T.Resize(256),
    T.RandomCrop(224),
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),            # valid: tissue is orientation-invariant
    T.RandomRotation(degrees=90),      # full 90° steps
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# ── Natural image (ImageNet-style) ──
imagenet_train_transform = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    T.RandomErasing(p=0.5),
])

# Validation: NO random augmentation  only deterministic resizing
val_transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

Augmenting Clinical Time-Series (ECG)

Python
import torch
import numpy as np

def augment_ecg(
    signal: torch.Tensor,   # (C, T)  channels × timesteps
    p: float = 0.5,
) -> torch.Tensor:
    """
    Augmentation for ECG signals.
    All transforms preserve clinical interpretation.
    """
    signal = signal.clone()
    
    # 1. Gaussian noise (simulate measurement noise)
    if torch.rand(1).item() < p:
        noise_level = 0.01 * signal.std()
        signal += torch.randn_like(signal) * noise_level
    
    # 2. Baseline wander (simulate electrode movement)
    if torch.rand(1).item() < p:
        t = torch.linspace(0, 1, signal.shape[-1])
        freq = torch.rand(1).item() * 0.5   # low frequency wander
        wander = 0.05 * torch.sin(2 * np.pi * freq * t)
        signal += wander.unsqueeze(0)
    
    # 3. Time scaling (small stretch/compression)
    if torch.rand(1).item() < p:
        scale = 0.9 + 0.2 * torch.rand(1).item()   # [0.9, 1.1]
        new_len = int(signal.shape[-1] * scale)
        signal = torch.nn.functional.interpolate(
            signal.unsqueeze(0), size=new_len, mode="linear", align_corners=False
        ).squeeze(0)
        # Crop or pad to original length
        if signal.shape[-1] >= 500:
            signal = signal[:, :500]
        else:
            signal = torch.nn.functional.pad(signal, (0, 500 - signal.shape[-1]))
    
    # 4. Random channel permutation (not for 12-lead ECG  leads have specific meaning)
    # Only safe for unlabelled channel orders (e.g., accelerometer axes)
    
    return signal

# Augmentation for tabular clinical data
def augment_tabular(
    X: torch.Tensor,      # (n_features,)
    y: float,
    noise_std: float = 0.02,
    dropout_prob: float = 0.05,
) -> tuple[torch.Tensor, float]:
    """Add mild noise and feature dropout for tabular augmentation."""
    X_aug = X.clone()
    
    # 1. Gaussian noise (proportional to feature std)
    X_aug += torch.randn_like(X) * noise_std
    
    # 2. Feature dropout (simulate missing values)
    mask = torch.rand_like(X) > dropout_prob
    X_aug = X_aug * mask.float()   # zero out dropped features
    
    return X_aug, y

# Note: tabular augmentation is delicate  noise must be clinically plausible
# Do NOT augment features with natural constraints (age  0, INR  0.5)

Mixup

Python
import torch
import torch.nn as nn

def mixup_data(
    X: torch.Tensor,
    y: torch.Tensor,
    alpha: float = 0.4,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
    """
    Mixup: interpolate between two training examples and their labels.
    X_mixed = λ·X_a + (1-λ)·X_b
    y_mixed = λ·y_a + (1-λ)·y_b  (soft labels)
    λ ~ Beta(alpha, alpha)
    """
    if alpha > 0:
        lam = torch.distributions.Beta(alpha, alpha).sample().item()
    else:
        lam = 1.0
    
    batch_size = X.shape[0]
    idx = torch.randperm(batch_size)   # random permutation for pairing
    
    X_mixed = lam * X + (1 - lam) * X[idx]
    y_a, y_b = y, y[idx]
    
    return X_mixed, y_a, y_b, lam

def mixup_criterion(
    criterion: nn.Module,
    pred: torch.Tensor,
    y_a: torch.Tensor,
    y_b: torch.Tensor,
    lam: float,
) -> torch.Tensor:
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

# Training with Mixup
model = nn.Sequential(nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 1))
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
criterion = nn.BCEWithLogitsLoss()

X = torch.randn(64, 20)
y = torch.randint(0, 2, (64,)).float()

# Mixup training step
X_mix, ya, yb, lam = mixup_data(X, y, alpha=0.4)
optimizer.zero_grad()
pred = model(X_mix).squeeze()
loss = mixup_criterion(criterion, pred, ya, yb, lam)
loss.backward()
optimizer.step()
print(f"Mixup loss: {loss.item():.4f}, λ={lam:.3f}")

CutMix

Python
import torch
import numpy as np

def cutmix_data(
    X: torch.Tensor,   # (B, C, H, W) images
    y: torch.Tensor,   # (B,) labels
    alpha: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
    """
    CutMix: paste a rectangular patch from one image into another.
    More informative than Mixup for images (preserves spatial structure).
    """
    lam = torch.distributions.Beta(alpha, alpha).sample().item()
    
    B, C, H, W = X.shape
    idx = torch.randperm(B)
    
    # Sample cut box
    cut_ratio = np.sqrt(1 - lam)
    cut_w = int(W * cut_ratio)
    cut_h = int(H * cut_ratio)
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    
    x1 = max(cx - cut_w // 2, 0)
    x2 = min(cx + cut_w // 2, W)
    y1 = max(cy - cut_h // 2, 0)
    y2 = min(cy + cut_h // 2, H)
    
    X_cut = X.clone()
    X_cut[:, :, y1:y2, x1:x2] = X[idx, :, y1:y2, x1:x2]
    
    # Adjust λ to match actual patch area
    lam = 1 - (x2 - x1) * (y2 - y1) / (W * H)
    
    return X_cut, y, y[idx], lam

# CutMix loss is same as Mixup loss

Test-Time Augmentation (TTA)

Python
import torch
import torch.nn as nn

def predict_with_tta(
    model: nn.Module,
    X: torch.Tensor,   # single image (C, H, W)
    n_augmentations: int = 8,
) -> torch.Tensor:
    """
    Apply multiple augmentations at test time and average predictions.
    Improves reliability — important for clinical AI.
    """
    import torchvision.transforms.functional as TF
    
    model.eval()
    preds = []
    
    augmentations = [
        lambda x: x,                                 # original
        lambda x: TF.hflip(x),                       # horizontal flip
        TF.adjust_brightness(X, 1.1),  # not callable like this  shown for structure
        TF.center_crop(X, 200),
    ]
    
    # Simpler: generate n random augmented versions
    augment = torch.nn.Sequential(
        *[torch.nn.Identity()] * n_augmentations   # placeholder
    )
    
    for _ in range(n_augmentations):
        # Apply random augmentation
        x_aug = X + 0.01 * torch.randn_like(X)   # mild noise as example
        x_aug = x_aug.unsqueeze(0)                 # add batch dim
        
        with torch.no_grad():
            pred = torch.sigmoid(model(x_aug).squeeze())
            preds.append(pred)
    
    # Average predictions
    return torch.stack(preds).mean(dim=0)

Interview Answer

"Data augmentation creates label-preserving transforms of training examples, effectively increasing dataset size and forcing models to learn invariances rather than memorising exact inputs. For images: random crops, flips, rotations, and colour jitter are standard. For clinical imaging: be conservative — horizontal flip is safe for chest X-rays (L/R symmetric), but avoid vertical flip (inverts anatomy) or strong colour jitter (alters density values). For ECG/time-series: Gaussian noise, baseline wander, and mild time scaling are valid. Advanced techniques: Mixup interpolates two examples and their labels (λ·x_a + (1-λ)·x_b, soft label), preventing confident predictions on interpolated regions; CutMix pastes a spatial patch from one image into another, preserving spatial structure better than Mixup. Test-time augmentation (TTA) averages predictions over multiple augmented versions of the test image, reducing variance at inference — especially valuable in clinical deployment where a single noisy prediction could affect a clinical decision."