Learnixo
Back to blog
AI Systemsintermediate

How to Select a Classification Threshold

Systematic methods for selecting the classification threshold: F1-optimal, recall-constrained, precision-constrained, cost-sensitive, and Youden's J — with clinical examples and validation procedure.

Asma Hafeez KhanMay 16, 20266 min read
Machine LearningThresholdPrecisionRecallClinical AIInterview
Share:𝕏

The Decision Framework

Before selecting a threshold, answer:

1. What is the cost of a false negative?
   → High (missed diagnosis, safety risk): target high recall
   → Low (manageable): optimize F1 or Youden's J

2. What is the cost of a false positive?
   → High (alert fatigue, invasive follow-up): target high precision
   → Low (cheap to follow up): optimize F1 or Youden's J

3. Are both errors equally costly?
   → Optimize F1 or Youden's J

4. Is there a hard clinical constraint?
   → "We must catch at least 90% of sepsis cases" → recall constraint
   → "Precision must be at least 50% or alerts will be ignored" → precision constraint

Method 1: Maximize F1

Python
import numpy as np
from sklearn.metrics import precision_recall_curve

y_proba = model.predict_proba(X_val)[:, 1]

precisions, recalls, thresholds = precision_recall_curve(y_val, y_proba)

# F1 at each threshold (precisions/recalls have one extra element at the end)
f1_scores = 2 * precisions[:-1] * recalls[:-1] / (precisions[:-1] + recalls[:-1] + 1e-9)

best_idx       = np.argmax(f1_scores)
best_threshold = thresholds[best_idx]

print(f"F1-optimal threshold: {best_threshold:.3f}")
print(f"  Precision: {precisions[best_idx]:.3f}")
print(f"  Recall:    {recalls[best_idx]:.3f}")
print(f"  F1:        {f1_scores[best_idx]:.3f}")

Method 2: Recall-Constrained (Clinical Safety)

Python
def find_threshold_for_recall(y_val, y_proba, target_recall: float) -> dict:
    """
    Find the highest threshold (best precision) that still achieves target_recall.
    """
    precisions, recalls, thresholds = precision_recall_curve(y_val, y_proba)

    # Sweep from high to low threshold (low to high recall)
    for t, p, r in zip(thresholds[::-1], precisions[-2::-1], recalls[-2::-1]):
        if r >= target_recall:
            return {
                "threshold": t,
                "precision": p,
                "recall":    r,
                "false_alarm_rate": f"1 in {1/p:.0f} alerts is real" if p > 0 else "N/A",
            }
    return {"threshold": 0.0, "precision": precisions[0], "recall": recalls[0], "note": "target not achievable"}

# Sepsis model: must achieve recall >= 0.90
result = find_threshold_for_recall(y_val, y_proba, target_recall=0.90)
print("Recall-constrained threshold:")
for k, v in result.items():
    print(f"  {k}: {v}")

Method 3: Precision-Constrained (Alert Fatigue)

Python
def find_threshold_for_precision(y_val, y_proba, target_precision: float) -> dict:
    """
    Find the lowest threshold (best recall) that achieves target_precision.
    """
    precisions, recalls, thresholds = precision_recall_curve(y_val, y_proba)

    for t, p, r in zip(thresholds, precisions, recalls):
        if p >= target_precision:
            return {
                "threshold": t,
                "precision": p,
                "recall":    r,
                "missed_fraction": f"{1-r:.0%} of real positives missed",
            }
    return {"note": "target precision not achievable at any threshold"}

# Drug alert system: precision must be at least 40% (acceptable alert burden)
result = find_threshold_for_precision(y_val, y_proba, target_precision=0.40)
print("Precision-constrained threshold:")
for k, v in result.items():
    print(f"  {k}: {v}")

Method 4: Youden's J (Balanced Default)

Python
from sklearn.metrics import roc_curve
import numpy as np

fpr, tpr, thresholds_roc = roc_curve(y_val, y_proba)

# Youden's J = sensitivity + specificity - 1 = TPR - FPR
j_scores   = tpr - fpr
best_idx   = np.argmax(j_scores)
best_j_threshold = thresholds_roc[best_idx]

print(f"Youden's J threshold: {best_j_threshold:.3f}")
print(f"  Sensitivity (recall): {tpr[best_idx]:.3f}")
print(f"  Specificity:          {1 - fpr[best_idx]:.3f}")
print(f"  Youden's J:           {j_scores[best_idx]:.3f}")
# Best for: no strong preference between FP and FN; want a balanced operating point

Method 5: Cost-Sensitive Threshold

Python
def cost_optimal_threshold(
    y_val: np.ndarray,
    y_proba: np.ndarray,
    cost_fn: float,   # cost per false negative
    cost_fp: float,   # cost per false positive
) -> dict:
    """
    Find the threshold minimizing expected total cost.
    """
    from sklearn.metrics import confusion_matrix

    best_cost = float("inf")
    best_t    = 0.5

    for t in np.arange(0.05, 0.95, 0.01):
        y_pred_t = (y_proba >= t).astype(int)
        tn, fp, fn, tp = confusion_matrix(y_val, y_pred_t).ravel()
        total_cost = fn * cost_fn + fp * cost_fp
        if total_cost < best_cost:
            best_cost = total_cost
            best_t    = t

    y_pred_best = (y_proba >= best_t).astype(int)
    tn, fp, fn, tp = confusion_matrix(y_val, y_pred_best).ravel()
    return {
        "threshold":   best_t,
        "total_cost":  best_cost,
        "fn":          fn,
        "fp":          fp,
        "recall":      tp / (tp + fn) if (tp + fn) > 0 else 0,
        "precision":   tp / (tp + fp) if (tp + fp) > 0 else 0,
    }

# Readmission model cost structure
result = cost_optimal_threshold(
    y_val, y_proba,
    cost_fn=5000,   # missed readmission: readmission prevention opportunity lost
    cost_fp=200,    # false alarm: unnecessary discharge planning resources
)
print("Cost-optimal threshold:")
for k, v in result.items():
    print(f"  {k}: {v}")

The Correct Validation Procedure

Python
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

# 3-way split
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.30, random_state=42)
X_val,   X_test, y_val,   y_test = train_test_split(X_temp, y_temp, test_size=0.50, random_state=42)

# Step 1: Train model
model.fit(X_train, y_train)

# Step 2: Find threshold on validation set
y_val_proba = model.predict_proba(X_val)[:, 1]
threshold_result = find_threshold_for_recall(y_val, y_val_proba, target_recall=0.85)
chosen_threshold = threshold_result["threshold"]
print(f"Chosen threshold: {chosen_threshold:.3f} (from validation set)")

# Step 3: Evaluate on test set with chosen threshold
y_test_proba = model.predict_proba(X_test)[:, 1]
y_test_pred  = (y_test_proba >= chosen_threshold).astype(int)

print("\nTest set performance at chosen threshold:")
print(classification_report(y_test, y_test_pred, target_names=["negative", "positive"]))

Threshold Summary Table

| Method | Use When | Primary Metric | Tradeoff | |---|---|---|---| | Maximize F1 | Equal FN/FP cost, imbalanced classes | F1 | Balanced | | Recall-constrained | FN is dangerous (sepsis, cancer) | Recall ≥ target | Accept lower precision | | Precision-constrained | FP causes alert fatigue | Precision ≥ target | Accept lower recall | | Youden's J | No cost preference, balanced default | Sensitivity + Specificity | Balanced | | Cost-sensitive | Known dollar/clinical cost per error | Total cost | Weighted |


Interview Answer Template

Q: How do you select the classification threshold in a clinical model?

The threshold choice is driven by the clinical cost structure. I start by asking: what's more dangerous — a missed positive (false negative) or a false alarm (false positive)? For sepsis detection, false negatives are catastrophic — I use the precision-recall curve on the validation set to find the lowest threshold that achieves recall ≥ 0.90, then verify the resulting precision is operationally tolerable. For drug alert systems where alert fatigue is a risk, I find the threshold that keeps precision above a clinical minimum (e.g., 40%), then measure the recall that yields. If there's no strong cost asymmetry, I use Youden's J (maximizing sensitivity + specificity) or F1. The critical procedural rule: tune the threshold on the validation set, evaluate on the test set only once. Repeating threshold selection on the test set inflates performance estimates — the threshold is a hyperparameter and must be treated as one.

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.