The Confusion Matrix
Confusion matrix explained: reading TP/TN/FP/FN, computing all derived metrics, multi-class confusion matrices, interpreting class-level errors, and common visualization patterns.
The 2Ć2 Matrix
For binary classification, a confusion matrix is a 2Ć2 table of actual vs predicted labels.
Predicted
Negative Positive
Actual Negative | TN | FP |
Positive | FN | TP |
TN = True Negative ā correctly predicted no readmission
TP = True Positive ā correctly predicted readmission
FP = False Positive ā predicted readmission, but patient was fine (false alarm)
FN = False Negative ā missed a real readmission (false reassurance)Computing and Displaying
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import numpy as np
# 30-day readmission model
y_true = np.array([0]*100 + [1]*20) # 100 negatives, 20 positives
y_pred = np.array([0]*90 + [1]*10 + [0]*5 + [1]*15)
# TN=90 FP=10 FN=5 TP=15
cm = confusion_matrix(y_true, y_pred)
print("Confusion matrix:")
print(cm)
# [[90 10]
# [ 5 15]]
# Row 0: actual negatives ā 90 TN, 10 FP
# Row 1: actual positives ā 5 FN, 15 TP
# sklearn also prints with labels
disp = ConfusionMatrixDisplay(cm, display_labels=["no_readmit", "readmit"])
disp.plot()
# Or manually:
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
print(f"TN={tn}, FP={fp}, FN={fn}, TP={tp}")Deriving All Metrics from the Matrix
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
total = tn + fp + fn + tp
# Core metrics
accuracy = (tp + tn) / total
precision = tp / (tp + fp) # Of flagged, how many were real?
recall = tp / (tp + fn) # Of real positives, how many caught?
specificity = tn / (tn + fp) # Of real negatives, how many correctly passed?
f1 = 2 * precision * recall / (precision + recall)
# Less common but important clinically
npv = tn / (tn + fn) # Negative predictive value
ppv = precision # Positive predictive value (same as precision)
fpr = fp / (fp + tn) # False positive rate = 1 - specificity
fnr = fn / (fn + tp) # False negative rate = 1 - recall (miss rate)
print(f"Accuracy: {accuracy:.3f}")
print(f"Precision: {precision:.3f}")
print(f"Recall: {recall:.3f} (sensitivity)")
print(f"Specificity: {specificity:.3f}")
print(f"F1: {f1:.3f}")
print(f"NPV: {npv:.3f}")
print(f"FPR: {fpr:.3f} (1 - specificity)")
print(f"FNR: {fnr:.3f} (miss rate, 1 - recall)")Reading the Errors in Context
# Each cell has a clinical interpretation
print(f"\nClinical interpretation:")
print(f" TN={tn}: Patients correctly told 'you're low risk' ā went home, stayed home")
print(f" TP={tp}: Patients correctly flagged ā received discharge planning / follow-up")
print(f" FP={fp}: Patients incorrectly flagged ā unnecessary discharge services")
print(f" ā Waste of resources, but patient is not harmed")
print(f" FN={fn}: Patients missed ā sent home without support, readmitted within 30 days")
print(f" ā Preventable readmission ā this is the costly error")
# For this application: minimize FN (high recall), accept some FPNormalizing the Matrix
# Row-normalized: shows rate of errors within each actual class
# Useful for understanding what fraction of each class was correct
cm_normalized = confusion_matrix(y_true, y_pred, normalize="true")
print("Row-normalized confusion matrix:")
print(cm_normalized.round(3))
# [[0.9 0.1] 90% of negatives correct, 10% incorrectly flagged (FPR=0.1)
# [0.25 0.75]] 75% of positives caught (recall=0.75), 25% missed (FNR=0.25)Multi-Class Confusion Matrix
from sklearn.metrics import confusion_matrix, classification_report
# Drug category classification: 4 classes
classes = ["anticoagulant", "antidiabetic", "antihypertensive", "antibiotic"]
y_true = [0, 1, 2, 3, 0, 1, 2, 3, 0, 0, 1, 2, 3, 3, 1]
y_pred = [0, 1, 2, 2, 0, 1, 3, 3, 1, 0, 1, 2, 3, 2, 1]
cm = confusion_matrix(y_true, y_pred)
print("Multi-class confusion matrix:")
print("Predicted ā")
print(f"{'':>17}", " ".join(f"{c[:6]:>6}" for c in classes))
for i, (row, name) in enumerate(zip(cm, classes)):
print(f"Actual {name[:12]:>12}: ", " ".join(f"{v:>6}" for v in row))
print("\nClassification report:")
print(classification_report(y_true, y_pred, target_names=classes))What the Off-Diagonal Entries Reveal
# In a multi-class matrix, off-diagonal entries show which classes get confused
# Example confusion pattern:
# anticoagulant predicted as antidiabetic: 3 times
# ā These classes share features? Check molecule structure
# antihypertensive predicted as antibiotic: 5 times
# ā A systematic error ā investigate the feature space
def find_most_confused_pairs(cm: np.ndarray, class_names: list) -> list[tuple]:
"""Return (actual, predicted, count) for off-diagonal entries, sorted by count."""
errors = []
for i in range(len(class_names)):
for j in range(len(class_names)):
if i != j and cm[i, j] > 0:
errors.append((class_names[i], class_names[j], cm[i, j]))
return sorted(errors, key=lambda x: -x[2])
cm_array = np.array(confusion_matrix(y_true, y_pred))
confused = find_most_confused_pairs(cm_array, classes)
print("Most confused class pairs:")
for actual, predicted, count in confused[:5]:
print(f" {actual} ā predicted as {predicted}: {count} times")Confusion Matrix for Train vs Val
from sklearn.metrics import confusion_matrix
# A simple but powerful diagnostic: compare confusion matrices on train vs val
# Large discrepancy ā overfitting on specific error patterns
def compare_confusion_matrices(model, X_train, y_train, X_val, y_val):
cm_train = confusion_matrix(y_train, model.predict(X_train), normalize="true")
cm_val = confusion_matrix(y_val, model.predict(X_val), normalize="true")
print("Train confusion (normalized):")
print(cm_train.round(3))
print("\nVal confusion (normalized):")
print(cm_val.round(3))
print("\nDifference (train - val):")
diff = cm_train - cm_val
print(diff.round(3))
# Large positive on diagonal ā model performs better on train than val ā overfittingInterview Answer Template
Q: How do you read a confusion matrix?
A confusion matrix shows the 4 possible outcomes for binary classification: true positives (correctly flagged), true negatives (correctly cleared), false positives (false alarms ā predicted positive but actually negative), and false negatives (missed cases ā predicted negative but actually positive). From it you can derive every standard metric: precision (TP/(TP+FP)), recall (TP/(TP+FN)), specificity (TN/(TN+FP)), F1, accuracy, and NPV. The most important thing to look at is which error type dominates. In clinical settings, false negatives (missed diagnoses) are typically more costly than false positives (false alarms), so I'd focus on the FN cell and recall. For multi-class problems, the off-diagonal entries reveal which classes are confused with each other ā a pattern that often points to a specific feature or preprocessing issue.
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.