Interview: Confusion Matrix Deep Dive
Interview walk-through: analyze a confusion matrix for a drug safety classifier, interpret error patterns, select the right threshold, and explain the clinical implications of each error type.
The Scenario
You built a classifier to detect potential drug-drug interactions (DDIs) in patient medication lists. The model outputs a binary prediction: interaction detected or not. On your validation set of 2000 medication pairs, you get this confusion matrix. Walk me through your analysis.
Predicted No DDI Predicted DDI
Actual No DDI | 1760 | 90 |
Actual DDI | 60 | 90 |Step 1: Extract the Numbers
import numpy as np
from sklearn.metrics import confusion_matrix
# Reconstruct from the table
tn, fp = 1760, 90
fn, tp = 60, 90
total = tn + fp + fn + tp
positive_count = fn + tp
negative_count = tn + fp
print(f"Total: {total}")
print(f"Actual DDI (positive): {positive_count} ({positive_count/total:.1%})")
print(f"Actual no DDI (negative): {negative_count} ({negative_count/total:.1%})")
print(f"\nTN={tn}, FP={fp}, FN={fn}, TP={tp}")Step 2: Compute All Metrics
accuracy = (tp + tn) / total
precision = tp / (tp + fp)
recall = tp / (tp + fn)
specificity = tn / (tn + fp)
f1 = 2 * precision * recall / (precision + recall)
fpr = fp / (fp + tn)
fnr = fn / (fn + tp)
npv = tn / (tn + fn)
print(f"Accuracy: {accuracy:.3f} ({accuracy:.1%})")
print(f"Precision: {precision:.3f} ({precision:.1%}) — when we flag, 50% are real DDIs")
print(f"Recall: {recall:.3f} ({recall:.1%}) — caught 60% of real DDIs")
print(f"Specificity: {specificity:.3f} ({specificity:.1%}) — cleared 95.1% of non-DDIs correctly")
print(f"F1: {f1:.3f}")
print(f"FPR: {fpr:.3f} ({fpr:.1%}) — 4.9% of non-DDIs incorrectly flagged")
print(f"FNR: {fnr:.3f} ({fnr:.1%}) — missed 40% of real DDIs")
print(f"NPV: {npv:.3f} ({npv:.1%}) — 96.7% chance a 'clean' pair is truly safe")Step 3: Interpret the Error Pattern
# Key observations:
print("=== Error Analysis ===\n")
print("1. CLASS IMBALANCE:")
print(f" Positive rate: {positive_count/total:.1%} (DDI class)")
print(f" This is a 7.5% positive rate — severely imbalanced")
print(f" A model that never flags DDIs would get {negative_count/total:.1%} accuracy")
print("\n2. FALSE NEGATIVES (FN=60):")
print(f" 40% of real drug-drug interactions are missed")
print(f" These patients have undetected interactions in their medication lists")
print(f" Clinical risk: potential adverse drug reactions go undetected")
print("\n3. FALSE POSITIVES (FP=90):")
print(f" 90 medication pairs incorrectly flagged as interactions")
print(f" = Alert fatigue: 50% of all alerts are false alarms")
print(f" Clinical risk: physicians may start ignoring alerts")
print("\n4. PRECISION = 50%:")
print(f" For every 2 alerts, 1 is a real interaction and 1 is a false alarm")
print(f" This is on the edge of clinical acceptability")Step 4: What Would You Change?
# The 60% recall is too low for a safety-critical system
# We need to catch more DDIs, even if we generate more false alarms
# Option A: Lower the classification threshold
from sklearn.metrics import precision_recall_curve
import numpy as np
y_proba = model.predict_proba(X_val)[:, 1]
precisions, recalls, thresholds = precision_recall_curve(y_val, y_proba)
print("Threshold tuning to improve recall:")
print(f"{'Threshold':>10} {'Recall':>8} {'Precision':>10} {'FP':>6} {'FN':>6}")
print("-" * 48)
for t, p, r in zip(thresholds[::20], precisions[::20], recalls[::20]):
y_pred_t = (y_proba >= t).astype(int)
tn_t, fp_t, fn_t, tp_t = confusion_matrix(y_val, y_pred_t).ravel()
print(f"{t:>10.3f} {r:>8.3f} {p:>10.3f} {fp_t:>6} {fn_t:>6}")
# Target: recall >= 0.85, accept lower precision
for t, p, r in zip(thresholds, precisions, recalls):
if r >= 0.85:
print(f"\nRecommended threshold for recall=0.85: {t:.3f}")
print(f" Precision: {p:.3f} (1 in {1/p:.0f} alerts is real)")
breakStep 5: Clinical Implications
# Communicate results in clinical terms
total_patients_per_day = 200 # example: 200 new patients per day
# With current model (recall=0.60, precision=0.50):
daily_true_ddi = int(total_patients_per_day * 0.075) # 7.5% have DDI
daily_caught = int(daily_true_ddi * 0.60)
daily_missed = daily_true_ddi - daily_caught
daily_false_alarms = int(daily_caught / 0.50 - daily_caught) # FP = TP * (1-precision)/precision
print("=== Clinical Impact at Current Threshold ===")
print(f"Patients per day: {total_patients_per_day}")
print(f"Expected DDI patients: {daily_true_ddi}")
print(f"Correctly flagged: {daily_caught} ({daily_caught/daily_true_ddi:.0%} of real DDIs)")
print(f"Missed (no alert): {daily_missed} — patients with undetected DDI")
print(f"False alarms: {daily_false_alarms} — physicians review and dismiss")
print("\n=== After Lowering Threshold (recall=0.85) ===")
daily_caught_new = int(daily_true_ddi * 0.85)
daily_missed_new = daily_true_ddi - daily_caught_new
daily_false_new = int(daily_caught_new / 0.30 - daily_caught_new) # assuming precision drops to 0.30
print(f"Correctly flagged: {daily_caught_new} ({daily_caught_new/daily_true_ddi:.0%} of real DDIs)")
print(f"Missed: {daily_missed_new}")
print(f"False alarms: {daily_false_new} — higher, but still manageable")
print(f"Clinical decision: is {daily_false_new} false alarms/day worth {daily_caught_new - daily_caught} more DDIs caught?")What Interviewers Want to Hear
- Compute all metrics — don't just state recall or accuracy; derive precision, specificity, NPV, and FNR
- Identify the dominant error — FN=60 (40% miss rate) is the primary concern for safety
- Connect to clinical cost — missed DDIs go undetected; the harm is concrete
- Propose a fix — lower the threshold to increase recall; quantify the tradeoff
- Address alert fatigue — precision=50% is already borderline; lowering threshold makes it worse
- Suggest alternatives — risk tiering, clinical override interface, filtering by interaction severity
One-line answer: "60% recall means 40% of real drug-drug interactions are missed — unacceptable for a safety system. I'd lower the threshold until recall reaches 0.85+, accept the precision drop, and work with the clinical team to determine whether the resulting false alarm rate is operationally sustainable. If alert fatigue is a concern, I'd implement risk tiers instead of a binary flag."
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.