Transfer Learning with CNNs
Fine-tuning pre-trained ImageNet models for medical imaging — freezing strategies, learning rate schedules, and when to use transfer learning vs training from scratch.
Why Transfer Learning Works
ImageNet pre-trained CNNs learn:
Layer 1–2: Edge detectors, colour blobs (universal to all images)
Layer 3–4: Textures, shapes (broadly transferable)
Layer 5+: ImageNet-specific features (dogs, cars, furniture)
Medical imaging benefits because:
- Early layers are universally useful (edges in X-rays, textures in histology)
- Training data for medical tasks is often limited
- Pre-training provides strong weight initialisation
- Dramatically reduces training time and data requirements
Results: ImageNet-pretrained ResNet50 on chest X-ray (5K images)
From scratch: AUC ≈ 0.72
Fine-tuned: AUC ≈ 0.87
Full CheXpert: AUC ≈ 0.93Three Fine-Tuning Strategies
import torch
import torch.nn as nn
import torchvision.models as models
# ── Strategy 1: Feature Extraction (freeze all backbone) ──
# Use when: very small dataset (<1K), or features are highly similar
def setup_feature_extraction(n_classes: int = 2) -> nn.Module:
model = models.resnet50(pretrained=True)
# Freeze ALL backbone weights
for param in model.parameters():
param.requires_grad = False
# Replace classification head (unfrozen by default since it's new)
in_features = model.fc.in_features # 2048 for ResNet50
model.fc = nn.Sequential(
nn.Linear(in_features, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, n_classes),
)
# Only head trains
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable params: {trainable:,} (head only)")
return model
# ── Strategy 2: Fine-Tuning Last N Layers ──
# Use when: moderate dataset (1K–50K), domain gap is moderate
def setup_partial_finetune(n_classes: int = 2, unfreeze_layers: int = 2) -> nn.Module:
model = models.resnet50(pretrained=True)
# Freeze all first
for param in model.parameters():
param.requires_grad = False
# Unfreeze last N layer groups
layer_groups = [model.layer1, model.layer2, model.layer3, model.layer4]
for layer in layer_groups[-unfreeze_layers:]:
for param in layer.parameters():
param.requires_grad = True
# Replace head
model.fc = nn.Linear(model.fc.in_features, n_classes)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable params: {trainable:,} (last {unfreeze_layers} stages + head)")
return model
# ── Strategy 3: Full Fine-Tuning (discriminative learning rates) ──
# Use when: large dataset (>50K) or significant domain shift
def setup_full_finetune(n_classes: int = 2) -> tuple[nn.Module, torch.optim.Optimizer]:
model = models.resnet50(pretrained=True)
# Replace head
model.fc = nn.Linear(model.fc.in_features, n_classes)
# Discriminative learning rates: earlier layers get smaller lr
optimizer = torch.optim.AdamW([
{"params": model.layer1.parameters(), "lr": 1e-5},
{"params": model.layer2.parameters(), "lr": 1e-5},
{"params": model.layer3.parameters(), "lr": 1e-4},
{"params": model.layer4.parameters(), "lr": 1e-4},
{"params": model.fc.parameters(), "lr": 1e-3},
], lr=1e-4, weight_decay=1e-4)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable params: {trainable:,} (full model, varied lr)")
return model, optimizerFull Fine-Tuning Pipeline
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T
from torch.utils.data import DataLoader
def build_medical_cnn(n_classes: int = 2, pretrained: bool = True) -> nn.Module:
"""ResNet50 adapted for grayscale medical images."""
model = models.resnet50(pretrained=pretrained)
# For grayscale input: replace first conv to accept 1 channel
# Keep pretrained weights by averaging across RGB channels
old_conv = model.conv1
new_conv = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
new_conv.weight.data = old_conv.weight.data.mean(dim=1, keepdim=True)
model.conv1 = new_conv
# New classification head
model.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(model.fc.in_features, n_classes),
)
return model
def train_medical_cnn(
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
n_epochs: int = 30,
warmup_epochs: int = 5,
) -> nn.Module:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Separate optimizer with different lr for backbone vs head
backbone_params = [p for name, p in model.named_parameters() if "fc" not in name]
head_params = list(model.fc.parameters())
optimizer = torch.optim.AdamW([
{"params": backbone_params, "lr": 1e-4, "weight_decay": 1e-4},
{"params": head_params, "lr": 1e-3, "weight_decay": 1e-4},
])
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
best_val_loss = float("inf")
for epoch in range(n_epochs):
# Training
model.train()
train_loss = 0.0
for X, y in train_loader:
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
loss = criterion(model(X), y)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
train_loss += loss.item()
# Validation
model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for X, y in val_loader:
X, y = X.to(device), y.to(device)
out = model(X)
val_loss += criterion(out, y).item()
correct += (out.argmax(1) == y).sum().item()
total += y.size(0)
val_loss /= len(val_loader)
scheduler.step()
print(f"Epoch {epoch+1:3d}: train={train_loss/len(train_loader):.4f}, "
f"val={val_loss:.4f}, acc={correct/total:.3f}")
return modelData Augmentation for Medical Imaging
import torchvision.transforms as T
# Medical imaging augmentations — more conservative than natural images
# (pathological features are subtle; aggressive augmentation may destroy them)
train_transform = T.Compose([
T.Resize(256),
T.RandomCrop(224),
T.RandomHorizontalFlip(0.5), # OK for chest X-ray (L/R symmetric)
T.RandomRotation(degrees=10), # small rotation only
T.ColorJitter(brightness=0.1, contrast=0.1), # subtle lighting variation
T.ToTensor(),
T.Normalize(mean=[0.485], std=[0.229]), # grayscale ImageNet stats
])
val_transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485], std=[0.229]),
])
# Do NOT use:
# - RandomVerticalFlip (inverts anatomy)
# - Strong colour jitter (alters intensity-based diagnoses)
# - Extreme rotations (>15°) — may be clinically implausibleTransfer Learning Decision
Dataset size | Domain similarity | Strategy
--------------|-------------------|------------------------------------------
< 1K | Similar | Feature extraction (freeze backbone)
< 1K | Different | Feature extraction or random init of early layers
1K – 50K | Similar | Fine-tune last 1–2 stages + head
1K – 50K | Different | Fine-tune more layers, lower backbone lr
> 50K | Any | Full fine-tune with discriminative lrs
> 500K | Very different | Consider training from scratch or domain-specific pretrain
Medical pre-training resources:
- CheXpert: chest X-ray (224K images)
- PathologyNet: histopathology
- MedCLIP: medical image-text pairs
- BioViL: clinical report + X-ray CLIP
Using medical domain pre-training often beats ImageNet pre-training
for clinical tasks — reduced domain gap matters.Interview Answer
"Transfer learning uses knowledge from a large pre-training task (usually ImageNet) to improve performance on a smaller target task. CNNs are particularly well-suited because early layers learn universal features (edges, textures) that transfer across domains — including to medical imaging. Three fine-tuning strategies: (1) Feature extraction — freeze backbone, train only the new head; for very small datasets (under 1K); (2) Partial fine-tuning — unfreeze last 1–2 stages and head; for moderate datasets; (3) Full fine-tuning with discriminative learning rates — smaller lr for early layers (they already learned useful features), larger lr for later layers and the head; for large datasets. For grayscale medical images: adapt the first conv layer by averaging pretrained RGB weights across channels (1/3 of each). Key insight: even with a domain gap (natural images vs X-rays), ImageNet pretraining almost always outperforms random initialisation — the feature hierarchy is broadly useful."
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.