Learnixo

Deep Learning for AI Interviews · Lesson 48 of 56

CNN in Production: Latency, Size, and Edge Deploy

Clinical Image Pipeline

DICOM file → Preprocessing → Model → Output → Clinical Decision

Preprocessing steps:
  1. Load DICOM: extract pixel array + metadata (window level, window width)
  2. Apply windowing: clamp HU values to clinically relevant range
  3. Normalise to [0, 1] or [-1, 1]
  4. Resize to model input size (224×224 or 512×512)
  5. Convert grayscale to pseudo-RGB if needed (repeat channel)
  6. Apply inference-time transforms (resize, centre crop, normalise)

HIPAA/GDPR compliance:
  - Strip PHI before any cloud processing
  - Audit logging of all predictions
  - Anonymise DICOM tags: PatientName, PatientID, DOB, etc.

DICOM Preprocessing

Python
import numpy as np
import torch
import pydicom
from pathlib import Path

def load_dicom_xray(
    path: str,
    target_size: tuple = (224, 224),
    window_centre: float = None,
    window_width: float = None,
) -> torch.Tensor:
    """
    Load a chest X-ray DICOM and return a normalised tensor.
    Returns: (1, H, W) float32 tensor in [0, 1]
    """
    ds = pydicom.dcmread(path)
    
    # Get pixel data
    pixels = ds.pixel_array.astype(np.float32)
    
    # Apply rescale slope/intercept (converts stored values to HU)
    slope     = float(getattr(ds, "RescaleSlope", 1.0))
    intercept = float(getattr(ds, "RescaleIntercept", 0.0))
    pixels    = pixels * slope + intercept
    
    # Apply windowing if specified
    if window_centre is None:
        window_centre = float(getattr(ds, "WindowCenter", pixels.mean()))
        if isinstance(window_centre, (list, tuple)):
            window_centre = float(window_centre[0])
    if window_width is None:
        window_width = float(getattr(ds, "WindowWidth", pixels.std() * 4))
        if isinstance(window_width, (list, tuple)):
            window_width = float(window_width[0])
    
    low  = window_centre - window_width / 2
    high = window_centre + window_width / 2
    pixels = np.clip(pixels, low, high)
    
    # Normalise to [0, 1]
    pixels = (pixels - low) / (window_width + 1e-8)
    
    # Resize
    import torchvision.transforms.functional as TF
    tensor = torch.from_numpy(pixels).unsqueeze(0)   # (1, H, W)
    tensor = TF.resize(tensor, list(target_size))
    
    return tensor.float()

def anonymise_dicom_metadata(ds: "pydicom.Dataset") -> "pydicom.Dataset":
    """Remove PHI from DICOM metadata."""
    PHI_TAGS = [
        "PatientName", "PatientID", "PatientBirthDate",
        "PatientSex", "InstitutionName", "ReferringPhysicianName",
        "StudyDescription", "OperatorsName", "StudyID",
    ]
    for tag in PHI_TAGS:
        if hasattr(ds, tag):
            setattr(ds, tag, "ANONYMISED")
    return ds

Grad-CAM: Gradient-Weighted Class Activation Map

Python
import torch
import torch.nn as nn
import torch.nn.functional as F

class GradCAM:
    """
    Visualises which regions of an image most influence the prediction.
    Uses gradients of the predicted class flowing into the last conv layer.
    """
    
    def __init__(self, model: nn.Module, target_layer: nn.Module):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        # Register hooks
        target_layer.register_forward_hook(self._save_activation)
        target_layer.register_backward_hook(self._save_gradient)
    
    def _save_activation(self, module, input, output) -> None:
        self.activations = output.detach()
    
    def _save_gradient(self, module, grad_input, grad_output) -> None:
        self.gradients = grad_output[0].detach()
    
    def generate(
        self,
        x: torch.Tensor,
        class_idx: int = None,
    ) -> tuple[torch.Tensor, int]:
        """Returns: (cam, predicted_class)"""
        self.model.eval()
        
        # Forward pass
        logits = self.model(x)
        
        if class_idx is None:
            class_idx = logits.argmax(dim=-1).item()
        
        # Backward pass for the target class
        self.model.zero_grad()
        logits[0, class_idx].backward()
        
        # Weight activations by global-average-pooled gradients
        weights = self.gradients.mean(dim=(2, 3), keepdim=True)  # (B, C, 1, 1)
        cam = (weights * self.activations).sum(dim=1, keepdim=True)  # (B, 1, H, W)
        cam = F.relu(cam)   # only positive contributions
        
        # Normalise to [0, 1]
        cam_min = cam.view(cam.size(0), -1).min(dim=-1)[0].view(-1, 1, 1, 1)
        cam_max = cam.view(cam.size(0), -1).max(dim=-1)[0].view(-1, 1, 1, 1)
        cam = (cam - cam_min) / (cam_max - cam_min + 1e-8)
        
        return cam.squeeze(), class_idx

# Usage with a ResNet
import torchvision.models as models

model = models.resnet50(pretrained=False)
model.fc = nn.Linear(2048, 2)

# Target the last conv layer (layer4[-1].conv3 for ResNet50)
target_layer = model.layer4[-1].conv3
gradcam = GradCAM(model, target_layer)

X = torch.randn(1, 3, 224, 224)
cam, pred_class = gradcam.generate(X)
print(f"Grad-CAM: {cam.shape}, predicted class: {pred_class}")
# cam is (14, 14)  upsample to (224, 224) for overlay on original image

Bias Detection and Subgroup Analysis

Python
import torch
import numpy as np
from sklearn.metrics import roc_auc_score

def evaluate_subgroup_performance(
    model: torch.nn.Module,
    data_by_subgroup: dict[str, tuple],  # {subgroup: (X, y)}
    device: torch.device,
) -> dict[str, dict]:
    """
    Evaluate model performance per demographic subgroup.
    Essential before clinical deployment to detect bias.
    """
    model.eval()
    results = {}
    
    for subgroup, (X, y) in data_by_subgroup.items():
        X = X.to(device)
        
        with torch.no_grad():
            logits = model(X).squeeze()
            probs  = torch.sigmoid(logits).cpu().numpy()
        
        y_np = y.numpy()
        n_positive = y_np.sum()
        n_total    = len(y_np)
        
        if n_positive == 0 or n_positive == n_total:
            auc = float("nan")
        else:
            auc = roc_auc_score(y_np, probs)
        
        results[subgroup] = {
            "n": n_total,
            "prevalence": float(n_positive / n_total),
            "auc": auc,
        }
        print(f"{subgroup:20s}: n={n_total:>5}, prevalence={n_positive/n_total:.1%}, AUC={auc:.4f}")
    
    # Flag subgroups with significantly lower AUC
    aucs = {k: v["auc"] for k, v in results.items() if not np.isnan(v["auc"])}
    if aucs:
        mean_auc = np.mean(list(aucs.values()))
        for subgroup, auc in aucs.items():
            if mean_auc - auc > 0.05:
                print(f"WARNING: {subgroup} AUC is {mean_auc - auc:.3f} below mean — investigate bias")
    
    return results

# Clinical subgroups to always check:
# - Age group (paediatric, adult, geriatric)
# - Sex at birth
# - Ethnicity (if available and appropriately consented)
# - Scanner manufacturer / hospital site
# - Image acquisition parameters (kV, mAs for X-ray)

Model Monitoring Post-Deployment

Python
import torch
import numpy as np
from collections import deque

class ClinicalModelMonitor:
    """Monitor predictions for drift and anomalies post-deployment."""
    
    def __init__(self, window_size: int = 1000, alert_threshold: float = 0.1):
        self.window_size = window_size
        self.alert_threshold = alert_threshold
        self.recent_probs = deque(maxlen=window_size)
        self.baseline_mean = None
        self.baseline_std  = None
    
    def calibrate(self, baseline_probs: list[float]) -> None:
        """Set baseline statistics from validation set."""
        self.baseline_mean = np.mean(baseline_probs)
        self.baseline_std  = np.std(baseline_probs)
        print(f"Baseline: mean={self.baseline_mean:.4f}, std={self.baseline_std:.4f}")
    
    def record_prediction(self, prob: float) -> dict | None:
        """Record a prediction and check for distribution shift."""
        self.recent_probs.append(prob)
        
        if len(self.recent_probs) < 100:
            return None   # not enough data
        
        current_mean = np.mean(self.recent_probs)
        drift = abs(current_mean - self.baseline_mean)
        
        if drift > self.alert_threshold:
            return {
                "alert": "DISTRIBUTION SHIFT",
                "baseline_mean": self.baseline_mean,
                "current_mean": current_mean,
                "drift": drift,
            }
        return None

monitor = ClinicalModelMonitor(window_size=500, alert_threshold=0.05)
monitor.calibrate(baseline_probs=[0.15] * 1000)   # 15% baseline positive rate

# Simulate drift
for _ in range(100):
    alert = monitor.record_prediction(0.15)   # stable
for _ in range(100):
    alert = monitor.record_prediction(0.25)   # shift!
    if alert:
        print(f"ALERT: {alert}")
        break

Interview Answer

"Deploying CNN-based clinical AI requires several additional considerations beyond model accuracy. Preprocessing: DICOM files need windowing (HU clamping) before pixel normalisation; PHI must be stripped for HIPAA/GDPR compliance. Explainability: Grad-CAM visualises which image regions drove the prediction by weighting the last conv layer's activations by the gradient of the predicted class — essential for clinician trust and regulatory review. Bias detection: always evaluate AUC per demographic subgroup (age, sex, scanner type, institution) before deployment — a gap of >0.05 AUC between subgroups warrants investigation. Post-deployment monitoring: track the distribution of predicted probabilities over time; a shift in the mean prediction indicates data drift (e.g., patient population change, scanner software update). Regulatory: UK/EU AI Act and FDA guidance require documented validation on representative populations, performance bounds, and incident reporting mechanisms."