Data Augmentation
Creating training variety through transforms, time-series augmentation for clinical signals, Mixup, CutMix, and test-time augmentation for better inference.
Why Augmentation Works
Data augmentation synthetically increases dataset size by creating transformed copies.
Each augmented example should preserve the semantic label (a rotated X-ray still shows the same condition).
Effect on training:
1. More training examples → harder to memorise, forces generalisation
2. Invariance: the model learns that predictions should be stable under transforms
3. Regularisation: equivalent to penalising predictions that are sensitive to augmentations
Rule of thumb:
If you can see the label is preserved after the transform, it's safe.
If the transform could change the clinical finding, don't use it.Standard Image Augmentation
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
# ── Safe for chest X-ray ──
xray_train_transform = T.Compose([
T.Resize(256),
T.RandomResizedCrop(224, scale=(0.8, 1.0)), # mild zoom
T.RandomHorizontalFlip(p=0.5), # OK: L/R symmetric
T.RandomRotation(degrees=10), # small rotation only
T.ColorJitter(brightness=0.1, contrast=0.1), # mild intensity
T.ToTensor(),
T.Normalize(mean=[0.5], std=[0.5]),
])
# ── More aggressive: histopathology (rotation-invariant) ──
pathology_train_transform = T.Compose([
T.Resize(256),
T.RandomCrop(224),
T.RandomHorizontalFlip(),
T.RandomVerticalFlip(), # valid: tissue is orientation-invariant
T.RandomRotation(degrees=90), # full 90° steps
T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# ── Natural image (ImageNet-style) ──
imagenet_train_transform = T.Compose([
T.RandomResizedCrop(224),
T.RandomHorizontalFlip(),
T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
T.RandomErasing(p=0.5),
])
# Validation: NO random augmentation — only deterministic resizing
val_transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])Augmenting Clinical Time-Series (ECG)
import torch
import numpy as np
def augment_ecg(
signal: torch.Tensor, # (C, T) — channels × timesteps
p: float = 0.5,
) -> torch.Tensor:
"""
Augmentation for ECG signals.
All transforms preserve clinical interpretation.
"""
signal = signal.clone()
# 1. Gaussian noise (simulate measurement noise)
if torch.rand(1).item() < p:
noise_level = 0.01 * signal.std()
signal += torch.randn_like(signal) * noise_level
# 2. Baseline wander (simulate electrode movement)
if torch.rand(1).item() < p:
t = torch.linspace(0, 1, signal.shape[-1])
freq = torch.rand(1).item() * 0.5 # low frequency wander
wander = 0.05 * torch.sin(2 * np.pi * freq * t)
signal += wander.unsqueeze(0)
# 3. Time scaling (small stretch/compression)
if torch.rand(1).item() < p:
scale = 0.9 + 0.2 * torch.rand(1).item() # [0.9, 1.1]
new_len = int(signal.shape[-1] * scale)
signal = torch.nn.functional.interpolate(
signal.unsqueeze(0), size=new_len, mode="linear", align_corners=False
).squeeze(0)
# Crop or pad to original length
if signal.shape[-1] >= 500:
signal = signal[:, :500]
else:
signal = torch.nn.functional.pad(signal, (0, 500 - signal.shape[-1]))
# 4. Random channel permutation (not for 12-lead ECG — leads have specific meaning)
# Only safe for unlabelled channel orders (e.g., accelerometer axes)
return signal
# Augmentation for tabular clinical data
def augment_tabular(
X: torch.Tensor, # (n_features,)
y: float,
noise_std: float = 0.02,
dropout_prob: float = 0.05,
) -> tuple[torch.Tensor, float]:
"""Add mild noise and feature dropout for tabular augmentation."""
X_aug = X.clone()
# 1. Gaussian noise (proportional to feature std)
X_aug += torch.randn_like(X) * noise_std
# 2. Feature dropout (simulate missing values)
mask = torch.rand_like(X) > dropout_prob
X_aug = X_aug * mask.float() # zero out dropped features
return X_aug, y
# Note: tabular augmentation is delicate — noise must be clinically plausible
# Do NOT augment features with natural constraints (age ≥ 0, INR ≥ 0.5)Mixup
import torch
import torch.nn as nn
def mixup_data(
X: torch.Tensor,
y: torch.Tensor,
alpha: float = 0.4,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
"""
Mixup: interpolate between two training examples and their labels.
X_mixed = λ·X_a + (1-λ)·X_b
y_mixed = λ·y_a + (1-λ)·y_b (soft labels)
λ ~ Beta(alpha, alpha)
"""
if alpha > 0:
lam = torch.distributions.Beta(alpha, alpha).sample().item()
else:
lam = 1.0
batch_size = X.shape[0]
idx = torch.randperm(batch_size) # random permutation for pairing
X_mixed = lam * X + (1 - lam) * X[idx]
y_a, y_b = y, y[idx]
return X_mixed, y_a, y_b, lam
def mixup_criterion(
criterion: nn.Module,
pred: torch.Tensor,
y_a: torch.Tensor,
y_b: torch.Tensor,
lam: float,
) -> torch.Tensor:
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
# Training with Mixup
model = nn.Sequential(nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 1))
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
criterion = nn.BCEWithLogitsLoss()
X = torch.randn(64, 20)
y = torch.randint(0, 2, (64,)).float()
# Mixup training step
X_mix, ya, yb, lam = mixup_data(X, y, alpha=0.4)
optimizer.zero_grad()
pred = model(X_mix).squeeze()
loss = mixup_criterion(criterion, pred, ya, yb, lam)
loss.backward()
optimizer.step()
print(f"Mixup loss: {loss.item():.4f}, λ={lam:.3f}")CutMix
import torch
import numpy as np
def cutmix_data(
X: torch.Tensor, # (B, C, H, W) images
y: torch.Tensor, # (B,) labels
alpha: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
"""
CutMix: paste a rectangular patch from one image into another.
More informative than Mixup for images (preserves spatial structure).
"""
lam = torch.distributions.Beta(alpha, alpha).sample().item()
B, C, H, W = X.shape
idx = torch.randperm(B)
# Sample cut box
cut_ratio = np.sqrt(1 - lam)
cut_w = int(W * cut_ratio)
cut_h = int(H * cut_ratio)
cx = np.random.randint(W)
cy = np.random.randint(H)
x1 = max(cx - cut_w // 2, 0)
x2 = min(cx + cut_w // 2, W)
y1 = max(cy - cut_h // 2, 0)
y2 = min(cy + cut_h // 2, H)
X_cut = X.clone()
X_cut[:, :, y1:y2, x1:x2] = X[idx, :, y1:y2, x1:x2]
# Adjust λ to match actual patch area
lam = 1 - (x2 - x1) * (y2 - y1) / (W * H)
return X_cut, y, y[idx], lam
# CutMix loss is same as Mixup lossTest-Time Augmentation (TTA)
import torch
import torch.nn as nn
def predict_with_tta(
model: nn.Module,
X: torch.Tensor, # single image (C, H, W)
n_augmentations: int = 8,
) -> torch.Tensor:
"""
Apply multiple augmentations at test time and average predictions.
Improves reliability — important for clinical AI.
"""
import torchvision.transforms.functional as TF
model.eval()
preds = []
augmentations = [
lambda x: x, # original
lambda x: TF.hflip(x), # horizontal flip
TF.adjust_brightness(X, 1.1), # not callable like this — shown for structure
TF.center_crop(X, 200),
]
# Simpler: generate n random augmented versions
augment = torch.nn.Sequential(
*[torch.nn.Identity()] * n_augmentations # placeholder
)
for _ in range(n_augmentations):
# Apply random augmentation
x_aug = X + 0.01 * torch.randn_like(X) # mild noise as example
x_aug = x_aug.unsqueeze(0) # add batch dim
with torch.no_grad():
pred = torch.sigmoid(model(x_aug).squeeze())
preds.append(pred)
# Average predictions
return torch.stack(preds).mean(dim=0)Interview Answer
"Data augmentation creates label-preserving transforms of training examples, effectively increasing dataset size and forcing models to learn invariances rather than memorising exact inputs. For images: random crops, flips, rotations, and colour jitter are standard. For clinical imaging: be conservative — horizontal flip is safe for chest X-rays (L/R symmetric), but avoid vertical flip (inverts anatomy) or strong colour jitter (alters density values). For ECG/time-series: Gaussian noise, baseline wander, and mild time scaling are valid. Advanced techniques: Mixup interpolates two examples and their labels (λ·x_a + (1-λ)·x_b, soft label), preventing confident predictions on interpolated regions; CutMix pastes a spatial patch from one image into another, preserving spatial structure better than Mixup. Test-time augmentation (TTA) averages predictions over multiple augmented versions of the test image, reducing variance at inference — especially valuable in clinical deployment where a single noisy prediction could affect a clinical decision."
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.