Learnixo
Back to blog
AI Systemsadvanced

Interview: Confusion Matrix Deep Dive

Interview walk-through: analyze a confusion matrix for a drug safety classifier, interpret error patterns, select the right threshold, and explain the clinical implications of each error type.

Asma Hafeez KhanMay 16, 20265 min read
Machine LearningInterviewConfusion MatrixClinical AIEvaluation
Share:𝕏

The Scenario

You built a classifier to detect potential drug-drug interactions (DDIs) in patient medication lists. The model outputs a binary prediction: interaction detected or not. On your validation set of 2000 medication pairs, you get this confusion matrix. Walk me through your analysis.

                    Predicted No DDI   Predicted DDI
Actual No DDI    |      1760         |      90       |
Actual DDI       |       60          |      90       |

Step 1: Extract the Numbers

Python
import numpy as np
from sklearn.metrics import confusion_matrix

# Reconstruct from the table
tn, fp = 1760, 90
fn, tp = 60, 90

total = tn + fp + fn + tp
positive_count = fn + tp
negative_count = tn + fp

print(f"Total: {total}")
print(f"Actual DDI (positive): {positive_count} ({positive_count/total:.1%})")
print(f"Actual no DDI (negative): {negative_count} ({negative_count/total:.1%})")
print(f"\nTN={tn}, FP={fp}, FN={fn}, TP={tp}")

Step 2: Compute All Metrics

Python
accuracy    = (tp + tn) / total
precision   = tp / (tp + fp)
recall      = tp / (tp + fn)
specificity = tn / (tn + fp)
f1          = 2 * precision * recall / (precision + recall)
fpr         = fp / (fp + tn)
fnr         = fn / (fn + tp)
npv         = tn / (tn + fn)

print(f"Accuracy:     {accuracy:.3f}  ({accuracy:.1%})")
print(f"Precision:    {precision:.3f}  ({precision:.1%}) — when we flag, 50% are real DDIs")
print(f"Recall:       {recall:.3f}  ({recall:.1%}) — caught 60% of real DDIs")
print(f"Specificity:  {specificity:.3f}  ({specificity:.1%}) — cleared 95.1% of non-DDIs correctly")
print(f"F1:           {f1:.3f}")
print(f"FPR:          {fpr:.3f}  ({fpr:.1%}) — 4.9% of non-DDIs incorrectly flagged")
print(f"FNR:          {fnr:.3f}  ({fnr:.1%}) — missed 40% of real DDIs")
print(f"NPV:          {npv:.3f}  ({npv:.1%}) — 96.7% chance a 'clean' pair is truly safe")

Step 3: Interpret the Error Pattern

Python
# Key observations:
print("=== Error Analysis ===\n")

print("1. CLASS IMBALANCE:")
print(f"   Positive rate: {positive_count/total:.1%} (DDI class)")
print(f"   This is a 7.5% positive rate — severely imbalanced")
print(f"   A model that never flags DDIs would get {negative_count/total:.1%} accuracy")

print("\n2. FALSE NEGATIVES (FN=60):")
print(f"   40% of real drug-drug interactions are missed")
print(f"   These patients have undetected interactions in their medication lists")
print(f"   Clinical risk: potential adverse drug reactions go undetected")

print("\n3. FALSE POSITIVES (FP=90):")
print(f"   90 medication pairs incorrectly flagged as interactions")
print(f"   = Alert fatigue: 50% of all alerts are false alarms")
print(f"   Clinical risk: physicians may start ignoring alerts")

print("\n4. PRECISION = 50%:")
print(f"   For every 2 alerts, 1 is a real interaction and 1 is a false alarm")
print(f"   This is on the edge of clinical acceptability")

Step 4: What Would You Change?

Python
# The 60% recall is too low for a safety-critical system
# We need to catch more DDIs, even if we generate more false alarms

# Option A: Lower the classification threshold
from sklearn.metrics import precision_recall_curve
import numpy as np

y_proba = model.predict_proba(X_val)[:, 1]
precisions, recalls, thresholds = precision_recall_curve(y_val, y_proba)

print("Threshold tuning to improve recall:")
print(f"{'Threshold':>10}  {'Recall':>8}  {'Precision':>10}  {'FP':>6}  {'FN':>6}")
print("-" * 48)

for t, p, r in zip(thresholds[::20], precisions[::20], recalls[::20]):
    y_pred_t = (y_proba >= t).astype(int)
    tn_t, fp_t, fn_t, tp_t = confusion_matrix(y_val, y_pred_t).ravel()
    print(f"{t:>10.3f}  {r:>8.3f}  {p:>10.3f}  {fp_t:>6}  {fn_t:>6}")

# Target: recall >= 0.85, accept lower precision
for t, p, r in zip(thresholds, precisions, recalls):
    if r >= 0.85:
        print(f"\nRecommended threshold for recall=0.85: {t:.3f}")
        print(f"  Precision: {p:.3f}  (1 in {1/p:.0f} alerts is real)")
        break

Step 5: Clinical Implications

Python
# Communicate results in clinical terms

total_patients_per_day = 200   # example: 200 new patients per day

# With current model (recall=0.60, precision=0.50):
daily_true_ddi = int(total_patients_per_day * 0.075)  # 7.5% have DDI
daily_caught   = int(daily_true_ddi * 0.60)
daily_missed   = daily_true_ddi - daily_caught
daily_false_alarms = int(daily_caught / 0.50 - daily_caught)  # FP = TP * (1-precision)/precision

print("=== Clinical Impact at Current Threshold ===")
print(f"Patients per day:          {total_patients_per_day}")
print(f"Expected DDI patients:     {daily_true_ddi}")
print(f"Correctly flagged:         {daily_caught} ({daily_caught/daily_true_ddi:.0%} of real DDIs)")
print(f"Missed (no alert):         {daily_missed} — patients with undetected DDI")
print(f"False alarms:              {daily_false_alarms} — physicians review and dismiss")

print("\n=== After Lowering Threshold (recall=0.85) ===")
daily_caught_new  = int(daily_true_ddi * 0.85)
daily_missed_new  = daily_true_ddi - daily_caught_new
daily_false_new   = int(daily_caught_new / 0.30 - daily_caught_new)  # assuming precision drops to 0.30
print(f"Correctly flagged:         {daily_caught_new} ({daily_caught_new/daily_true_ddi:.0%} of real DDIs)")
print(f"Missed:                    {daily_missed_new}")
print(f"False alarms:              {daily_false_new} — higher, but still manageable")
print(f"Clinical decision: is {daily_false_new} false alarms/day worth {daily_caught_new - daily_caught} more DDIs caught?")

What Interviewers Want to Hear

  1. Compute all metrics — don't just state recall or accuracy; derive precision, specificity, NPV, and FNR
  2. Identify the dominant error — FN=60 (40% miss rate) is the primary concern for safety
  3. Connect to clinical cost — missed DDIs go undetected; the harm is concrete
  4. Propose a fix — lower the threshold to increase recall; quantify the tradeoff
  5. Address alert fatigue — precision=50% is already borderline; lowering threshold makes it worse
  6. Suggest alternatives — risk tiering, clinical override interface, filtering by interaction severity

One-line answer: "60% recall means 40% of real drug-drug interactions are missed — unacceptable for a safety system. I'd lower the threshold until recall reaches 0.85+, accept the precision drop, and work with the clinical team to determine whether the resulting false alarm rate is operationally sustainable. If alert fatigue is a concern, I'd implement risk tiers instead of a binary flag."

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.