Learnixo
Back to blog
AI Systemsintermediate

Learning Rate

The most important hyperparameter โ€” how to choose it, the learning rate range test, warmup strategies, and what happens when it's too high or too low.

Asma Hafeez KhanMay 22, 20265 min read
Deep LearningLearning RateHyperparameterTrainingInterview
Share:๐•

What the Learning Rate Controls

W โ† W - ฮฑ ยท โˆ‡L(W)

ฮฑ (learning rate) = step size along the negative gradient

Too small (ฮฑ = 1e-6):
  Steps are tiny โ†’ thousands of epochs to converge
  May get stuck in small basins or on plateaus
  Safe but impractically slow

Too large (ฮฑ = 1.0):
  Steps overshoot the minimum
  Loss oscillates or diverges (increases each epoch)
  Model fails to train

Just right (ฮฑ โ‰ˆ 1e-3 for Adam, 1e-2 for SGD):
  Loss decreases steadily
  Converges in reasonable time
  The hardest hyperparameter to get right

Learning Rate Range Test (LR Finder)

Python
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy

def lr_range_test(
    model: nn.Module,
    loader,
    criterion: nn.Module,
    start_lr: float = 1e-7,
    end_lr: float = 10.0,
    n_steps: int = 100,
    smoothing: float = 0.05,
) -> tuple[list[float], list[float]]:
    """
    Increase lr exponentially from start to end.
    Plot loss vs lr โ€” the optimal lr is just before loss starts rising.
    """
    model_copy = deepcopy(model)
    optimizer = torch.optim.SGD(model_copy.parameters(), lr=start_lr)
    
    # Exponential lr schedule
    lr_mult = (end_lr / start_lr) ** (1 / n_steps)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=lr_mult)
    
    lrs, losses = [], []
    best_loss = float("inf")
    avg_loss = 0.0
    
    data_iter = iter(loader)
    
    for step in range(n_steps):
        try:
            X, y = next(data_iter)
        except StopIteration:
            data_iter = iter(loader)
            X, y = next(data_iter)
        
        optimizer.zero_grad()
        loss = criterion(model_copy(X).squeeze(), y)
        loss.backward()
        optimizer.step()
        
        current_lr = optimizer.param_groups[0]["lr"]
        current_loss = loss.item()
        
        # Exponential smoothing to reduce noise
        avg_loss = smoothing * current_loss + (1 - smoothing) * avg_loss
        
        lrs.append(current_lr)
        losses.append(avg_loss)
        
        if current_loss < best_loss:
            best_loss = current_loss
        
        if current_loss > 4 * best_loss:
            print(f"Stopping early: loss diverged at lr={current_lr:.2e}")
            break
        
        scheduler.step()
    
    # Optimal lr โ‰ˆ 10x before minimum loss
    min_idx = np.argmin(losses)
    optimal_lr = lrs[max(0, min_idx - 10)]
    print(f"Suggested lr: {optimal_lr:.2e}")
    
    return lrs, losses

Common Starting Points

Python
import torch.nn as nn

def get_default_lr(optimizer_type: str, architecture: str) -> float:
    """
    Empirically validated starting learning rates.
    Use as starting point, then fine-tune with LR range test.
    """
    defaults = {
        # (optimizer, architecture) โ†’ lr
        ("adam",   "mlp"):         3e-4,   # Karpathy's constant
        ("adam",   "cnn"):         1e-3,
        ("adam",   "transformer"): 1e-4,   # or 3e-4 with warmup
        ("adamw",  "transformer"): 3e-4,   # GPT-style
        ("adamw",  "fine-tuning"): 1e-5,   # fine-tuning pre-trained models
        ("sgd",    "cnn"):         1e-2,
        ("sgd",    "resnet"):      1e-1,   # with cosine decay
    }
    return defaults.get((optimizer_type.lower(), architecture.lower()), 1e-3)

for (opt, arch), lr in [
    ("adam",  "transformer"), ("adamw", "fine-tuning"),
    ("sgd",   "resnet"),      ("adam",  "mlp"),
]:
    print(f"{opt:6s} + {arch:15s}: lr = {get_default_lr(opt, arch):.0e}")

Diagnosing Learning Rate Problems

Python
import torch

def diagnose_training(
    train_losses: list[float],
    val_losses: list[float],
    n_epochs: int,
) -> str:
    """Diagnose likely learning rate problem from loss curves."""
    if len(train_losses) < 3:
        return "Not enough data"
    
    # Check for divergence: loss increasing after first few epochs
    early_loss = np.mean(train_losses[:3])
    late_loss   = np.mean(train_losses[-3:])
    
    if late_loss > early_loss * 2:
        return "DIVERGING: learning rate too high โ€” try dividing by 10"
    
    # Check for slow convergence: loss barely changing
    relative_improvement = (early_loss - late_loss) / (early_loss + 1e-8)
    if relative_improvement < 0.01:
        return "STUCK: learning rate too low or plateau โ€” try lrร—10 or add momentum"
    
    # Check for oscillation: high variance in recent losses
    if len(train_losses) > 10:
        recent_std = np.std(train_losses[-10:])
        recent_mean = np.mean(train_losses[-10:])
        if recent_std / (recent_mean + 1e-8) > 0.3:
            return "OSCILLATING: learning rate slightly too high โ€” try dividing by 3"
    
    # Check for overfitting
    if val_losses and val_losses[-1] > min(val_losses) * 1.1:
        return "OVERFITTING: model fit well, consider regularisation or early stopping"
    
    return "OK: training appears healthy"

# Usage
train_losses = [1.2, 1.1, 1.0, 0.95, 0.9, 0.88, 0.86, 0.85, 0.85, 0.84]
val_losses   = [1.3, 1.2, 1.1, 1.05, 1.02, 1.01, 1.02, 1.05, 1.1, 1.15]
diagnosis = diagnose_training(train_losses, val_losses, n_epochs=10)
print(f"Diagnosis: {diagnosis}")

Warmup

Python
import torch.optim as optim

def get_linear_warmup_scheduler(
    optimizer: optim.Optimizer,
    warmup_steps: int,
    total_steps: int,
) -> optim.lr_scheduler.LambdaLR:
    """
    Linearly increase lr from 0 to target over warmup_steps,
    then linearly decay to 0 over remaining steps.
    Standard for Transformers.
    """
    def lr_lambda(current_step: int) -> float:
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return max(0.0, 1.0 - progress)
    
    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

model = nn.Linear(10, 1)
optimizer = optim.AdamW(model.parameters(), lr=3e-4)

# 1000 warmup steps, 10000 total
scheduler = get_linear_warmup_scheduler(optimizer, warmup_steps=1000, total_steps=10000)

# Training loop
for step in range(10000):
    optimizer.zero_grad()
    # ... forward, backward ...
    optimizer.step()
    scheduler.step()   # call after optimizer.step()
    
    if step in [0, 100, 999, 1000, 5000]:
        lr = optimizer.param_groups[0]["lr"]
        print(f"Step {step:5d}: lr = {lr:.2e}")

Interview Answer

"The learning rate is the single most impactful hyperparameter โ€” it controls the step size in weight space. Too small: training is impractically slow. Too large: loss diverges. The learning rate range test (gradually increase lr while watching loss) systematically finds the right scale; the optimal lr is typically one order of magnitude before loss starts rising. Defaults: 3e-4 for Adam/AdamW with MLPs, 1e-4 to 3e-4 for transformers, 1e-2 to 1e-1 for SGD on CNNs. Warmup is important for transformers โ€” starting with a large lr causes early gradient explosion before the model has sensible weights; linearly ramping from 0 over the first 1โ€“5% of training avoids this. Learning rate is almost always paired with a scheduler (cosine annealing or linear decay) to fine-tune convergence after the large early steps."

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:๐•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.