Learnixo

Machine Learning Foundations · Lesson 50 of 70

Interview: Confusion Matrix Scenario

The Scenario

You built a classifier to detect potential drug-drug interactions (DDIs) in patient medication lists. The model outputs a binary prediction: interaction detected or not. On your validation set of 2000 medication pairs, you get this confusion matrix. Walk me through your analysis.

                    Predicted No DDI   Predicted DDI
Actual No DDI    |      1760         |      90       |
Actual DDI       |       60          |      90       |

Step 1: Extract the Numbers

Python
import numpy as np
from sklearn.metrics import confusion_matrix

# Reconstruct from the table
tn, fp = 1760, 90
fn, tp = 60, 90

total = tn + fp + fn + tp
positive_count = fn + tp
negative_count = tn + fp

print(f"Total: {total}")
print(f"Actual DDI (positive): {positive_count} ({positive_count/total:.1%})")
print(f"Actual no DDI (negative): {negative_count} ({negative_count/total:.1%})")
print(f"\nTN={tn}, FP={fp}, FN={fn}, TP={tp}")

Step 2: Compute All Metrics

Python
accuracy    = (tp + tn) / total
precision   = tp / (tp + fp)
recall      = tp / (tp + fn)
specificity = tn / (tn + fp)
f1          = 2 * precision * recall / (precision + recall)
fpr         = fp / (fp + tn)
fnr         = fn / (fn + tp)
npv         = tn / (tn + fn)

print(f"Accuracy:     {accuracy:.3f}  ({accuracy:.1%})")
print(f"Precision:    {precision:.3f}  ({precision:.1%}) — when we flag, 50% are real DDIs")
print(f"Recall:       {recall:.3f}  ({recall:.1%}) — caught 60% of real DDIs")
print(f"Specificity:  {specificity:.3f}  ({specificity:.1%}) — cleared 95.1% of non-DDIs correctly")
print(f"F1:           {f1:.3f}")
print(f"FPR:          {fpr:.3f}  ({fpr:.1%}) — 4.9% of non-DDIs incorrectly flagged")
print(f"FNR:          {fnr:.3f}  ({fnr:.1%}) — missed 40% of real DDIs")
print(f"NPV:          {npv:.3f}  ({npv:.1%}) — 96.7% chance a 'clean' pair is truly safe")

Step 3: Interpret the Error Pattern

Python
# Key observations:
print("=== Error Analysis ===\n")

print("1. CLASS IMBALANCE:")
print(f"   Positive rate: {positive_count/total:.1%} (DDI class)")
print(f"   This is a 7.5% positive rate — severely imbalanced")
print(f"   A model that never flags DDIs would get {negative_count/total:.1%} accuracy")

print("\n2. FALSE NEGATIVES (FN=60):")
print(f"   40% of real drug-drug interactions are missed")
print(f"   These patients have undetected interactions in their medication lists")
print(f"   Clinical risk: potential adverse drug reactions go undetected")

print("\n3. FALSE POSITIVES (FP=90):")
print(f"   90 medication pairs incorrectly flagged as interactions")
print(f"   = Alert fatigue: 50% of all alerts are false alarms")
print(f"   Clinical risk: physicians may start ignoring alerts")

print("\n4. PRECISION = 50%:")
print(f"   For every 2 alerts, 1 is a real interaction and 1 is a false alarm")
print(f"   This is on the edge of clinical acceptability")

Step 4: What Would You Change?

Python
# The 60% recall is too low for a safety-critical system
# We need to catch more DDIs, even if we generate more false alarms

# Option A: Lower the classification threshold
from sklearn.metrics import precision_recall_curve
import numpy as np

y_proba = model.predict_proba(X_val)[:, 1]
precisions, recalls, thresholds = precision_recall_curve(y_val, y_proba)

print("Threshold tuning to improve recall:")
print(f"{'Threshold':>10}  {'Recall':>8}  {'Precision':>10}  {'FP':>6}  {'FN':>6}")
print("-" * 48)

for t, p, r in zip(thresholds[::20], precisions[::20], recalls[::20]):
    y_pred_t = (y_proba >= t).astype(int)
    tn_t, fp_t, fn_t, tp_t = confusion_matrix(y_val, y_pred_t).ravel()
    print(f"{t:>10.3f}  {r:>8.3f}  {p:>10.3f}  {fp_t:>6}  {fn_t:>6}")

# Target: recall >= 0.85, accept lower precision
for t, p, r in zip(thresholds, precisions, recalls):
    if r >= 0.85:
        print(f"\nRecommended threshold for recall=0.85: {t:.3f}")
        print(f"  Precision: {p:.3f}  (1 in {1/p:.0f} alerts is real)")
        break

Step 5: Clinical Implications

Python
# Communicate results in clinical terms

total_patients_per_day = 200   # example: 200 new patients per day

# With current model (recall=0.60, precision=0.50):
daily_true_ddi = int(total_patients_per_day * 0.075)  # 7.5% have DDI
daily_caught   = int(daily_true_ddi * 0.60)
daily_missed   = daily_true_ddi - daily_caught
daily_false_alarms = int(daily_caught / 0.50 - daily_caught)  # FP = TP * (1-precision)/precision

print("=== Clinical Impact at Current Threshold ===")
print(f"Patients per day:          {total_patients_per_day}")
print(f"Expected DDI patients:     {daily_true_ddi}")
print(f"Correctly flagged:         {daily_caught} ({daily_caught/daily_true_ddi:.0%} of real DDIs)")
print(f"Missed (no alert):         {daily_missed} — patients with undetected DDI")
print(f"False alarms:              {daily_false_alarms} — physicians review and dismiss")

print("\n=== After Lowering Threshold (recall=0.85) ===")
daily_caught_new  = int(daily_true_ddi * 0.85)
daily_missed_new  = daily_true_ddi - daily_caught_new
daily_false_new   = int(daily_caught_new / 0.30 - daily_caught_new)  # assuming precision drops to 0.30
print(f"Correctly flagged:         {daily_caught_new} ({daily_caught_new/daily_true_ddi:.0%} of real DDIs)")
print(f"Missed:                    {daily_missed_new}")
print(f"False alarms:              {daily_false_new} — higher, but still manageable")
print(f"Clinical decision: is {daily_false_new} false alarms/day worth {daily_caught_new - daily_caught} more DDIs caught?")

What Interviewers Want to Hear

  1. Compute all metrics — don't just state recall or accuracy; derive precision, specificity, NPV, and FNR
  2. Identify the dominant error — FN=60 (40% miss rate) is the primary concern for safety
  3. Connect to clinical cost — missed DDIs go undetected; the harm is concrete
  4. Propose a fix — lower the threshold to increase recall; quantify the tradeoff
  5. Address alert fatigue — precision=50% is already borderline; lowering threshold makes it worse
  6. Suggest alternatives — risk tiering, clinical override interface, filtering by interaction severity

One-line answer: "60% recall means 40% of real drug-drug interactions are missed — unacceptable for a safety system. I'd lower the threshold until recall reaches 0.85+, accept the precision drop, and work with the clinical team to determine whether the resulting false alarm rate is operationally sustainable. If alert fatigue is a concern, I'd implement risk tiers instead of a binary flag."