Learnixo
Back to blog
AI Systemsintermediate

Optimisers โ€” Interview Q&A

Six key interview questions on gradient descent, Adam, SGD, learning rate scheduling, and choosing optimisers for clinical AI systems.

Asma Hafeez KhanMay 22, 20266 min read
Deep LearningOptimisersAdamSGDLearning RateInterview
Share:๐•

Q1: What is gradient descent and how does it work?

Answer: Gradient descent minimises a loss function by iteratively moving the weights in the direction opposite to the gradient. The update rule is W โ† W - ฮฑยทโˆ‡L(W), where ฮฑ is the learning rate. The gradient points toward steepest increase in loss, so moving against it descends the loss surface. In practice, mini-batch gradient descent is used: gradient is estimated from 32โ€“512 samples rather than the full dataset, providing a balance between noise (regularisation) and efficiency (GPU parallelism).

Python
import torch
import torch.nn as nn

# The four lines that implement gradient descent in PyTorch
optimizer.zero_grad()          # clear accumulated gradients
loss = criterion(model(X), y)  # forward pass + loss
loss.backward()                # backward pass โ€” compute dL/dW for all params
optimizer.step()               # W โ† W - lr * grad

Q2: What is the difference between SGD and Adam?

Answer: SGD updates all weights with the same global learning rate (with optional momentum to smooth updates). Adam maintains per-weight adaptive learning rates โ€” it tracks the first moment (gradient mean) and second moment (gradient variance) of each parameter. Weights that receive large gradients get automatically reduced learning rates; rarely-updated weights (like embeddings) get larger effective rates. This makes Adam converge faster with less tuning. The trade-off: SGD with momentum sometimes reaches slightly better solutions given enough time, which is why ResNet papers train with SGD; most NLP and transformer work uses AdamW.

Python
import torch.optim as optim

# SGD: one lr for all weights
optimizer_sgd = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

# AdamW: per-weight adaptive lr + decoupled weight decay
optimizer_adamw = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

# Adam (less correct): weight decay is coupled with adaptive lr
optimizer_adam = optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)
# Prefer AdamW over Adam when using weight decay

Q3: Why does Adam use bias correction?

Answer: Adam initialises the moment estimates mโ‚€ = 0 and vโ‚€ = 0. At step t=1, the first moment estimate is mโ‚ = (1-ฮฒโ‚)ยทgโ‚. With ฮฒโ‚=0.9, mโ‚ = 0.1ยทgโ‚ โ€” far smaller than the true gradient. Without correction, early updates are heavily dampened. Bias correction divides by (1 - ฮฒ^t), which at t=1 equals 0.1, restoring the estimate to gโ‚. After many steps, ฮฒ^t โ†’ 0, so the correction becomes negligible.

Python
import numpy as np

beta1, beta2 = 0.9, 0.999

# Simulate first few steps with and without bias correction
g = np.array([1.0, -0.5, 0.8])   # constant gradient for illustration
m, v = np.zeros(3), np.zeros(3)

for t in range(1, 6):
    m = beta1 * m + (1 - beta1) * g
    v = beta2 * v + (1 - beta2) * g**2
    
    m_hat = m / (1 - beta1**t)   # bias corrected
    v_hat = v / (1 - beta2**t)
    
    print(f"t={t}: m[0]={m[0]:.4f}, m_hat[0]={m_hat[0]:.4f}, correction_factor={1/(1-beta1**t):.2f}")

Q4: When would you use SGD over Adam?

Answer: Four cases favour SGD with momentum over Adam:

  1. Computer vision from scratch โ€” training ResNet/EfficientNet; SGD+momentum often gives 0.5โ€“1% better final accuracy
  2. Small datasets โ€” Adam's adaptivity can overfit when there's little data; SGD's larger, less-targeted updates regularise implicitly
  3. Known good hyperparameters โ€” if the learning rate and schedule are already tuned, SGD's simplicity avoids Adam's potential sharp-minimum bias
  4. When you have training time to spare โ€” SGD often needs 3โ€“5ร— more epochs but can find a better solution
Python
import torch.optim as optim

# ResNet training recipe (ImageNet)
optimizer_vision = optim.SGD(
    model.parameters(),
    lr=0.1,          # high initial lr for SGD
    momentum=0.9,
    weight_decay=1e-4,
    nesterov=True,   # Nesterov momentum: look-ahead gradient
)
scheduler_vision = optim.lr_scheduler.MultiStepLR(
    optimizer_vision, milestones=[30, 60, 90], gamma=0.1
)

# Transformer / NLP recipe (BERT, GPT)
optimizer_nlp = optim.AdamW(
    model.parameters(),
    lr=3e-4,
    betas=(0.9, 0.999),
    weight_decay=0.01,
)

Q5: How do you tune the learning rate for a new clinical AI model?

Answer: A three-step process:

  1. Learning rate range test: train for ~100 steps while exponentially increasing lr from 1e-7 to 10. Plot loss vs lr. Pick the lr just before the loss diverges โ€” typically 1/10 to 1/3 of the minimum-loss lr.
  2. Start with heuristics: AdamW with lr=3e-4 for tabular/MLP, 1e-4 for transformers, 1e-5 for fine-tuning a pre-trained model. For clinical data, err on the side of conservative lr โ€” noisy labels and class imbalance amplify the sensitivity.
  3. Use warmup + cosine decay: warmup over first 5โ€“10% of steps prevents early divergence; cosine decay ensures precise convergence. Monitor validation loss โ€” if it diverges or oscillates, divide lr by 3.
Python
import torch.optim as optim
import math

def build_clinical_optimiser(
    model,
    lr: float = 3e-4,
    total_steps: int = 5000,
    warmup_fraction: float = 0.1,
) -> tuple:
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    
    warmup_steps = int(total_steps * warmup_fraction)
    
    def lr_lambda(step: int) -> float:
        if step < warmup_steps:
            return step / warmup_steps
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        return 0.5 * (1 + math.cos(math.pi * progress))
    
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    return optimizer, scheduler

Q6: What causes training loss to spike mid-training, and how do you fix it?

Answer: Mid-training loss spikes typically have four causes:

  1. Learning rate too high for late training โ€” Use a scheduler; cosine annealing prevents this by design
  2. Exploding gradients โ€” Apply gradient clipping: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
  3. Bad batch โ€” A batch with extreme outliers causes a large gradient. With mini-batch SGD this is transient; if persistent, check data preprocessing for unnormalised features
  4. Batch norm in eval during training โ€” Using model.eval() inside the training loop freezes BatchNorm statistics; ensure model.train() during training
Python
import torch
import torch.nn as nn

def safe_train_step(
    model: nn.Module,
    X: torch.Tensor,
    y: torch.Tensor,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
    max_grad_norm: float = 1.0,
) -> dict:
    model.train()  # ensure BatchNorm and Dropout are in training mode
    optimizer.zero_grad()
    
    loss = criterion(model(X).squeeze(), y)
    loss.backward()
    
    # Gradient clipping before step
    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
    
    optimizer.step()
    
    return {
        "loss": loss.item(),
        "grad_norm": grad_norm.item(),
        "clipped": grad_norm.item() > max_grad_norm,
    }

# Log grad_norm โ€” a sudden spike signals the cause of the loss spike

Interview Answer

"Gradient descent variants: full batch (exact but slow), SGD (noisy, rarely used alone), mini-batch (standard). Adam wins for most tasks due to per-weight adaptive learning rates and fast convergence; SGD+momentum is competitive for image classification given enough training time. Key Adam details: bias correction matters in early steps; use AdamW (decoupled weight decay) not Adam when regularising. Learning rate is the most critical hyperparameter โ€” use the range test, start with 3e-4 for Adam/AdamW, and always pair with a scheduler. For transformers: linear warmup + cosine decay is the standard. In clinical AI: conservative learning rates, gradient clipping (norm โ‰ค 1.0), and monitoring validation loss for spikes are non-negotiable."

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:๐•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.