Learnixo

Deep Learning for AI Interviews · Lesson 52 of 56

Early Stopping: Simple but Effective

Why Early Stopping Works

As training progresses:

Phase 1: Both train and val loss decrease (underfitting → good fit)
Phase 2: Train loss continues decreasing, val loss plateaus
Phase 3: Train loss still decreasing, val loss increases (overfitting begins)

Early stopping: stop at the best val loss point (end of Phase 2 / start of Phase 3).
Restoring best weights: save the model weights at the best validation point.

This is equivalent to constraining the model to a lower-complexity function —
the optimiser hasn't had time to memorise training noise.

Early stopping is the most universally applicable regularisation technique:
  ✓ Works for all architectures
  ✓ Requires no hyperparameter tuning (only patience)
  ✓ Computational benefit: stops when already optimal

Early Stopping Implementation

Python
import torch
import torch.nn as nn
import copy
from typing import Optional

class EarlyStopping:
    """
    Stops training when validation loss stops improving for `patience` epochs.
    Saves and restores best model weights.
    """
    
    def __init__(
        self,
        patience: int = 10,
        min_delta: float = 1e-4,
        restore_best_weights: bool = True,
        mode: str = "min",  # "min" for loss, "max" for AUC/accuracy
        verbose: bool = True,
    ):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.mode = mode
        self.verbose = verbose
        
        self.best_score: Optional[float] = None
        self.best_weights = None
        self.counter = 0
        self.stop = False
    
    def _is_improvement(self, score: float) -> bool:
        if self.best_score is None:
            return True
        if self.mode == "min":
            return score < self.best_score - self.min_delta
        else:
            return score > self.best_score + self.min_delta
    
    def __call__(self, score: float, model: nn.Module) -> bool:
        """Returns True if training should stop."""
        if self._is_improvement(score):
            self.best_score = score
            self.counter = 0
            if self.restore_best_weights:
                self.best_weights = copy.deepcopy(model.state_dict())
            if self.verbose:
                print(f"  ✓ Val score improved to {score:.4f}")
        else:
            self.counter += 1
            if self.verbose:
                print(f"  · No improvement ({self.counter}/{self.patience})")
            
            if self.counter >= self.patience:
                self.stop = True
                if self.restore_best_weights and self.best_weights:
                    model.load_state_dict(self.best_weights)
                    print(f"  ↩ Restored best weights (score: {self.best_score:.4f})")
        
        return self.stop

# Usage
early_stop = EarlyStopping(patience=10, min_delta=1e-4, mode="min")

model = nn.Sequential(nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 1))
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
criterion = nn.BCEWithLogitsLoss()

# Simulated training
for epoch in range(100):
    # ... training ...
    
    # Simulate val loss: improves early, then plateaus/worsens
    val_loss = max(0.2, 1.0 - 0.05 * epoch + 0.001 * max(0, epoch - 20) ** 1.5)
    
    if early_stop(val_loss, model):
        print(f"Early stopping at epoch {epoch+1}")
        break

Patience Selection

patience = 5:  Aggressive — stops quickly; risk of stopping during a temporary plateau
patience = 10: Moderate — good default for most problems
patience = 20: Patient — allows longer plateaus; better for noisy training or LR schedules
patience = 50: Conservative — mostly for transformers with long warmup phases

Rule of thumb:
  patience = max(10, n_epochs × 0.1)

Also consider:
  - With learning rate schedulers: set patience > scheduler plateau wait
    (ReduceLROnPlateau patience=5, early stopping patience=20)
  - With cosine annealing: early stopping is less useful — the schedule already
    controls training length; just run for the full schedule
  - min_delta: typically 1e-4 to 1e-3; smaller values are more sensitive to tiny improvements

Combining with Learning Rate Scheduling

Python
import torch
import torch.nn as nn
import copy

def train_with_early_stopping_and_scheduler(
    model: nn.Module,
    train_loader,
    val_loader,
    max_epochs: int = 200,
    patience: int = 15,
    lr: float = 3e-4,
) -> tuple[nn.Module, dict]:
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    criterion = nn.BCEWithLogitsLoss()
    
    # ReduceLROnPlateau: halves lr if val loss doesn't improve for 5 epochs
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.5, patience=5, min_lr=1e-6
    )
    
    early_stop = EarlyStopping(patience=patience, mode="min")
    history = {"train_loss": [], "val_loss": [], "lr": []}
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    for epoch in range(max_epochs):
        # Training
        model.train()
        train_loss = 0.0
        for X, y in train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            loss = criterion(model(X).squeeze(), y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)
        
        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for X, y in val_loader:
                X, y = X.to(device), y.to(device)
                val_loss += criterion(model(X).squeeze(), y).item()
        val_loss /= len(val_loader)
        
        current_lr = optimizer.param_groups[0]["lr"]
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["lr"].append(current_lr)
        
        # Scheduler step (based on val loss)
        scheduler.step(val_loss)
        
        if epoch % 5 == 0:
            print(f"Epoch {epoch:3d}: train={train_loss:.4f}, val={val_loss:.4f}, lr={current_lr:.2e}")
        
        # Early stopping check
        if early_stop(val_loss, model):
            print(f"Early stopping triggered at epoch {epoch+1}")
            break
    
    return model, history

What Early Stopping Monitors

Python
import torch

class MultiMetricEarlyStopping:
    """Early stopping based on any monitored metric."""
    
    def __init__(self, patience: int = 10, mode: str = "min"):
        self.patience = patience
        self.mode = mode
        self.best = None
        self.counter = 0
        self.best_weights = None
    
    def step(
        self,
        metric: float,
        model: torch.nn.Module,
        metric_name: str = "val_loss",
    ) -> bool:
        import copy
        
        improved = (
            self.best is None or
            (self.mode == "min" and metric < self.best - 1e-4) or
            (self.mode == "max" and metric > self.best + 1e-4)
        )
        
        if improved:
            self.best = metric
            self.counter = 0
            self.best_weights = copy.deepcopy(model.state_dict())
            print(f"  {metric_name}={metric:.4f} improved")
        else:
            self.counter += 1
            if self.counter >= self.patience:
                model.load_state_dict(self.best_weights)
                return True
        return False

# Monitor val AUC (maximise) instead of val loss (minimise)
auc_stopper = MultiMetricEarlyStopping(patience=10, mode="max")
loss_stopper = MultiMetricEarlyStopping(patience=10, mode="min")

# In clinical AI: monitoring AUC (a ranking metric) is often better than loss,
# since the goal is discrimination (ranking sick vs healthy) not calibration

Interview Answer

"Early stopping halts training when validation loss (or metric) stops improving, then restores the weights from the best validation epoch. It prevents the model from memorising training noise — training too long allows the optimiser to fit the training set exactly, causing validation performance to degrade. Implementation: track best validation metric, increment a patience counter when no improvement occurs, and stop when the counter reaches the threshold (10–20 epochs is typical). Best weights must be explicitly saved and restored — PyTorch does not do this automatically. Early stopping is the most broadly applicable regularisation technique: it works without architecture changes, requires no hyperparameter search beyond patience, and saves computation by stopping when already optimal. In combination with ReduceLROnPlateau: set patience for early stopping (20) higher than for LR reduction (5–10), so the learning rate is reduced first before giving up. For transformers with warmup + cosine decay schedules, early stopping is less critical — the schedule already controls training length."