Exploding Gradients
Why gradients can grow exponentially in deep networks, how to detect explosion, and gradient clipping as the standard fix.
The Problem
Exploding gradients: the opposite of vanishing gradients.
When weight matrices have eigenvalues > 1, gradients can grow exponentially
as they propagate backward through many layers.
Symptom: loss suddenly spikes to NaN or very large values
model weights blow up to Β±infinity
Most common in:
- Recurrent networks (RNNs) β gradient traverses the same weights many times
- Very deep feedforward networks without residual connections
- Large learning rates
- Poor weight initialisation (weights too large)
Vanishing vs Exploding:
|W_max_eigenvalue| < 1: vanishing
|W_max_eigenvalue| > 1: explodingDetecting Exploding Gradients
import torch
import torch.nn as nn
def detect_exploding_gradients(model: nn.Module, threshold: float = 10.0) -> dict:
"""Check if any parameter gradients exceed a threshold after backward()."""
results = {"has_explosion": False, "problematic_layers": []}
for name, param in model.named_parameters():
if param.grad is None:
continue
grad_norm = param.grad.norm().item()
has_nan = torch.isnan(param.grad).any().item()
has_inf = torch.isinf(param.grad).any().item()
if has_nan or has_inf or grad_norm > threshold:
results["has_explosion"] = True
results["problematic_layers"].append({
"name": name,
"grad_norm": grad_norm,
"has_nan": has_nan,
"has_inf": has_inf,
})
return results
def monitor_training_stability(
model: nn.Module,
loader,
n_steps: int = 20,
grad_threshold: float = 10.0,
) -> None:
"""Monitor gradient norms across training steps to detect explosion."""
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # no clipping
criterion = nn.BCEWithLogitsLoss()
data_iter = iter(loader)
for step in range(n_steps):
try:
X, y = next(data_iter)
except StopIteration:
break
optimizer.zero_grad()
loss = criterion(model(X).squeeze(), y.float())
loss.backward()
total_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), float("inf") # compute norm without clipping
).item()
status = "OK" if total_norm < grad_threshold else "EXPLODING"
print(f"Step {step:3d}: loss={loss.item():.4f}, grad_norm={total_norm:.2f} [{status}]")
if total_norm > 1000:
print(" WARNING: gradient explosion detected, stopping")
break
optimizer.step()Gradient Clipping
import torch
import torch.nn as nn
# ββ Clip by norm (most common) ββ
# If ||grad|| > max_norm: scale all gradients so ||grad|| = max_norm
# Preserves gradient direction, just scales down the magnitude
def train_with_grad_clip(
model: nn.Module,
loader,
lr: float = 0.01,
max_norm: float = 1.0,
n_epochs: int = 5,
) -> list[float]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.BCEWithLogitsLoss()
epoch_losses = []
for epoch in range(n_epochs):
total_loss = 0.0
clipped_steps = 0
n_steps = 0
for X, y in loader:
optimizer.zero_grad()
loss = criterion(model(X).squeeze(), y.float())
loss.backward()
# Compute gradient norm before clipping
total_norm = 0.0
for p in model.parameters():
if p.grad is not None:
total_norm += p.grad.data.norm(2).item() ** 2
total_norm = total_norm ** 0.5
# Clip gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
if total_norm > max_norm:
clipped_steps += 1
optimizer.step()
total_loss += loss.item()
n_steps += 1
avg_loss = total_loss / n_steps
epoch_losses.append(avg_loss)
print(f"Epoch {epoch+1}: loss={avg_loss:.4f}, clipped={clipped_steps}/{n_steps} steps")
return epoch_losses
# ββ Clip by value (less common) ββ
# Clamp each gradient value individually to [-clip_val, clip_val]
# Changes gradient direction β less principled than norm clipping
def clip_by_value_example(model: nn.Module, clip_val: float = 0.1) -> None:
for param in model.parameters():
if param.grad is not None:
param.grad.data.clamp_(-clip_val, clip_val)Exploding Gradients in RNNs
import torch
import torch.nn as nn
class SimpleRNN(nn.Module):
"""Vanilla RNN β notorious for exploding/vanishing gradients."""
def __init__(self, input_size: int, hidden_size: int):
super().__init__()
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False) # recurrent weights
self.W_x = nn.Linear(input_size, hidden_size)
self.tanh = nn.Tanh()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x: (batch, seq_len, input_size)"""
batch, seq_len, _ = x.shape
h = torch.zeros(batch, self.W_h.in_features)
for t in range(seq_len):
h = self.tanh(self.W_h(h) + self.W_x(x[:, t, :]))
return h
# Without clipping, gradient through 100 steps can easily explode
rnn = SimpleRNN(input_size=10, hidden_size=32)
# PyTorch's built-in RNN handles this better but still needs clipping for long sequences
import torch.nn as nn
lstm = nn.LSTM(input_size=10, hidden_size=32, num_layers=2, batch_first=True)
optimizer = torch.optim.Adam(lstm.parameters(), lr=1e-3)
criterion = nn.MSELoss()
# Training with clipping for RNN/LSTM is standard practice
X_seq = torch.randn(32, 100, 10) # (batch, seq_len=100, input_size)
y_seq = torch.randn(32, 32)
output, (h_n, c_n) = lstm(X_seq)
loss = criterion(h_n[-1], y_seq)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(lstm.parameters(), max_norm=1.0)
print(f"RNN gradient norm before clipping: {grad_norm:.4f}")
optimizer.step()Choosing the Clipping Threshold
import torch
import torch.nn as nn
def find_gradient_norm_distribution(
model: nn.Module,
loader,
n_steps: int = 100,
) -> dict:
"""Collect gradient norms to understand typical magnitude before setting threshold."""
criterion = nn.BCEWithLogitsLoss()
grad_norms = []
for i, (X, y) in enumerate(loader):
if i >= n_steps:
break
loss = criterion(model(X).squeeze(), y.float())
loss.backward()
total_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), float("inf")
).item()
grad_norms.append(total_norm)
for p in model.parameters():
if p.grad is not None:
p.grad.zero_()
import numpy as np
grad_norms = [g for g in grad_norms if g < 1e6] # exclude already-exploded
return {
"mean": float(np.mean(grad_norms)),
"median": float(np.median(grad_norms)),
"p95": float(np.percentile(grad_norms, 95)),
"p99": float(np.percentile(grad_norms, 99)),
"max": float(np.max(grad_norms)),
}
# Rule of thumb: set max_norm to the 95th percentile of observed gradient norms
# Standard default: max_norm = 1.0 for most tasks
# Transformers: max_norm = 1.0
# RNNs: max_norm = 5.0 or 1.0Interview Answer
"Exploding gradients occur when the gradient signal grows exponentially during backpropagation, causing weights to blow up and loss to diverge (often to NaN). It's most severe in RNNs where the recurrent weight matrix W_h is applied T times during backprop β if its eigenvalues exceed 1, the gradient grows as eigenvalue^T. Detection: gradient norm suddenly jumps, loss becomes NaN. The fix is gradient clipping: compute the total gradient norm across all parameters, and if it exceeds max_norm, scale all gradients down proportionally so the total norm equals max_norm. This preserves the gradient direction β unlike per-value clipping. In PyTorch:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)called after backward() and before optimizer.step(). Standard thresholds: 1.0 for most networks and Transformers, 5.0 for RNNs/LSTMs. Always log gradient norm during training β a sudden spike before a loss spike confirms explosion."
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.