Learnixo

Deep Learning for AI Interviews · Lesson 10 of 56

Overfitting in Deep Networks

What Overfitting Is

A model overfits when it memorises the training data
instead of learning general patterns.

Signs:
  Training loss:     very low (0.01)
  Validation loss:   much higher (0.35)
  
  Training accuracy:    98%
  Validation accuracy:  71%

The gap between training and validation performance is the overfitting signal.

Analogy: a student who memorises past exam questions.
  They get 100% on practice papers but fail the real exam.
  They learned the answers, not the material.

Why Deep Networks Overfit

A network with millions of parameters can memorise a dataset with thousands of examples.

Extreme case: a network can memorise any random labelling of any dataset
if it has enough parameters — it becomes a lookup table.

Parameter count vs. dataset size:
  ResNet-50: 25M parameters
  CIFAR-10 training set: 50K examples
  Ratio: 500 parameters per training example → severe overfitting without regularisation

Capacity vs. data:
  Small dataset + large network = high variance → overfitting
  Large dataset + large network = can generalise
  Small dataset + small network = may underfit (high bias)

Detecting Overfitting: Training Curves

Python
import matplotlib.pyplot as plt

def plot_training_curves(
    train_losses: list[float],
    val_losses: list[float],
    title: str = "Training vs Validation Loss",
) -> None:
    epochs = range(1, len(train_losses) + 1)
    plt.figure(figsize=(8, 5))
    plt.plot(epochs, train_losses, "b-", label="Training loss")
    plt.plot(epochs, val_losses, "r-", label="Validation loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(title)
    plt.legend()
    plt.tight_layout()

# Overfitting pattern:
# - Both losses decrease at first
# - Training loss continues falling
# - Validation loss reaches a minimum, then RISES (overfitting begins)
# 
# 
# loss   \  Training
#         \ 
#          \_________...
#      
#        \.  Validation
#          \.
#            \.___/‾‾‾‾  (rises after epoch 20  overfitting)
#      └─────────────────── epoch

Monitoring Overfitting Automatically

Python
class OverfittingMonitor:
    def __init__(self, patience: int = 10, delta: float = 0.001):
        self.patience = patience
        self.delta = delta
        self.best_val_loss = float("inf")
        self.epochs_without_improvement = 0
        self.overfitting_threshold = 0.10  # 10% gap signals overfitting
    
    def update(self, train_loss: float, val_loss: float) -> dict:
        gap = val_loss - train_loss
        improved = val_loss < (self.best_val_loss - self.delta)
        
        if improved:
            self.best_val_loss = val_loss
            self.epochs_without_improvement = 0
        else:
            self.epochs_without_improvement += 1
        
        return {
            "train_loss": train_loss,
            "val_loss": val_loss,
            "gap": gap,
            "is_overfitting": gap > self.overfitting_threshold,
            "should_stop": self.epochs_without_improvement >= self.patience,
        }

Causes of Overfitting

1. Too many parameters relative to training data
   Fix: reduce model size, get more data

2. Training too long (too many epochs)
   Fix: early stopping

3. No regularisation
   Fix: dropout, weight decay, batch normalisation

4. Data leakage
   Test information bleeds into training
   Fix: strict data pipeline review

5. Small, unrepresentative training set
   Fix: data augmentation, collect more data, transfer learning

6. Label noise
   Training labels are wrong → model memorises wrong labels
   Fix: label cleaning, robust loss functions

Primary Overfitting Prevention Techniques

Technique         | How it works                          | When to use
------------------|---------------------------------------|----------------------------
Dropout           | Randomly zero neurons during training | MLPs, Transformers
Weight decay (L2) | Penalise large weights               | Almost always
Batch normalisation| Reduces internal covariate shift     | CNNs, MLPs
Early stopping    | Stop when val loss stops improving   | Always — free regularisation
Data augmentation | Artificially increase training data  | Images, time series
Transfer learning | Pre-trained features → less data needed | When labelled data is scarce
Reduce model size | Fewer parameters → less capacity    | When nothing else works

Quick Implementation: All at Once

Python
import torch
import torch.nn as nn
from torch.optim import Adam

class RegularisedMLP(nn.Module):
    def __init__(self, d_in: int, d_out: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, 256),
            nn.BatchNorm1d(256),    # normalise activations
            nn.ReLU(),
            nn.Dropout(0.3),        # randomly drop 30% of neurons
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, d_out),
        )
    
    def forward(self, x):
        return self.net(x)

model = RegularisedMLP(50, 1)

# Weight decay in optimiser (L2 regularisation)
optimizer = Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

# Early stopping
best_val = float("inf")
patience, counter = 10, 0

for epoch in range(200):
    train(model, train_loader, optimizer)
    val_loss = evaluate(model, val_loader)
    
    if val_loss < best_val - 1e-4:
        best_val = val_loss
        torch.save(model.state_dict(), "best_model.pt")
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            print(f"Early stopping at epoch {epoch}")
            break

model.load_state_dict(torch.load("best_model.pt"))

Interview Answer

"Overfitting occurs when a model memorises training data instead of learning general patterns — evidenced by a large gap between training and validation loss. Deep networks are prone to it because they have more parameters than training examples, giving them the capacity to memorise. Detection: monitor training vs validation curves; the validation loss starts rising while training loss continues falling. Prevention hierarchy: always use early stopping and weight decay (L2 regularisation) — they're free. Add dropout for MLPs and Transformers. Use data augmentation if the training set is small. If nothing helps, reduce model capacity (fewer layers/neurons) or collect more labelled data."