Deep Learning for AI Interviews · Lesson 52 of 56
Early Stopping: Simple but Effective
Why Early Stopping Works
As training progresses:
Phase 1: Both train and val loss decrease (underfitting → good fit)
Phase 2: Train loss continues decreasing, val loss plateaus
Phase 3: Train loss still decreasing, val loss increases (overfitting begins)
Early stopping: stop at the best val loss point (end of Phase 2 / start of Phase 3).
Restoring best weights: save the model weights at the best validation point.
This is equivalent to constraining the model to a lower-complexity function —
the optimiser hasn't had time to memorise training noise.
Early stopping is the most universally applicable regularisation technique:
✓ Works for all architectures
✓ Requires no hyperparameter tuning (only patience)
✓ Computational benefit: stops when already optimalEarly Stopping Implementation
import torch
import torch.nn as nn
import copy
from typing import Optional
class EarlyStopping:
"""
Stops training when validation loss stops improving for `patience` epochs.
Saves and restores best model weights.
"""
def __init__(
self,
patience: int = 10,
min_delta: float = 1e-4,
restore_best_weights: bool = True,
mode: str = "min", # "min" for loss, "max" for AUC/accuracy
verbose: bool = True,
):
self.patience = patience
self.min_delta = min_delta
self.restore_best_weights = restore_best_weights
self.mode = mode
self.verbose = verbose
self.best_score: Optional[float] = None
self.best_weights = None
self.counter = 0
self.stop = False
def _is_improvement(self, score: float) -> bool:
if self.best_score is None:
return True
if self.mode == "min":
return score < self.best_score - self.min_delta
else:
return score > self.best_score + self.min_delta
def __call__(self, score: float, model: nn.Module) -> bool:
"""Returns True if training should stop."""
if self._is_improvement(score):
self.best_score = score
self.counter = 0
if self.restore_best_weights:
self.best_weights = copy.deepcopy(model.state_dict())
if self.verbose:
print(f" ✓ Val score improved to {score:.4f}")
else:
self.counter += 1
if self.verbose:
print(f" · No improvement ({self.counter}/{self.patience})")
if self.counter >= self.patience:
self.stop = True
if self.restore_best_weights and self.best_weights:
model.load_state_dict(self.best_weights)
print(f" ↩ Restored best weights (score: {self.best_score:.4f})")
return self.stop
# Usage
early_stop = EarlyStopping(patience=10, min_delta=1e-4, mode="min")
model = nn.Sequential(nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 1))
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
criterion = nn.BCEWithLogitsLoss()
# Simulated training
for epoch in range(100):
# ... training ...
# Simulate val loss: improves early, then plateaus/worsens
val_loss = max(0.2, 1.0 - 0.05 * epoch + 0.001 * max(0, epoch - 20) ** 1.5)
if early_stop(val_loss, model):
print(f"Early stopping at epoch {epoch+1}")
breakPatience Selection
patience = 5: Aggressive — stops quickly; risk of stopping during a temporary plateau
patience = 10: Moderate — good default for most problems
patience = 20: Patient — allows longer plateaus; better for noisy training or LR schedules
patience = 50: Conservative — mostly for transformers with long warmup phases
Rule of thumb:
patience = max(10, n_epochs × 0.1)
Also consider:
- With learning rate schedulers: set patience > scheduler plateau wait
(ReduceLROnPlateau patience=5, early stopping patience=20)
- With cosine annealing: early stopping is less useful — the schedule already
controls training length; just run for the full schedule
- min_delta: typically 1e-4 to 1e-3; smaller values are more sensitive to tiny improvementsCombining with Learning Rate Scheduling
import torch
import torch.nn as nn
import copy
def train_with_early_stopping_and_scheduler(
model: nn.Module,
train_loader,
val_loader,
max_epochs: int = 200,
patience: int = 15,
lr: float = 3e-4,
) -> tuple[nn.Module, dict]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
criterion = nn.BCEWithLogitsLoss()
# ReduceLROnPlateau: halves lr if val loss doesn't improve for 5 epochs
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=5, min_lr=1e-6
)
early_stop = EarlyStopping(patience=patience, mode="min")
history = {"train_loss": [], "val_loss": [], "lr": []}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
for epoch in range(max_epochs):
# Training
model.train()
train_loss = 0.0
for X, y in train_loader:
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
loss = criterion(model(X).squeeze(), y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
train_loss += loss.item()
train_loss /= len(train_loader)
# Validation
model.eval()
val_loss = 0.0
with torch.no_grad():
for X, y in val_loader:
X, y = X.to(device), y.to(device)
val_loss += criterion(model(X).squeeze(), y).item()
val_loss /= len(val_loader)
current_lr = optimizer.param_groups[0]["lr"]
history["train_loss"].append(train_loss)
history["val_loss"].append(val_loss)
history["lr"].append(current_lr)
# Scheduler step (based on val loss)
scheduler.step(val_loss)
if epoch % 5 == 0:
print(f"Epoch {epoch:3d}: train={train_loss:.4f}, val={val_loss:.4f}, lr={current_lr:.2e}")
# Early stopping check
if early_stop(val_loss, model):
print(f"Early stopping triggered at epoch {epoch+1}")
break
return model, historyWhat Early Stopping Monitors
import torch
class MultiMetricEarlyStopping:
"""Early stopping based on any monitored metric."""
def __init__(self, patience: int = 10, mode: str = "min"):
self.patience = patience
self.mode = mode
self.best = None
self.counter = 0
self.best_weights = None
def step(
self,
metric: float,
model: torch.nn.Module,
metric_name: str = "val_loss",
) -> bool:
import copy
improved = (
self.best is None or
(self.mode == "min" and metric < self.best - 1e-4) or
(self.mode == "max" and metric > self.best + 1e-4)
)
if improved:
self.best = metric
self.counter = 0
self.best_weights = copy.deepcopy(model.state_dict())
print(f" {metric_name}={metric:.4f} improved")
else:
self.counter += 1
if self.counter >= self.patience:
model.load_state_dict(self.best_weights)
return True
return False
# Monitor val AUC (maximise) instead of val loss (minimise)
auc_stopper = MultiMetricEarlyStopping(patience=10, mode="max")
loss_stopper = MultiMetricEarlyStopping(patience=10, mode="min")
# In clinical AI: monitoring AUC (a ranking metric) is often better than loss,
# since the goal is discrimination (ranking sick vs healthy) not calibrationInterview Answer
"Early stopping halts training when validation loss (or metric) stops improving, then restores the weights from the best validation epoch. It prevents the model from memorising training noise — training too long allows the optimiser to fit the training set exactly, causing validation performance to degrade. Implementation: track best validation metric, increment a patience counter when no improvement occurs, and stop when the counter reaches the threshold (10–20 epochs is typical). Best weights must be explicitly saved and restored — PyTorch does not do this automatically. Early stopping is the most broadly applicable regularisation technique: it works without architecture changes, requires no hyperparameter search beyond patience, and saves computation by stopping when already optimal. In combination with ReduceLROnPlateau: set patience for early stopping (20) higher than for LR reduction (5–10), so the learning rate is reduced first before giving up. For transformers with warmup + cosine decay schedules, early stopping is less critical — the schedule already controls training length."