Learning Rate
The most important hyperparameter โ how to choose it, the learning rate range test, warmup strategies, and what happens when it's too high or too low.
What the Learning Rate Controls
W โ W - ฮฑ ยท โL(W)
ฮฑ (learning rate) = step size along the negative gradient
Too small (ฮฑ = 1e-6):
Steps are tiny โ thousands of epochs to converge
May get stuck in small basins or on plateaus
Safe but impractically slow
Too large (ฮฑ = 1.0):
Steps overshoot the minimum
Loss oscillates or diverges (increases each epoch)
Model fails to train
Just right (ฮฑ โ 1e-3 for Adam, 1e-2 for SGD):
Loss decreases steadily
Converges in reasonable time
The hardest hyperparameter to get rightLearning Rate Range Test (LR Finder)
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy
def lr_range_test(
model: nn.Module,
loader,
criterion: nn.Module,
start_lr: float = 1e-7,
end_lr: float = 10.0,
n_steps: int = 100,
smoothing: float = 0.05,
) -> tuple[list[float], list[float]]:
"""
Increase lr exponentially from start to end.
Plot loss vs lr โ the optimal lr is just before loss starts rising.
"""
model_copy = deepcopy(model)
optimizer = torch.optim.SGD(model_copy.parameters(), lr=start_lr)
# Exponential lr schedule
lr_mult = (end_lr / start_lr) ** (1 / n_steps)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=lr_mult)
lrs, losses = [], []
best_loss = float("inf")
avg_loss = 0.0
data_iter = iter(loader)
for step in range(n_steps):
try:
X, y = next(data_iter)
except StopIteration:
data_iter = iter(loader)
X, y = next(data_iter)
optimizer.zero_grad()
loss = criterion(model_copy(X).squeeze(), y)
loss.backward()
optimizer.step()
current_lr = optimizer.param_groups[0]["lr"]
current_loss = loss.item()
# Exponential smoothing to reduce noise
avg_loss = smoothing * current_loss + (1 - smoothing) * avg_loss
lrs.append(current_lr)
losses.append(avg_loss)
if current_loss < best_loss:
best_loss = current_loss
if current_loss > 4 * best_loss:
print(f"Stopping early: loss diverged at lr={current_lr:.2e}")
break
scheduler.step()
# Optimal lr โ 10x before minimum loss
min_idx = np.argmin(losses)
optimal_lr = lrs[max(0, min_idx - 10)]
print(f"Suggested lr: {optimal_lr:.2e}")
return lrs, lossesCommon Starting Points
import torch.nn as nn
def get_default_lr(optimizer_type: str, architecture: str) -> float:
"""
Empirically validated starting learning rates.
Use as starting point, then fine-tune with LR range test.
"""
defaults = {
# (optimizer, architecture) โ lr
("adam", "mlp"): 3e-4, # Karpathy's constant
("adam", "cnn"): 1e-3,
("adam", "transformer"): 1e-4, # or 3e-4 with warmup
("adamw", "transformer"): 3e-4, # GPT-style
("adamw", "fine-tuning"): 1e-5, # fine-tuning pre-trained models
("sgd", "cnn"): 1e-2,
("sgd", "resnet"): 1e-1, # with cosine decay
}
return defaults.get((optimizer_type.lower(), architecture.lower()), 1e-3)
for (opt, arch), lr in [
("adam", "transformer"), ("adamw", "fine-tuning"),
("sgd", "resnet"), ("adam", "mlp"),
]:
print(f"{opt:6s} + {arch:15s}: lr = {get_default_lr(opt, arch):.0e}")Diagnosing Learning Rate Problems
import torch
def diagnose_training(
train_losses: list[float],
val_losses: list[float],
n_epochs: int,
) -> str:
"""Diagnose likely learning rate problem from loss curves."""
if len(train_losses) < 3:
return "Not enough data"
# Check for divergence: loss increasing after first few epochs
early_loss = np.mean(train_losses[:3])
late_loss = np.mean(train_losses[-3:])
if late_loss > early_loss * 2:
return "DIVERGING: learning rate too high โ try dividing by 10"
# Check for slow convergence: loss barely changing
relative_improvement = (early_loss - late_loss) / (early_loss + 1e-8)
if relative_improvement < 0.01:
return "STUCK: learning rate too low or plateau โ try lrร10 or add momentum"
# Check for oscillation: high variance in recent losses
if len(train_losses) > 10:
recent_std = np.std(train_losses[-10:])
recent_mean = np.mean(train_losses[-10:])
if recent_std / (recent_mean + 1e-8) > 0.3:
return "OSCILLATING: learning rate slightly too high โ try dividing by 3"
# Check for overfitting
if val_losses and val_losses[-1] > min(val_losses) * 1.1:
return "OVERFITTING: model fit well, consider regularisation or early stopping"
return "OK: training appears healthy"
# Usage
train_losses = [1.2, 1.1, 1.0, 0.95, 0.9, 0.88, 0.86, 0.85, 0.85, 0.84]
val_losses = [1.3, 1.2, 1.1, 1.05, 1.02, 1.01, 1.02, 1.05, 1.1, 1.15]
diagnosis = diagnose_training(train_losses, val_losses, n_epochs=10)
print(f"Diagnosis: {diagnosis}")Warmup
import torch.optim as optim
def get_linear_warmup_scheduler(
optimizer: optim.Optimizer,
warmup_steps: int,
total_steps: int,
) -> optim.lr_scheduler.LambdaLR:
"""
Linearly increase lr from 0 to target over warmup_steps,
then linearly decay to 0 over remaining steps.
Standard for Transformers.
"""
def lr_lambda(current_step: int) -> float:
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
return max(0.0, 1.0 - progress)
return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
model = nn.Linear(10, 1)
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
# 1000 warmup steps, 10000 total
scheduler = get_linear_warmup_scheduler(optimizer, warmup_steps=1000, total_steps=10000)
# Training loop
for step in range(10000):
optimizer.zero_grad()
# ... forward, backward ...
optimizer.step()
scheduler.step() # call after optimizer.step()
if step in [0, 100, 999, 1000, 5000]:
lr = optimizer.param_groups[0]["lr"]
print(f"Step {step:5d}: lr = {lr:.2e}")Interview Answer
"The learning rate is the single most impactful hyperparameter โ it controls the step size in weight space. Too small: training is impractically slow. Too large: loss diverges. The learning rate range test (gradually increase lr while watching loss) systematically finds the right scale; the optimal lr is typically one order of magnitude before loss starts rising. Defaults: 3e-4 for Adam/AdamW with MLPs, 1e-4 to 3e-4 for transformers, 1e-2 to 1e-1 for SGD on CNNs. Warmup is important for transformers โ starting with a large lr causes early gradient explosion before the model has sensible weights; linearly ramping from 0 over the first 1โ5% of training avoids this. Learning rate is almost always paired with a scheduler (cosine annealing or linear decay) to fine-tune convergence after the large early steps."
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.