Machine Learning Foundations · Lesson 54 of 70
How to Choose the Right Threshold
The Decision Framework
Before selecting a threshold, answer:
1. What is the cost of a false negative?
→ High (missed diagnosis, safety risk): target high recall
→ Low (manageable): optimize F1 or Youden's J
2. What is the cost of a false positive?
→ High (alert fatigue, invasive follow-up): target high precision
→ Low (cheap to follow up): optimize F1 or Youden's J
3. Are both errors equally costly?
→ Optimize F1 or Youden's J
4. Is there a hard clinical constraint?
→ "We must catch at least 90% of sepsis cases" → recall constraint
→ "Precision must be at least 50% or alerts will be ignored" → precision constraintMethod 1: Maximize F1
import numpy as np
from sklearn.metrics import precision_recall_curve
y_proba = model.predict_proba(X_val)[:, 1]
precisions, recalls, thresholds = precision_recall_curve(y_val, y_proba)
# F1 at each threshold (precisions/recalls have one extra element at the end)
f1_scores = 2 * precisions[:-1] * recalls[:-1] / (precisions[:-1] + recalls[:-1] + 1e-9)
best_idx = np.argmax(f1_scores)
best_threshold = thresholds[best_idx]
print(f"F1-optimal threshold: {best_threshold:.3f}")
print(f" Precision: {precisions[best_idx]:.3f}")
print(f" Recall: {recalls[best_idx]:.3f}")
print(f" F1: {f1_scores[best_idx]:.3f}")Method 2: Recall-Constrained (Clinical Safety)
def find_threshold_for_recall(y_val, y_proba, target_recall: float) -> dict:
"""
Find the highest threshold (best precision) that still achieves target_recall.
"""
precisions, recalls, thresholds = precision_recall_curve(y_val, y_proba)
# Sweep from high to low threshold (low to high recall)
for t, p, r in zip(thresholds[::-1], precisions[-2::-1], recalls[-2::-1]):
if r >= target_recall:
return {
"threshold": t,
"precision": p,
"recall": r,
"false_alarm_rate": f"1 in {1/p:.0f} alerts is real" if p > 0 else "N/A",
}
return {"threshold": 0.0, "precision": precisions[0], "recall": recalls[0], "note": "target not achievable"}
# Sepsis model: must achieve recall >= 0.90
result = find_threshold_for_recall(y_val, y_proba, target_recall=0.90)
print("Recall-constrained threshold:")
for k, v in result.items():
print(f" {k}: {v}")Method 3: Precision-Constrained (Alert Fatigue)
def find_threshold_for_precision(y_val, y_proba, target_precision: float) -> dict:
"""
Find the lowest threshold (best recall) that achieves target_precision.
"""
precisions, recalls, thresholds = precision_recall_curve(y_val, y_proba)
for t, p, r in zip(thresholds, precisions, recalls):
if p >= target_precision:
return {
"threshold": t,
"precision": p,
"recall": r,
"missed_fraction": f"{1-r:.0%} of real positives missed",
}
return {"note": "target precision not achievable at any threshold"}
# Drug alert system: precision must be at least 40% (acceptable alert burden)
result = find_threshold_for_precision(y_val, y_proba, target_precision=0.40)
print("Precision-constrained threshold:")
for k, v in result.items():
print(f" {k}: {v}")Method 4: Youden's J (Balanced Default)
from sklearn.metrics import roc_curve
import numpy as np
fpr, tpr, thresholds_roc = roc_curve(y_val, y_proba)
# Youden's J = sensitivity + specificity - 1 = TPR - FPR
j_scores = tpr - fpr
best_idx = np.argmax(j_scores)
best_j_threshold = thresholds_roc[best_idx]
print(f"Youden's J threshold: {best_j_threshold:.3f}")
print(f" Sensitivity (recall): {tpr[best_idx]:.3f}")
print(f" Specificity: {1 - fpr[best_idx]:.3f}")
print(f" Youden's J: {j_scores[best_idx]:.3f}")
# Best for: no strong preference between FP and FN; want a balanced operating pointMethod 5: Cost-Sensitive Threshold
def cost_optimal_threshold(
y_val: np.ndarray,
y_proba: np.ndarray,
cost_fn: float, # cost per false negative
cost_fp: float, # cost per false positive
) -> dict:
"""
Find the threshold minimizing expected total cost.
"""
from sklearn.metrics import confusion_matrix
best_cost = float("inf")
best_t = 0.5
for t in np.arange(0.05, 0.95, 0.01):
y_pred_t = (y_proba >= t).astype(int)
tn, fp, fn, tp = confusion_matrix(y_val, y_pred_t).ravel()
total_cost = fn * cost_fn + fp * cost_fp
if total_cost < best_cost:
best_cost = total_cost
best_t = t
y_pred_best = (y_proba >= best_t).astype(int)
tn, fp, fn, tp = confusion_matrix(y_val, y_pred_best).ravel()
return {
"threshold": best_t,
"total_cost": best_cost,
"fn": fn,
"fp": fp,
"recall": tp / (tp + fn) if (tp + fn) > 0 else 0,
"precision": tp / (tp + fp) if (tp + fp) > 0 else 0,
}
# Readmission model cost structure
result = cost_optimal_threshold(
y_val, y_proba,
cost_fn=5000, # missed readmission: readmission prevention opportunity lost
cost_fp=200, # false alarm: unnecessary discharge planning resources
)
print("Cost-optimal threshold:")
for k, v in result.items():
print(f" {k}: {v}")The Correct Validation Procedure
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
# 3-way split
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.30, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.50, random_state=42)
# Step 1: Train model
model.fit(X_train, y_train)
# Step 2: Find threshold on validation set
y_val_proba = model.predict_proba(X_val)[:, 1]
threshold_result = find_threshold_for_recall(y_val, y_val_proba, target_recall=0.85)
chosen_threshold = threshold_result["threshold"]
print(f"Chosen threshold: {chosen_threshold:.3f} (from validation set)")
# Step 3: Evaluate on test set with chosen threshold
y_test_proba = model.predict_proba(X_test)[:, 1]
y_test_pred = (y_test_proba >= chosen_threshold).astype(int)
print("\nTest set performance at chosen threshold:")
print(classification_report(y_test, y_test_pred, target_names=["negative", "positive"]))Threshold Summary Table
| Method | Use When | Primary Metric | Tradeoff | |---|---|---|---| | Maximize F1 | Equal FN/FP cost, imbalanced classes | F1 | Balanced | | Recall-constrained | FN is dangerous (sepsis, cancer) | Recall ≥ target | Accept lower precision | | Precision-constrained | FP causes alert fatigue | Precision ≥ target | Accept lower recall | | Youden's J | No cost preference, balanced default | Sensitivity + Specificity | Balanced | | Cost-sensitive | Known dollar/clinical cost per error | Total cost | Weighted |
Interview Answer Template
Q: How do you select the classification threshold in a clinical model?
The threshold choice is driven by the clinical cost structure. I start by asking: what's more dangerous — a missed positive (false negative) or a false alarm (false positive)? For sepsis detection, false negatives are catastrophic — I use the precision-recall curve on the validation set to find the lowest threshold that achieves recall ≥ 0.90, then verify the resulting precision is operationally tolerable. For drug alert systems where alert fatigue is a risk, I find the threshold that keeps precision above a clinical minimum (e.g., 40%), then measure the recall that yields. If there's no strong cost asymmetry, I use Youden's J (maximizing sensitivity + specificity) or F1. The critical procedural rule: tune the threshold on the validation set, evaluate on the test set only once. Repeating threshold selection on the test set inflates performance estimates — the threshold is a hyperparameter and must be treated as one.