Overfitting in Neural Networks
What overfitting is, how to detect it, why deep networks are prone to it, and the primary techniques to prevent it.
What Overfitting Is
A model overfits when it memorises the training data
instead of learning general patterns.
Signs:
Training loss: very low (0.01)
Validation loss: much higher (0.35)
Training accuracy: 98%
Validation accuracy: 71%
The gap between training and validation performance is the overfitting signal.
Analogy: a student who memorises past exam questions.
They get 100% on practice papers but fail the real exam.
They learned the answers, not the material.Why Deep Networks Overfit
A network with millions of parameters can memorise a dataset with thousands of examples.
Extreme case: a network can memorise any random labelling of any dataset
if it has enough parameters — it becomes a lookup table.
Parameter count vs. dataset size:
ResNet-50: 25M parameters
CIFAR-10 training set: 50K examples
Ratio: 500 parameters per training example → severe overfitting without regularisation
Capacity vs. data:
Small dataset + large network = high variance → overfitting
Large dataset + large network = can generalise
Small dataset + small network = may underfit (high bias)Detecting Overfitting: Training Curves
import matplotlib.pyplot as plt
def plot_training_curves(
train_losses: list[float],
val_losses: list[float],
title: str = "Training vs Validation Loss",
) -> None:
epochs = range(1, len(train_losses) + 1)
plt.figure(figsize=(8, 5))
plt.plot(epochs, train_losses, "b-", label="Training loss")
plt.plot(epochs, val_losses, "r-", label="Validation loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title(title)
plt.legend()
plt.tight_layout()
# Overfitting pattern:
# - Both losses decrease at first
# - Training loss continues falling
# - Validation loss reaches a minimum, then RISES (overfitting begins)
#
#
# loss │ \ Training
# │ \
# │ \_________...
# │
# │ \. Validation
# │ \.
# │ \.___/‾‾‾‾ (rises after epoch 20 → overfitting)
# └─────────────────── epochMonitoring Overfitting Automatically
class OverfittingMonitor:
def __init__(self, patience: int = 10, delta: float = 0.001):
self.patience = patience
self.delta = delta
self.best_val_loss = float("inf")
self.epochs_without_improvement = 0
self.overfitting_threshold = 0.10 # 10% gap signals overfitting
def update(self, train_loss: float, val_loss: float) -> dict:
gap = val_loss - train_loss
improved = val_loss < (self.best_val_loss - self.delta)
if improved:
self.best_val_loss = val_loss
self.epochs_without_improvement = 0
else:
self.epochs_without_improvement += 1
return {
"train_loss": train_loss,
"val_loss": val_loss,
"gap": gap,
"is_overfitting": gap > self.overfitting_threshold,
"should_stop": self.epochs_without_improvement >= self.patience,
}Causes of Overfitting
1. Too many parameters relative to training data
Fix: reduce model size, get more data
2. Training too long (too many epochs)
Fix: early stopping
3. No regularisation
Fix: dropout, weight decay, batch normalisation
4. Data leakage
Test information bleeds into training
Fix: strict data pipeline review
5. Small, unrepresentative training set
Fix: data augmentation, collect more data, transfer learning
6. Label noise
Training labels are wrong → model memorises wrong labels
Fix: label cleaning, robust loss functionsPrimary Overfitting Prevention Techniques
Technique | How it works | When to use
------------------|---------------------------------------|----------------------------
Dropout | Randomly zero neurons during training | MLPs, Transformers
Weight decay (L2) | Penalise large weights | Almost always
Batch normalisation| Reduces internal covariate shift | CNNs, MLPs
Early stopping | Stop when val loss stops improving | Always — free regularisation
Data augmentation | Artificially increase training data | Images, time series
Transfer learning | Pre-trained features → less data needed | When labelled data is scarce
Reduce model size | Fewer parameters → less capacity | When nothing else worksQuick Implementation: All at Once
import torch
import torch.nn as nn
from torch.optim import Adam
class RegularisedMLP(nn.Module):
def __init__(self, d_in: int, d_out: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_in, 256),
nn.BatchNorm1d(256), # normalise activations
nn.ReLU(),
nn.Dropout(0.3), # randomly drop 30% of neurons
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, d_out),
)
def forward(self, x):
return self.net(x)
model = RegularisedMLP(50, 1)
# Weight decay in optimiser (L2 regularisation)
optimizer = Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
# Early stopping
best_val = float("inf")
patience, counter = 10, 0
for epoch in range(200):
train(model, train_loader, optimizer)
val_loss = evaluate(model, val_loader)
if val_loss < best_val - 1e-4:
best_val = val_loss
torch.save(model.state_dict(), "best_model.pt")
counter = 0
else:
counter += 1
if counter >= patience:
print(f"Early stopping at epoch {epoch}")
break
model.load_state_dict(torch.load("best_model.pt"))Interview Answer
"Overfitting occurs when a model memorises training data instead of learning general patterns — evidenced by a large gap between training and validation loss. Deep networks are prone to it because they have more parameters than training examples, giving them the capacity to memorise. Detection: monitor training vs validation curves; the validation loss starts rising while training loss continues falling. Prevention hierarchy: always use early stopping and weight decay (L2 regularisation) — they're free. Add dropout for MLPs and Transformers. Use data augmentation if the training set is small. If nothing helps, reduce model capacity (fewer layers/neurons) or collect more labelled data."
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.