Deep Learning for AI Interviews · Lesson 51 of 56
Data Augmentation Strategies for Deep Learning
Why Augmentation Works
Data augmentation synthetically increases dataset size by creating transformed copies.
Each augmented example should preserve the semantic label (a rotated X-ray still shows the same condition).
Effect on training:
1. More training examples → harder to memorise, forces generalisation
2. Invariance: the model learns that predictions should be stable under transforms
3. Regularisation: equivalent to penalising predictions that are sensitive to augmentations
Rule of thumb:
If you can see the label is preserved after the transform, it's safe.
If the transform could change the clinical finding, don't use it.Standard Image Augmentation
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
# ── Safe for chest X-ray ──
xray_train_transform = T.Compose([
T.Resize(256),
T.RandomResizedCrop(224, scale=(0.8, 1.0)), # mild zoom
T.RandomHorizontalFlip(p=0.5), # OK: L/R symmetric
T.RandomRotation(degrees=10), # small rotation only
T.ColorJitter(brightness=0.1, contrast=0.1), # mild intensity
T.ToTensor(),
T.Normalize(mean=[0.5], std=[0.5]),
])
# ── More aggressive: histopathology (rotation-invariant) ──
pathology_train_transform = T.Compose([
T.Resize(256),
T.RandomCrop(224),
T.RandomHorizontalFlip(),
T.RandomVerticalFlip(), # valid: tissue is orientation-invariant
T.RandomRotation(degrees=90), # full 90° steps
T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# ── Natural image (ImageNet-style) ──
imagenet_train_transform = T.Compose([
T.RandomResizedCrop(224),
T.RandomHorizontalFlip(),
T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
T.RandomErasing(p=0.5),
])
# Validation: NO random augmentation — only deterministic resizing
val_transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])Augmenting Clinical Time-Series (ECG)
import torch
import numpy as np
def augment_ecg(
signal: torch.Tensor, # (C, T) — channels × timesteps
p: float = 0.5,
) -> torch.Tensor:
"""
Augmentation for ECG signals.
All transforms preserve clinical interpretation.
"""
signal = signal.clone()
# 1. Gaussian noise (simulate measurement noise)
if torch.rand(1).item() < p:
noise_level = 0.01 * signal.std()
signal += torch.randn_like(signal) * noise_level
# 2. Baseline wander (simulate electrode movement)
if torch.rand(1).item() < p:
t = torch.linspace(0, 1, signal.shape[-1])
freq = torch.rand(1).item() * 0.5 # low frequency wander
wander = 0.05 * torch.sin(2 * np.pi * freq * t)
signal += wander.unsqueeze(0)
# 3. Time scaling (small stretch/compression)
if torch.rand(1).item() < p:
scale = 0.9 + 0.2 * torch.rand(1).item() # [0.9, 1.1]
new_len = int(signal.shape[-1] * scale)
signal = torch.nn.functional.interpolate(
signal.unsqueeze(0), size=new_len, mode="linear", align_corners=False
).squeeze(0)
# Crop or pad to original length
if signal.shape[-1] >= 500:
signal = signal[:, :500]
else:
signal = torch.nn.functional.pad(signal, (0, 500 - signal.shape[-1]))
# 4. Random channel permutation (not for 12-lead ECG — leads have specific meaning)
# Only safe for unlabelled channel orders (e.g., accelerometer axes)
return signal
# Augmentation for tabular clinical data
def augment_tabular(
X: torch.Tensor, # (n_features,)
y: float,
noise_std: float = 0.02,
dropout_prob: float = 0.05,
) -> tuple[torch.Tensor, float]:
"""Add mild noise and feature dropout for tabular augmentation."""
X_aug = X.clone()
# 1. Gaussian noise (proportional to feature std)
X_aug += torch.randn_like(X) * noise_std
# 2. Feature dropout (simulate missing values)
mask = torch.rand_like(X) > dropout_prob
X_aug = X_aug * mask.float() # zero out dropped features
return X_aug, y
# Note: tabular augmentation is delicate — noise must be clinically plausible
# Do NOT augment features with natural constraints (age ≥ 0, INR ≥ 0.5)Mixup
import torch
import torch.nn as nn
def mixup_data(
X: torch.Tensor,
y: torch.Tensor,
alpha: float = 0.4,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
"""
Mixup: interpolate between two training examples and their labels.
X_mixed = λ·X_a + (1-λ)·X_b
y_mixed = λ·y_a + (1-λ)·y_b (soft labels)
λ ~ Beta(alpha, alpha)
"""
if alpha > 0:
lam = torch.distributions.Beta(alpha, alpha).sample().item()
else:
lam = 1.0
batch_size = X.shape[0]
idx = torch.randperm(batch_size) # random permutation for pairing
X_mixed = lam * X + (1 - lam) * X[idx]
y_a, y_b = y, y[idx]
return X_mixed, y_a, y_b, lam
def mixup_criterion(
criterion: nn.Module,
pred: torch.Tensor,
y_a: torch.Tensor,
y_b: torch.Tensor,
lam: float,
) -> torch.Tensor:
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
# Training with Mixup
model = nn.Sequential(nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 1))
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
criterion = nn.BCEWithLogitsLoss()
X = torch.randn(64, 20)
y = torch.randint(0, 2, (64,)).float()
# Mixup training step
X_mix, ya, yb, lam = mixup_data(X, y, alpha=0.4)
optimizer.zero_grad()
pred = model(X_mix).squeeze()
loss = mixup_criterion(criterion, pred, ya, yb, lam)
loss.backward()
optimizer.step()
print(f"Mixup loss: {loss.item():.4f}, λ={lam:.3f}")CutMix
import torch
import numpy as np
def cutmix_data(
X: torch.Tensor, # (B, C, H, W) images
y: torch.Tensor, # (B,) labels
alpha: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
"""
CutMix: paste a rectangular patch from one image into another.
More informative than Mixup for images (preserves spatial structure).
"""
lam = torch.distributions.Beta(alpha, alpha).sample().item()
B, C, H, W = X.shape
idx = torch.randperm(B)
# Sample cut box
cut_ratio = np.sqrt(1 - lam)
cut_w = int(W * cut_ratio)
cut_h = int(H * cut_ratio)
cx = np.random.randint(W)
cy = np.random.randint(H)
x1 = max(cx - cut_w // 2, 0)
x2 = min(cx + cut_w // 2, W)
y1 = max(cy - cut_h // 2, 0)
y2 = min(cy + cut_h // 2, H)
X_cut = X.clone()
X_cut[:, :, y1:y2, x1:x2] = X[idx, :, y1:y2, x1:x2]
# Adjust λ to match actual patch area
lam = 1 - (x2 - x1) * (y2 - y1) / (W * H)
return X_cut, y, y[idx], lam
# CutMix loss is same as Mixup lossTest-Time Augmentation (TTA)
import torch
import torch.nn as nn
def predict_with_tta(
model: nn.Module,
X: torch.Tensor, # single image (C, H, W)
n_augmentations: int = 8,
) -> torch.Tensor:
"""
Apply multiple augmentations at test time and average predictions.
Improves reliability — important for clinical AI.
"""
import torchvision.transforms.functional as TF
model.eval()
preds = []
augmentations = [
lambda x: x, # original
lambda x: TF.hflip(x), # horizontal flip
TF.adjust_brightness(X, 1.1), # not callable like this — shown for structure
TF.center_crop(X, 200),
]
# Simpler: generate n random augmented versions
augment = torch.nn.Sequential(
*[torch.nn.Identity()] * n_augmentations # placeholder
)
for _ in range(n_augmentations):
# Apply random augmentation
x_aug = X + 0.01 * torch.randn_like(X) # mild noise as example
x_aug = x_aug.unsqueeze(0) # add batch dim
with torch.no_grad():
pred = torch.sigmoid(model(x_aug).squeeze())
preds.append(pred)
# Average predictions
return torch.stack(preds).mean(dim=0)Interview Answer
"Data augmentation creates label-preserving transforms of training examples, effectively increasing dataset size and forcing models to learn invariances rather than memorising exact inputs. For images: random crops, flips, rotations, and colour jitter are standard. For clinical imaging: be conservative — horizontal flip is safe for chest X-rays (L/R symmetric), but avoid vertical flip (inverts anatomy) or strong colour jitter (alters density values). For ECG/time-series: Gaussian noise, baseline wander, and mild time scaling are valid. Advanced techniques: Mixup interpolates two examples and their labels (λ·x_a + (1-λ)·x_b, soft label), preventing confident predictions on interpolated regions; CutMix pastes a spatial patch from one image into another, preserving spatial structure better than Mixup. Test-time augmentation (TTA) averages predictions over multiple augmented versions of the test image, reducing variance at inference — especially valuable in clinical deployment where a single noisy prediction could affect a clinical decision."