Learnixo
Back to blog
AI Systemsintermediate

CNNs in Real-World Clinical AI

Deploying CNN-based medical image models — DICOM preprocessing, clinical validation, bias detection, explainability with Grad-CAM, and regulatory considerations.

Asma Hafeez KhanMay 22, 20266 min read
Deep LearningCNNClinical AIDICOMGrad-CAMDeploymentInterview
Share:𝕏

Clinical Image Pipeline

DICOM file → Preprocessing → Model → Output → Clinical Decision

Preprocessing steps:
  1. Load DICOM: extract pixel array + metadata (window level, window width)
  2. Apply windowing: clamp HU values to clinically relevant range
  3. Normalise to [0, 1] or [-1, 1]
  4. Resize to model input size (224×224 or 512×512)
  5. Convert grayscale to pseudo-RGB if needed (repeat channel)
  6. Apply inference-time transforms (resize, centre crop, normalise)

HIPAA/GDPR compliance:
  - Strip PHI before any cloud processing
  - Audit logging of all predictions
  - Anonymise DICOM tags: PatientName, PatientID, DOB, etc.

DICOM Preprocessing

Python
import numpy as np
import torch
import pydicom
from pathlib import Path

def load_dicom_xray(
    path: str,
    target_size: tuple = (224, 224),
    window_centre: float = None,
    window_width: float = None,
) -> torch.Tensor:
    """
    Load a chest X-ray DICOM and return a normalised tensor.
    Returns: (1, H, W) float32 tensor in [0, 1]
    """
    ds = pydicom.dcmread(path)
    
    # Get pixel data
    pixels = ds.pixel_array.astype(np.float32)
    
    # Apply rescale slope/intercept (converts stored values to HU)
    slope     = float(getattr(ds, "RescaleSlope", 1.0))
    intercept = float(getattr(ds, "RescaleIntercept", 0.0))
    pixels    = pixels * slope + intercept
    
    # Apply windowing if specified
    if window_centre is None:
        window_centre = float(getattr(ds, "WindowCenter", pixels.mean()))
        if isinstance(window_centre, (list, tuple)):
            window_centre = float(window_centre[0])
    if window_width is None:
        window_width = float(getattr(ds, "WindowWidth", pixels.std() * 4))
        if isinstance(window_width, (list, tuple)):
            window_width = float(window_width[0])
    
    low  = window_centre - window_width / 2
    high = window_centre + window_width / 2
    pixels = np.clip(pixels, low, high)
    
    # Normalise to [0, 1]
    pixels = (pixels - low) / (window_width + 1e-8)
    
    # Resize
    import torchvision.transforms.functional as TF
    tensor = torch.from_numpy(pixels).unsqueeze(0)   # (1, H, W)
    tensor = TF.resize(tensor, list(target_size))
    
    return tensor.float()

def anonymise_dicom_metadata(ds: "pydicom.Dataset") -> "pydicom.Dataset":
    """Remove PHI from DICOM metadata."""
    PHI_TAGS = [
        "PatientName", "PatientID", "PatientBirthDate",
        "PatientSex", "InstitutionName", "ReferringPhysicianName",
        "StudyDescription", "OperatorsName", "StudyID",
    ]
    for tag in PHI_TAGS:
        if hasattr(ds, tag):
            setattr(ds, tag, "ANONYMISED")
    return ds

Grad-CAM: Gradient-Weighted Class Activation Map

Python
import torch
import torch.nn as nn
import torch.nn.functional as F

class GradCAM:
    """
    Visualises which regions of an image most influence the prediction.
    Uses gradients of the predicted class flowing into the last conv layer.
    """
    
    def __init__(self, model: nn.Module, target_layer: nn.Module):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        # Register hooks
        target_layer.register_forward_hook(self._save_activation)
        target_layer.register_backward_hook(self._save_gradient)
    
    def _save_activation(self, module, input, output) -> None:
        self.activations = output.detach()
    
    def _save_gradient(self, module, grad_input, grad_output) -> None:
        self.gradients = grad_output[0].detach()
    
    def generate(
        self,
        x: torch.Tensor,
        class_idx: int = None,
    ) -> tuple[torch.Tensor, int]:
        """Returns: (cam, predicted_class)"""
        self.model.eval()
        
        # Forward pass
        logits = self.model(x)
        
        if class_idx is None:
            class_idx = logits.argmax(dim=-1).item()
        
        # Backward pass for the target class
        self.model.zero_grad()
        logits[0, class_idx].backward()
        
        # Weight activations by global-average-pooled gradients
        weights = self.gradients.mean(dim=(2, 3), keepdim=True)  # (B, C, 1, 1)
        cam = (weights * self.activations).sum(dim=1, keepdim=True)  # (B, 1, H, W)
        cam = F.relu(cam)   # only positive contributions
        
        # Normalise to [0, 1]
        cam_min = cam.view(cam.size(0), -1).min(dim=-1)[0].view(-1, 1, 1, 1)
        cam_max = cam.view(cam.size(0), -1).max(dim=-1)[0].view(-1, 1, 1, 1)
        cam = (cam - cam_min) / (cam_max - cam_min + 1e-8)
        
        return cam.squeeze(), class_idx

# Usage with a ResNet
import torchvision.models as models

model = models.resnet50(pretrained=False)
model.fc = nn.Linear(2048, 2)

# Target the last conv layer (layer4[-1].conv3 for ResNet50)
target_layer = model.layer4[-1].conv3
gradcam = GradCAM(model, target_layer)

X = torch.randn(1, 3, 224, 224)
cam, pred_class = gradcam.generate(X)
print(f"Grad-CAM: {cam.shape}, predicted class: {pred_class}")
# cam is (14, 14)  upsample to (224, 224) for overlay on original image

Bias Detection and Subgroup Analysis

Python
import torch
import numpy as np
from sklearn.metrics import roc_auc_score

def evaluate_subgroup_performance(
    model: torch.nn.Module,
    data_by_subgroup: dict[str, tuple],  # {subgroup: (X, y)}
    device: torch.device,
) -> dict[str, dict]:
    """
    Evaluate model performance per demographic subgroup.
    Essential before clinical deployment to detect bias.
    """
    model.eval()
    results = {}
    
    for subgroup, (X, y) in data_by_subgroup.items():
        X = X.to(device)
        
        with torch.no_grad():
            logits = model(X).squeeze()
            probs  = torch.sigmoid(logits).cpu().numpy()
        
        y_np = y.numpy()
        n_positive = y_np.sum()
        n_total    = len(y_np)
        
        if n_positive == 0 or n_positive == n_total:
            auc = float("nan")
        else:
            auc = roc_auc_score(y_np, probs)
        
        results[subgroup] = {
            "n": n_total,
            "prevalence": float(n_positive / n_total),
            "auc": auc,
        }
        print(f"{subgroup:20s}: n={n_total:>5}, prevalence={n_positive/n_total:.1%}, AUC={auc:.4f}")
    
    # Flag subgroups with significantly lower AUC
    aucs = {k: v["auc"] for k, v in results.items() if not np.isnan(v["auc"])}
    if aucs:
        mean_auc = np.mean(list(aucs.values()))
        for subgroup, auc in aucs.items():
            if mean_auc - auc > 0.05:
                print(f"WARNING: {subgroup} AUC is {mean_auc - auc:.3f} below mean — investigate bias")
    
    return results

# Clinical subgroups to always check:
# - Age group (paediatric, adult, geriatric)
# - Sex at birth
# - Ethnicity (if available and appropriately consented)
# - Scanner manufacturer / hospital site
# - Image acquisition parameters (kV, mAs for X-ray)

Model Monitoring Post-Deployment

Python
import torch
import numpy as np
from collections import deque

class ClinicalModelMonitor:
    """Monitor predictions for drift and anomalies post-deployment."""
    
    def __init__(self, window_size: int = 1000, alert_threshold: float = 0.1):
        self.window_size = window_size
        self.alert_threshold = alert_threshold
        self.recent_probs = deque(maxlen=window_size)
        self.baseline_mean = None
        self.baseline_std  = None
    
    def calibrate(self, baseline_probs: list[float]) -> None:
        """Set baseline statistics from validation set."""
        self.baseline_mean = np.mean(baseline_probs)
        self.baseline_std  = np.std(baseline_probs)
        print(f"Baseline: mean={self.baseline_mean:.4f}, std={self.baseline_std:.4f}")
    
    def record_prediction(self, prob: float) -> dict | None:
        """Record a prediction and check for distribution shift."""
        self.recent_probs.append(prob)
        
        if len(self.recent_probs) < 100:
            return None   # not enough data
        
        current_mean = np.mean(self.recent_probs)
        drift = abs(current_mean - self.baseline_mean)
        
        if drift > self.alert_threshold:
            return {
                "alert": "DISTRIBUTION SHIFT",
                "baseline_mean": self.baseline_mean,
                "current_mean": current_mean,
                "drift": drift,
            }
        return None

monitor = ClinicalModelMonitor(window_size=500, alert_threshold=0.05)
monitor.calibrate(baseline_probs=[0.15] * 1000)   # 15% baseline positive rate

# Simulate drift
for _ in range(100):
    alert = monitor.record_prediction(0.15)   # stable
for _ in range(100):
    alert = monitor.record_prediction(0.25)   # shift!
    if alert:
        print(f"ALERT: {alert}")
        break

Interview Answer

"Deploying CNN-based clinical AI requires several additional considerations beyond model accuracy. Preprocessing: DICOM files need windowing (HU clamping) before pixel normalisation; PHI must be stripped for HIPAA/GDPR compliance. Explainability: Grad-CAM visualises which image regions drove the prediction by weighting the last conv layer's activations by the gradient of the predicted class — essential for clinician trust and regulatory review. Bias detection: always evaluate AUC per demographic subgroup (age, sex, scanner type, institution) before deployment — a gap of >0.05 AUC between subgroups warrants investigation. Post-deployment monitoring: track the distribution of predicted probabilities over time; a shift in the mean prediction indicates data drift (e.g., patient population change, scanner software update). Regulatory: UK/EU AI Act and FDA guidance require documented validation on representative populations, performance bounds, and incident reporting mechanisms."

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.