Deep Learning for AI Interviews · Lesson 27 of 56
The Loss Landscape and Local Minima
The Landscape Metaphor
Weight space is a high-dimensional surface.
Height = loss value.
Training = finding the lowest valley.
Features of the landscape:
Global minimum: The absolute lowest point — theoretically best.
For neural networks: likely unreachable exactly, and may overfit.
Local minimum: A valley lower than all immediate neighbours.
In high dimensions: rare (usually saddle points instead).
Saddle point: Low in some dimensions, high in others.
Gradient = 0 but not a minimum.
Most "stuck" points in deep learning are saddle points.
Plateau: Nearly-zero gradient over a wide region.
Training slows dramatically. Momentum helps escape.
Sharp minimum: Steep walls around the valley.
Small weight perturbation → large loss increase.
Correlates with poor generalisation.
Flat minimum: Gentle slopes around the valley.
Robust to weight perturbations.
Correlates with better generalisation and robustness to distribution shift.Visualising a 2D Slice
import numpy as np
import torch
import torch.nn as nn
def compute_loss_surface(
model: nn.Module,
loader,
criterion: nn.Module,
resolution: int = 40,
range_val: float = 1.0,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Compute loss surface along two random directions in weight space.
PCA-based directions give a more meaningful slice.
"""
# Save current weights
weights_flat = torch.cat([p.data.view(-1) for p in model.parameters()])
n_params = weights_flat.numel()
# Two random orthogonal directions
d1 = torch.randn(n_params)
d2 = torch.randn(n_params)
# Gram-Schmidt orthogonalisation
d2 = d2 - (d2 @ d1) / (d1 @ d1) * d1
d1 = d1 / d1.norm()
d2 = d2 / d2.norm()
alphas = np.linspace(-range_val, range_val, resolution)
betas = np.linspace(-range_val, range_val, resolution)
Z = np.zeros((resolution, resolution))
def set_weights(flat_w: torch.Tensor) -> None:
offset = 0
for p in model.parameters():
n = p.numel()
p.data.copy_(flat_w[offset:offset+n].view(p.shape))
offset += n
original_weights = weights_flat.clone()
for i, alpha in enumerate(alphas):
for j, beta in enumerate(betas):
perturbed = original_weights + alpha * d1 + beta * d2
set_weights(perturbed)
loss_val = 0.0
n_batches = 0
with torch.no_grad():
for X, y in loader:
loss_val += criterion(model(X).squeeze(), y).item()
n_batches += 1
Z[i, j] = loss_val / max(n_batches, 1)
# Restore original weights
set_weights(original_weights)
return alphas, betas, ZSharp vs Flat Minima
import torch
import torch.nn as nn
def sharpness_measure(
model: nn.Module,
loader,
criterion: nn.Module,
perturbation_scale: float = 0.01,
n_perturbations: int = 20,
) -> float:
"""
Estimates sharpness: average loss increase under random weight perturbations.
Low sharpness → flat minimum → better generalisation.
"""
base_loss = 0.0
n_batches = 0
with torch.no_grad():
for X, y in loader:
base_loss += criterion(model(X).squeeze(), y).item()
n_batches += 1
base_loss /= n_batches
perturbed_losses = []
original_params = [p.data.clone() for p in model.parameters()]
for _ in range(n_perturbations):
# Add random Gaussian perturbation
with torch.no_grad():
for p in model.parameters():
p.data += perturbation_scale * torch.randn_like(p.data)
pert_loss = 0.0
n_b = 0
with torch.no_grad():
for X, y in loader:
pert_loss += criterion(model(X).squeeze(), y).item()
n_b += 1
perturbed_losses.append(pert_loss / n_b)
# Restore
with torch.no_grad():
for p, orig in zip(model.parameters(), original_params):
p.data.copy_(orig)
avg_perturbed = sum(perturbed_losses) / len(perturbed_losses)
return avg_perturbed - base_loss # positive = sharp; near-zero = flat
# Flat minimisation techniques:
# 1. Larger batch size → sharper minima (more precise gradient)
# 2. Smaller batch size → flatter minima (noisier gradient)
# 3. SAM (Sharpness-Aware Minimisation) explicitly seeks flat minimaSAM: Sharpness-Aware Minimisation
class SAM(torch.optim.Optimizer):
"""
Sharpness-Aware Minimisation.
Two forward-backward passes per step:
1. Perturb weights to the worst-case (sharpest) direction
2. Compute gradient at that perturbed point
3. Step the base optimiser at original weights with this gradient
"""
def __init__(self, params, base_optimizer, rho: float = 0.05, **kwargs):
defaults = {"rho": rho, **kwargs}
super().__init__(params, defaults)
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
self.param_groups = self.base_optimizer.param_groups
@torch.no_grad()
def first_step(self, zero_grad: bool = False) -> None:
"""Perturb weights to the sharpest point in the rho-ball."""
grad_norm = self._grad_norm()
for group in self.param_groups:
scale = group["rho"] / (grad_norm + 1e-12)
for p in group["params"]:
if p.grad is None:
continue
e_w = p.grad * scale
p.add_(e_w) # perturb
self.state[p]["e_w"] = e_w
if zero_grad:
self.zero_grad()
@torch.no_grad()
def second_step(self, zero_grad: bool = False) -> None:
"""Restore weights and apply base optimiser step."""
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
p.sub_(self.state[p]["e_w"]) # restore
self.base_optimizer.step()
if zero_grad:
self.zero_grad()
def _grad_norm(self) -> torch.Tensor:
norms = [
p.grad.norm(p=2)
for group in self.param_groups
for p in group["params"]
if p.grad is not None
]
return torch.stack(norms).norm(p=2)
# Usage (double forward-backward per step)
# sam = SAM(model.parameters(), torch.optim.SGD, rho=0.05, lr=0.01)
#
# loss = criterion(model(X), y)
# loss.backward()
# sam.first_step(zero_grad=True)
#
# criterion(model(X), y).backward() # second forward pass at perturbed weights
# sam.second_step(zero_grad=True)Why Deep Networks Don't Get Stuck in Local Minima
Classical intuition: many local minima → gradient descent gets stuck.
Modern understanding: the landscape of deep networks is more forgiving.
1. High-dimensional spaces:
For a critical point (gradient=0) to be a local minimum,
ALL curvature eigenvalues must be positive (the Hessian is positive definite).
In millions of dimensions, this is extremely unlikely.
Most critical points are saddle points, not local minima.
2. Similar loss at local minima:
Empirically: local minima found by SGD tend to have similar loss values.
Even if you escape to a "better" minimum, the improvement is small.
3. Over-parameterisation helps:
Neural networks with more parameters than data points can fit the data
through an exponential number of equivalent weight configurations.
Gradient descent finds one of many solutions that fit well.
4. SGD noise:
Stochastic gradient noise helps escape saddle points
and biases toward flatter regions of the landscape.Interview Answer
"The loss landscape for neural networks is a high-dimensional surface over weight space. Critical points where the gradient is zero include saddle points (gradient=0, but curvature differs across dimensions) rather than true local minima — in high dimensions, all eigenvalues of the Hessian being positive is unlikely, so gradient descent rarely gets permanently stuck. Plateau regions (near-zero gradient) slow training and are escaped with momentum. The practically important distinction is between sharp and flat minima: sharp minima have high curvature and generalise poorly — a small weight perturbation causes large loss increases; flat minima are robust to perturbations and correlate with better test performance. Mini-batch SGD's noise biases toward flat minima, which is one reason it often generalises better than full-batch gradient descent. SAM explicitly optimises for flat minima by taking gradient steps at the worst-case perturbation."