CNNs in Real-World Clinical AI
Deploying CNN-based medical image models — DICOM preprocessing, clinical validation, bias detection, explainability with Grad-CAM, and regulatory considerations.
Clinical Image Pipeline
DICOM file → Preprocessing → Model → Output → Clinical Decision
Preprocessing steps:
1. Load DICOM: extract pixel array + metadata (window level, window width)
2. Apply windowing: clamp HU values to clinically relevant range
3. Normalise to [0, 1] or [-1, 1]
4. Resize to model input size (224×224 or 512×512)
5. Convert grayscale to pseudo-RGB if needed (repeat channel)
6. Apply inference-time transforms (resize, centre crop, normalise)
HIPAA/GDPR compliance:
- Strip PHI before any cloud processing
- Audit logging of all predictions
- Anonymise DICOM tags: PatientName, PatientID, DOB, etc.DICOM Preprocessing
import numpy as np
import torch
import pydicom
from pathlib import Path
def load_dicom_xray(
path: str,
target_size: tuple = (224, 224),
window_centre: float = None,
window_width: float = None,
) -> torch.Tensor:
"""
Load a chest X-ray DICOM and return a normalised tensor.
Returns: (1, H, W) float32 tensor in [0, 1]
"""
ds = pydicom.dcmread(path)
# Get pixel data
pixels = ds.pixel_array.astype(np.float32)
# Apply rescale slope/intercept (converts stored values to HU)
slope = float(getattr(ds, "RescaleSlope", 1.0))
intercept = float(getattr(ds, "RescaleIntercept", 0.0))
pixels = pixels * slope + intercept
# Apply windowing if specified
if window_centre is None:
window_centre = float(getattr(ds, "WindowCenter", pixels.mean()))
if isinstance(window_centre, (list, tuple)):
window_centre = float(window_centre[0])
if window_width is None:
window_width = float(getattr(ds, "WindowWidth", pixels.std() * 4))
if isinstance(window_width, (list, tuple)):
window_width = float(window_width[0])
low = window_centre - window_width / 2
high = window_centre + window_width / 2
pixels = np.clip(pixels, low, high)
# Normalise to [0, 1]
pixels = (pixels - low) / (window_width + 1e-8)
# Resize
import torchvision.transforms.functional as TF
tensor = torch.from_numpy(pixels).unsqueeze(0) # (1, H, W)
tensor = TF.resize(tensor, list(target_size))
return tensor.float()
def anonymise_dicom_metadata(ds: "pydicom.Dataset") -> "pydicom.Dataset":
"""Remove PHI from DICOM metadata."""
PHI_TAGS = [
"PatientName", "PatientID", "PatientBirthDate",
"PatientSex", "InstitutionName", "ReferringPhysicianName",
"StudyDescription", "OperatorsName", "StudyID",
]
for tag in PHI_TAGS:
if hasattr(ds, tag):
setattr(ds, tag, "ANONYMISED")
return dsGrad-CAM: Gradient-Weighted Class Activation Map
import torch
import torch.nn as nn
import torch.nn.functional as F
class GradCAM:
"""
Visualises which regions of an image most influence the prediction.
Uses gradients of the predicted class flowing into the last conv layer.
"""
def __init__(self, model: nn.Module, target_layer: nn.Module):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
# Register hooks
target_layer.register_forward_hook(self._save_activation)
target_layer.register_backward_hook(self._save_gradient)
def _save_activation(self, module, input, output) -> None:
self.activations = output.detach()
def _save_gradient(self, module, grad_input, grad_output) -> None:
self.gradients = grad_output[0].detach()
def generate(
self,
x: torch.Tensor,
class_idx: int = None,
) -> tuple[torch.Tensor, int]:
"""Returns: (cam, predicted_class)"""
self.model.eval()
# Forward pass
logits = self.model(x)
if class_idx is None:
class_idx = logits.argmax(dim=-1).item()
# Backward pass for the target class
self.model.zero_grad()
logits[0, class_idx].backward()
# Weight activations by global-average-pooled gradients
weights = self.gradients.mean(dim=(2, 3), keepdim=True) # (B, C, 1, 1)
cam = (weights * self.activations).sum(dim=1, keepdim=True) # (B, 1, H, W)
cam = F.relu(cam) # only positive contributions
# Normalise to [0, 1]
cam_min = cam.view(cam.size(0), -1).min(dim=-1)[0].view(-1, 1, 1, 1)
cam_max = cam.view(cam.size(0), -1).max(dim=-1)[0].view(-1, 1, 1, 1)
cam = (cam - cam_min) / (cam_max - cam_min + 1e-8)
return cam.squeeze(), class_idx
# Usage with a ResNet
import torchvision.models as models
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(2048, 2)
# Target the last conv layer (layer4[-1].conv3 for ResNet50)
target_layer = model.layer4[-1].conv3
gradcam = GradCAM(model, target_layer)
X = torch.randn(1, 3, 224, 224)
cam, pred_class = gradcam.generate(X)
print(f"Grad-CAM: {cam.shape}, predicted class: {pred_class}")
# cam is (14, 14) — upsample to (224, 224) for overlay on original imageBias Detection and Subgroup Analysis
import torch
import numpy as np
from sklearn.metrics import roc_auc_score
def evaluate_subgroup_performance(
model: torch.nn.Module,
data_by_subgroup: dict[str, tuple], # {subgroup: (X, y)}
device: torch.device,
) -> dict[str, dict]:
"""
Evaluate model performance per demographic subgroup.
Essential before clinical deployment to detect bias.
"""
model.eval()
results = {}
for subgroup, (X, y) in data_by_subgroup.items():
X = X.to(device)
with torch.no_grad():
logits = model(X).squeeze()
probs = torch.sigmoid(logits).cpu().numpy()
y_np = y.numpy()
n_positive = y_np.sum()
n_total = len(y_np)
if n_positive == 0 or n_positive == n_total:
auc = float("nan")
else:
auc = roc_auc_score(y_np, probs)
results[subgroup] = {
"n": n_total,
"prevalence": float(n_positive / n_total),
"auc": auc,
}
print(f"{subgroup:20s}: n={n_total:>5}, prevalence={n_positive/n_total:.1%}, AUC={auc:.4f}")
# Flag subgroups with significantly lower AUC
aucs = {k: v["auc"] for k, v in results.items() if not np.isnan(v["auc"])}
if aucs:
mean_auc = np.mean(list(aucs.values()))
for subgroup, auc in aucs.items():
if mean_auc - auc > 0.05:
print(f"WARNING: {subgroup} AUC is {mean_auc - auc:.3f} below mean — investigate bias")
return results
# Clinical subgroups to always check:
# - Age group (paediatric, adult, geriatric)
# - Sex at birth
# - Ethnicity (if available and appropriately consented)
# - Scanner manufacturer / hospital site
# - Image acquisition parameters (kV, mAs for X-ray)Model Monitoring Post-Deployment
import torch
import numpy as np
from collections import deque
class ClinicalModelMonitor:
"""Monitor predictions for drift and anomalies post-deployment."""
def __init__(self, window_size: int = 1000, alert_threshold: float = 0.1):
self.window_size = window_size
self.alert_threshold = alert_threshold
self.recent_probs = deque(maxlen=window_size)
self.baseline_mean = None
self.baseline_std = None
def calibrate(self, baseline_probs: list[float]) -> None:
"""Set baseline statistics from validation set."""
self.baseline_mean = np.mean(baseline_probs)
self.baseline_std = np.std(baseline_probs)
print(f"Baseline: mean={self.baseline_mean:.4f}, std={self.baseline_std:.4f}")
def record_prediction(self, prob: float) -> dict | None:
"""Record a prediction and check for distribution shift."""
self.recent_probs.append(prob)
if len(self.recent_probs) < 100:
return None # not enough data
current_mean = np.mean(self.recent_probs)
drift = abs(current_mean - self.baseline_mean)
if drift > self.alert_threshold:
return {
"alert": "DISTRIBUTION SHIFT",
"baseline_mean": self.baseline_mean,
"current_mean": current_mean,
"drift": drift,
}
return None
monitor = ClinicalModelMonitor(window_size=500, alert_threshold=0.05)
monitor.calibrate(baseline_probs=[0.15] * 1000) # 15% baseline positive rate
# Simulate drift
for _ in range(100):
alert = monitor.record_prediction(0.15) # stable
for _ in range(100):
alert = monitor.record_prediction(0.25) # shift!
if alert:
print(f"ALERT: {alert}")
breakInterview Answer
"Deploying CNN-based clinical AI requires several additional considerations beyond model accuracy. Preprocessing: DICOM files need windowing (HU clamping) before pixel normalisation; PHI must be stripped for HIPAA/GDPR compliance. Explainability: Grad-CAM visualises which image regions drove the prediction by weighting the last conv layer's activations by the gradient of the predicted class — essential for clinician trust and regulatory review. Bias detection: always evaluate AUC per demographic subgroup (age, sex, scanner type, institution) before deployment — a gap of >0.05 AUC between subgroups warrants investigation. Post-deployment monitoring: track the distribution of predicted probabilities over time; a shift in the mean prediction indicates data drift (e.g., patient population change, scanner software update). Regulatory: UK/EU AI Act and FDA guidance require documented validation on representative populations, performance bounds, and incident reporting mechanisms."
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.