Learnixo

Deep Learning for AI Interviews · Lesson 44 of 56

Transfer Learning with Pretrained CNNs

Why Transfer Learning Works

ImageNet pre-trained CNNs learn:
  Layer 1–2: Edge detectors, colour blobs (universal to all images)
  Layer 3–4: Textures, shapes (broadly transferable)
  Layer 5+:  ImageNet-specific features (dogs, cars, furniture)

Medical imaging benefits because:
  - Early layers are universally useful (edges in X-rays, textures in histology)
  - Training data for medical tasks is often limited
  - Pre-training provides strong weight initialisation
  - Dramatically reduces training time and data requirements

Results: ImageNet-pretrained ResNet50 on chest X-ray (5K images)
  From scratch:    AUC ≈ 0.72
  Fine-tuned:      AUC ≈ 0.87
  Full CheXpert:   AUC ≈ 0.93

Three Fine-Tuning Strategies

Python
import torch
import torch.nn as nn
import torchvision.models as models

# ── Strategy 1: Feature Extraction (freeze all backbone) ──
# Use when: very small dataset (<1K), or features are highly similar

def setup_feature_extraction(n_classes: int = 2) -> nn.Module:
    model = models.resnet50(pretrained=True)
    
    # Freeze ALL backbone weights
    for param in model.parameters():
        param.requires_grad = False
    
    # Replace classification head (unfrozen by default since it's new)
    in_features = model.fc.in_features   # 2048 for ResNet50
    model.fc = nn.Sequential(
        nn.Linear(in_features, 256),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(256, n_classes),
    )
    
    # Only head trains
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable params: {trainable:,} (head only)")
    return model

# ── Strategy 2: Fine-Tuning Last N Layers ──
# Use when: moderate dataset (1K–50K), domain gap is moderate

def setup_partial_finetune(n_classes: int = 2, unfreeze_layers: int = 2) -> nn.Module:
    model = models.resnet50(pretrained=True)
    
    # Freeze all first
    for param in model.parameters():
        param.requires_grad = False
    
    # Unfreeze last N layer groups
    layer_groups = [model.layer1, model.layer2, model.layer3, model.layer4]
    for layer in layer_groups[-unfreeze_layers:]:
        for param in layer.parameters():
            param.requires_grad = True
    
    # Replace head
    model.fc = nn.Linear(model.fc.in_features, n_classes)
    
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable params: {trainable:,} (last {unfreeze_layers} stages + head)")
    return model

# ── Strategy 3: Full Fine-Tuning (discriminative learning rates) ──
# Use when: large dataset (>50K) or significant domain shift

def setup_full_finetune(n_classes: int = 2) -> tuple[nn.Module, torch.optim.Optimizer]:
    model = models.resnet50(pretrained=True)
    
    # Replace head
    model.fc = nn.Linear(model.fc.in_features, n_classes)
    
    # Discriminative learning rates: earlier layers get smaller lr
    optimizer = torch.optim.AdamW([
        {"params": model.layer1.parameters(), "lr": 1e-5},
        {"params": model.layer2.parameters(), "lr": 1e-5},
        {"params": model.layer3.parameters(), "lr": 1e-4},
        {"params": model.layer4.parameters(), "lr": 1e-4},
        {"params": model.fc.parameters(),     "lr": 1e-3},
    ], lr=1e-4, weight_decay=1e-4)
    
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable params: {trainable:,} (full model, varied lr)")
    return model, optimizer

Full Fine-Tuning Pipeline

Python
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T
from torch.utils.data import DataLoader

def build_medical_cnn(n_classes: int = 2, pretrained: bool = True) -> nn.Module:
    """ResNet50 adapted for grayscale medical images."""
    model = models.resnet50(pretrained=pretrained)
    
    # For grayscale input: replace first conv to accept 1 channel
    # Keep pretrained weights by averaging across RGB channels
    old_conv = model.conv1
    new_conv = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    new_conv.weight.data = old_conv.weight.data.mean(dim=1, keepdim=True)
    model.conv1 = new_conv
    
    # New classification head
    model.fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(model.fc.in_features, n_classes),
    )
    
    return model

def train_medical_cnn(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    n_epochs: int = 30,
    warmup_epochs: int = 5,
) -> nn.Module:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # Separate optimizer with different lr for backbone vs head
    backbone_params = [p for name, p in model.named_parameters() if "fc" not in name]
    head_params     = list(model.fc.parameters())
    
    optimizer = torch.optim.AdamW([
        {"params": backbone_params, "lr": 1e-4, "weight_decay": 1e-4},
        {"params": head_params,     "lr": 1e-3, "weight_decay": 1e-4},
    ])
    
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
    
    best_val_loss = float("inf")
    
    for epoch in range(n_epochs):
        # Training
        model.train()
        train_loss = 0.0
        for X, y in train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            loss = criterion(model(X), y)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_loss += loss.item()
        
        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for X, y in val_loader:
                X, y = X.to(device), y.to(device)
                out = model(X)
                val_loss += criterion(out, y).item()
                correct += (out.argmax(1) == y).sum().item()
                total += y.size(0)
        
        val_loss /= len(val_loader)
        scheduler.step()
        
        print(f"Epoch {epoch+1:3d}: train={train_loss/len(train_loader):.4f}, "
              f"val={val_loss:.4f}, acc={correct/total:.3f}")
    
    return model

Data Augmentation for Medical Imaging

Python
import torchvision.transforms as T

# Medical imaging augmentations  more conservative than natural images
# (pathological features are subtle; aggressive augmentation may destroy them)

train_transform = T.Compose([
    T.Resize(256),
    T.RandomCrop(224),
    T.RandomHorizontalFlip(0.5),         # OK for chest X-ray (L/R symmetric)
    T.RandomRotation(degrees=10),         # small rotation only
    T.ColorJitter(brightness=0.1, contrast=0.1),  # subtle lighting variation
    T.ToTensor(),
    T.Normalize(mean=[0.485], std=[0.229]),  # grayscale ImageNet stats
])

val_transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485], std=[0.229]),
])

# Do NOT use:
#   - RandomVerticalFlip (inverts anatomy)
#   - Strong colour jitter (alters intensity-based diagnoses)
#   - Extreme rotations (>15°)  may be clinically implausible

Transfer Learning Decision

Dataset size  | Domain similarity | Strategy
--------------|-------------------|------------------------------------------
< 1K          | Similar           | Feature extraction (freeze backbone)
< 1K          | Different         | Feature extraction or random init of early layers
1K – 50K      | Similar           | Fine-tune last 1–2 stages + head
1K – 50K      | Different         | Fine-tune more layers, lower backbone lr
> 50K         | Any               | Full fine-tune with discriminative lrs
> 500K        | Very different    | Consider training from scratch or domain-specific pretrain

Medical pre-training resources:
  - CheXpert: chest X-ray (224K images)
  - PathologyNet: histopathology
  - MedCLIP: medical image-text pairs
  - BioViL: clinical report + X-ray CLIP
  
Using medical domain pre-training often beats ImageNet pre-training
for clinical tasks — reduced domain gap matters.

Interview Answer

"Transfer learning uses knowledge from a large pre-training task (usually ImageNet) to improve performance on a smaller target task. CNNs are particularly well-suited because early layers learn universal features (edges, textures) that transfer across domains — including to medical imaging. Three fine-tuning strategies: (1) Feature extraction — freeze backbone, train only the new head; for very small datasets (under 1K); (2) Partial fine-tuning — unfreeze last 1–2 stages and head; for moderate datasets; (3) Full fine-tuning with discriminative learning rates — smaller lr for early layers (they already learned useful features), larger lr for later layers and the head; for large datasets. For grayscale medical images: adapt the first conv layer by averaging pretrained RGB weights across channels (1/3 of each). Key insight: even with a domain gap (natural images vs X-rays), ImageNet pretraining almost always outperforms random initialisation — the feature hierarchy is broadly useful."