Learnixo
Back to blog
AI Systemsbeginner

Overfitting in Neural Networks

What overfitting is, how to detect it, why deep networks are prone to it, and the primary techniques to prevent it.

Asma Hafeez KhanMay 21, 20264 min read
Deep LearningOverfittingGeneralisationRegularisationInterview
Share:𝕏

What Overfitting Is

A model overfits when it memorises the training data
instead of learning general patterns.

Signs:
  Training loss:     very low (0.01)
  Validation loss:   much higher (0.35)
  
  Training accuracy:    98%
  Validation accuracy:  71%

The gap between training and validation performance is the overfitting signal.

Analogy: a student who memorises past exam questions.
  They get 100% on practice papers but fail the real exam.
  They learned the answers, not the material.

Why Deep Networks Overfit

A network with millions of parameters can memorise a dataset with thousands of examples.

Extreme case: a network can memorise any random labelling of any dataset
if it has enough parameters — it becomes a lookup table.

Parameter count vs. dataset size:
  ResNet-50: 25M parameters
  CIFAR-10 training set: 50K examples
  Ratio: 500 parameters per training example → severe overfitting without regularisation

Capacity vs. data:
  Small dataset + large network = high variance → overfitting
  Large dataset + large network = can generalise
  Small dataset + small network = may underfit (high bias)

Detecting Overfitting: Training Curves

Python
import matplotlib.pyplot as plt

def plot_training_curves(
    train_losses: list[float],
    val_losses: list[float],
    title: str = "Training vs Validation Loss",
) -> None:
    epochs = range(1, len(train_losses) + 1)
    plt.figure(figsize=(8, 5))
    plt.plot(epochs, train_losses, "b-", label="Training loss")
    plt.plot(epochs, val_losses, "r-", label="Validation loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(title)
    plt.legend()
    plt.tight_layout()

# Overfitting pattern:
# - Both losses decrease at first
# - Training loss continues falling
# - Validation loss reaches a minimum, then RISES (overfitting begins)
# 
# 
# loss   \  Training
#         \ 
#          \_________...
#      
#        \.  Validation
#          \.
#            \.___/‾‾‾‾  (rises after epoch 20  overfitting)
#      └─────────────────── epoch

Monitoring Overfitting Automatically

Python
class OverfittingMonitor:
    def __init__(self, patience: int = 10, delta: float = 0.001):
        self.patience = patience
        self.delta = delta
        self.best_val_loss = float("inf")
        self.epochs_without_improvement = 0
        self.overfitting_threshold = 0.10  # 10% gap signals overfitting
    
    def update(self, train_loss: float, val_loss: float) -> dict:
        gap = val_loss - train_loss
        improved = val_loss < (self.best_val_loss - self.delta)
        
        if improved:
            self.best_val_loss = val_loss
            self.epochs_without_improvement = 0
        else:
            self.epochs_without_improvement += 1
        
        return {
            "train_loss": train_loss,
            "val_loss": val_loss,
            "gap": gap,
            "is_overfitting": gap > self.overfitting_threshold,
            "should_stop": self.epochs_without_improvement >= self.patience,
        }

Causes of Overfitting

1. Too many parameters relative to training data
   Fix: reduce model size, get more data

2. Training too long (too many epochs)
   Fix: early stopping

3. No regularisation
   Fix: dropout, weight decay, batch normalisation

4. Data leakage
   Test information bleeds into training
   Fix: strict data pipeline review

5. Small, unrepresentative training set
   Fix: data augmentation, collect more data, transfer learning

6. Label noise
   Training labels are wrong → model memorises wrong labels
   Fix: label cleaning, robust loss functions

Primary Overfitting Prevention Techniques

Technique         | How it works                          | When to use
------------------|---------------------------------------|----------------------------
Dropout           | Randomly zero neurons during training | MLPs, Transformers
Weight decay (L2) | Penalise large weights               | Almost always
Batch normalisation| Reduces internal covariate shift     | CNNs, MLPs
Early stopping    | Stop when val loss stops improving   | Always — free regularisation
Data augmentation | Artificially increase training data  | Images, time series
Transfer learning | Pre-trained features → less data needed | When labelled data is scarce
Reduce model size | Fewer parameters → less capacity    | When nothing else works

Quick Implementation: All at Once

Python
import torch
import torch.nn as nn
from torch.optim import Adam

class RegularisedMLP(nn.Module):
    def __init__(self, d_in: int, d_out: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, 256),
            nn.BatchNorm1d(256),    # normalise activations
            nn.ReLU(),
            nn.Dropout(0.3),        # randomly drop 30% of neurons
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, d_out),
        )
    
    def forward(self, x):
        return self.net(x)

model = RegularisedMLP(50, 1)

# Weight decay in optimiser (L2 regularisation)
optimizer = Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

# Early stopping
best_val = float("inf")
patience, counter = 10, 0

for epoch in range(200):
    train(model, train_loader, optimizer)
    val_loss = evaluate(model, val_loader)
    
    if val_loss < best_val - 1e-4:
        best_val = val_loss
        torch.save(model.state_dict(), "best_model.pt")
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            print(f"Early stopping at epoch {epoch}")
            break

model.load_state_dict(torch.load("best_model.pt"))

Interview Answer

"Overfitting occurs when a model memorises training data instead of learning general patterns — evidenced by a large gap between training and validation loss. Deep networks are prone to it because they have more parameters than training examples, giving them the capacity to memorise. Detection: monitor training vs validation curves; the validation loss starts rising while training loss continues falling. Prevention hierarchy: always use early stopping and weight decay (L2 regularisation) — they're free. Add dropout for MLPs and Transformers. Use data augmentation if the training set is small. If nothing helps, reduce model capacity (fewer layers/neurons) or collect more labelled data."

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.