Learnixo
Back to blog
AI Systemsintermediate

The Loss Landscape

Local minima, saddle points, flat regions, and sharp vs flat minima — visualising what gradient descent traverses and why it matters for generalisation.

Asma Hafeez KhanMay 22, 20266 min read
Deep LearningLoss LandscapeOptimisationSaddle PointsGeneralisationInterview
Share:𝕏

The Landscape Metaphor

Weight space is a high-dimensional surface.
Height = loss value.
Training = finding the lowest valley.

Features of the landscape:

Global minimum:    The absolute lowest point — theoretically best.
                   For neural networks: likely unreachable exactly, and may overfit.

Local minimum:     A valley lower than all immediate neighbours.
                   In high dimensions: rare (usually saddle points instead).

Saddle point:      Low in some dimensions, high in others.
                   Gradient = 0 but not a minimum.
                   Most "stuck" points in deep learning are saddle points.

Plateau:           Nearly-zero gradient over a wide region.
                   Training slows dramatically. Momentum helps escape.

Sharp minimum:     Steep walls around the valley.
                   Small weight perturbation → large loss increase.
                   Correlates with poor generalisation.

Flat minimum:      Gentle slopes around the valley.
                   Robust to weight perturbations.
                   Correlates with better generalisation and robustness to distribution shift.

Visualising a 2D Slice

Python
import numpy as np
import torch
import torch.nn as nn

def compute_loss_surface(
    model: nn.Module,
    loader,
    criterion: nn.Module,
    resolution: int = 40,
    range_val: float = 1.0,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Compute loss surface along two random directions in weight space.
    PCA-based directions give a more meaningful slice.
    """
    # Save current weights
    weights_flat = torch.cat([p.data.view(-1) for p in model.parameters()])
    n_params = weights_flat.numel()
    
    # Two random orthogonal directions
    d1 = torch.randn(n_params)
    d2 = torch.randn(n_params)
    # Gram-Schmidt orthogonalisation
    d2 = d2 - (d2 @ d1) / (d1 @ d1) * d1
    d1 = d1 / d1.norm()
    d2 = d2 / d2.norm()
    
    alphas = np.linspace(-range_val, range_val, resolution)
    betas  = np.linspace(-range_val, range_val, resolution)
    Z = np.zeros((resolution, resolution))
    
    def set_weights(flat_w: torch.Tensor) -> None:
        offset = 0
        for p in model.parameters():
            n = p.numel()
            p.data.copy_(flat_w[offset:offset+n].view(p.shape))
            offset += n
    
    original_weights = weights_flat.clone()
    
    for i, alpha in enumerate(alphas):
        for j, beta in enumerate(betas):
            perturbed = original_weights + alpha * d1 + beta * d2
            set_weights(perturbed)
            
            loss_val = 0.0
            n_batches = 0
            with torch.no_grad():
                for X, y in loader:
                    loss_val += criterion(model(X).squeeze(), y).item()
                    n_batches += 1
            Z[i, j] = loss_val / max(n_batches, 1)
    
    # Restore original weights
    set_weights(original_weights)
    return alphas, betas, Z

Sharp vs Flat Minima

Python
import torch
import torch.nn as nn

def sharpness_measure(
    model: nn.Module,
    loader,
    criterion: nn.Module,
    perturbation_scale: float = 0.01,
    n_perturbations: int = 20,
) -> float:
    """
    Estimates sharpness: average loss increase under random weight perturbations.
    Low sharpness → flat minimum → better generalisation.
    """
    base_loss = 0.0
    n_batches = 0
    with torch.no_grad():
        for X, y in loader:
            base_loss += criterion(model(X).squeeze(), y).item()
            n_batches += 1
    base_loss /= n_batches
    
    perturbed_losses = []
    original_params = [p.data.clone() for p in model.parameters()]
    
    for _ in range(n_perturbations):
        # Add random Gaussian perturbation
        with torch.no_grad():
            for p in model.parameters():
                p.data += perturbation_scale * torch.randn_like(p.data)
        
        pert_loss = 0.0
        n_b = 0
        with torch.no_grad():
            for X, y in loader:
                pert_loss += criterion(model(X).squeeze(), y).item()
                n_b += 1
        perturbed_losses.append(pert_loss / n_b)
        
        # Restore
        with torch.no_grad():
            for p, orig in zip(model.parameters(), original_params):
                p.data.copy_(orig)
    
    avg_perturbed = sum(perturbed_losses) / len(perturbed_losses)
    return avg_perturbed - base_loss  # positive = sharp; near-zero = flat

# Flat minimisation techniques:
# 1. Larger batch size  sharper minima (more precise gradient)
# 2. Smaller batch size  flatter minima (noisier gradient)
# 3. SAM (Sharpness-Aware Minimisation) explicitly seeks flat minima

SAM: Sharpness-Aware Minimisation

Python
class SAM(torch.optim.Optimizer):
    """
    Sharpness-Aware Minimisation.
    Two forward-backward passes per step:
      1. Perturb weights to the worst-case (sharpest) direction
      2. Compute gradient at that perturbed point
      3. Step the base optimiser at original weights with this gradient
    """
    
    def __init__(self, params, base_optimizer, rho: float = 0.05, **kwargs):
        defaults = {"rho": rho, **kwargs}
        super().__init__(params, defaults)
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
    
    @torch.no_grad()
    def first_step(self, zero_grad: bool = False) -> None:
        """Perturb weights to the sharpest point in the rho-ball."""
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)
            for p in group["params"]:
                if p.grad is None:
                    continue
                e_w = p.grad * scale
                p.add_(e_w)           # perturb
                self.state[p]["e_w"] = e_w
        if zero_grad:
            self.zero_grad()
    
    @torch.no_grad()
    def second_step(self, zero_grad: bool = False) -> None:
        """Restore weights and apply base optimiser step."""
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                p.sub_(self.state[p]["e_w"])  # restore
        self.base_optimizer.step()
        if zero_grad:
            self.zero_grad()
    
    def _grad_norm(self) -> torch.Tensor:
        norms = [
            p.grad.norm(p=2)
            for group in self.param_groups
            for p in group["params"]
            if p.grad is not None
        ]
        return torch.stack(norms).norm(p=2)

# Usage (double forward-backward per step)
# sam = SAM(model.parameters(), torch.optim.SGD, rho=0.05, lr=0.01)
# 
# loss = criterion(model(X), y)
# loss.backward()
# sam.first_step(zero_grad=True)
# 
# criterion(model(X), y).backward()  # second forward pass at perturbed weights
# sam.second_step(zero_grad=True)

Why Deep Networks Don't Get Stuck in Local Minima

Classical intuition: many local minima → gradient descent gets stuck.
Modern understanding: the landscape of deep networks is more forgiving.

1. High-dimensional spaces:
   For a critical point (gradient=0) to be a local minimum,
   ALL curvature eigenvalues must be positive (the Hessian is positive definite).
   In millions of dimensions, this is extremely unlikely.
   Most critical points are saddle points, not local minima.

2. Similar loss at local minima:
   Empirically: local minima found by SGD tend to have similar loss values.
   Even if you escape to a "better" minimum, the improvement is small.

3. Over-parameterisation helps:
   Neural networks with more parameters than data points can fit the data
   through an exponential number of equivalent weight configurations.
   Gradient descent finds one of many solutions that fit well.

4. SGD noise:
   Stochastic gradient noise helps escape saddle points
   and biases toward flatter regions of the landscape.

Interview Answer

"The loss landscape for neural networks is a high-dimensional surface over weight space. Critical points where the gradient is zero include saddle points (gradient=0, but curvature differs across dimensions) rather than true local minima — in high dimensions, all eigenvalues of the Hessian being positive is unlikely, so gradient descent rarely gets permanently stuck. Plateau regions (near-zero gradient) slow training and are escaped with momentum. The practically important distinction is between sharp and flat minima: sharp minima have high curvature and generalise poorly — a small weight perturbation causes large loss increases; flat minima are robust to perturbations and correlate with better test performance. Mini-batch SGD's noise biases toward flat minima, which is one reason it often generalises better than full-batch gradient descent. SAM explicitly optimises for flat minima by taking gradient steps at the worst-case perturbation."

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.