Optimisers โ Interview Q&A
Six key interview questions on gradient descent, Adam, SGD, learning rate scheduling, and choosing optimisers for clinical AI systems.
Q1: What is gradient descent and how does it work?
Answer: Gradient descent minimises a loss function by iteratively moving the weights in the direction opposite to the gradient. The update rule is W โ W - ฮฑยทโL(W), where ฮฑ is the learning rate. The gradient points toward steepest increase in loss, so moving against it descends the loss surface. In practice, mini-batch gradient descent is used: gradient is estimated from 32โ512 samples rather than the full dataset, providing a balance between noise (regularisation) and efficiency (GPU parallelism).
import torch
import torch.nn as nn
# The four lines that implement gradient descent in PyTorch
optimizer.zero_grad() # clear accumulated gradients
loss = criterion(model(X), y) # forward pass + loss
loss.backward() # backward pass โ compute dL/dW for all params
optimizer.step() # W โ W - lr * gradQ2: What is the difference between SGD and Adam?
Answer: SGD updates all weights with the same global learning rate (with optional momentum to smooth updates). Adam maintains per-weight adaptive learning rates โ it tracks the first moment (gradient mean) and second moment (gradient variance) of each parameter. Weights that receive large gradients get automatically reduced learning rates; rarely-updated weights (like embeddings) get larger effective rates. This makes Adam converge faster with less tuning. The trade-off: SGD with momentum sometimes reaches slightly better solutions given enough time, which is why ResNet papers train with SGD; most NLP and transformer work uses AdamW.
import torch.optim as optim
# SGD: one lr for all weights
optimizer_sgd = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
# AdamW: per-weight adaptive lr + decoupled weight decay
optimizer_adamw = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
# Adam (less correct): weight decay is coupled with adaptive lr
optimizer_adam = optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)
# Prefer AdamW over Adam when using weight decayQ3: Why does Adam use bias correction?
Answer: Adam initialises the moment estimates mโ = 0 and vโ = 0. At step t=1, the first moment estimate is mโ = (1-ฮฒโ)ยทgโ. With ฮฒโ=0.9, mโ = 0.1ยทgโ โ far smaller than the true gradient. Without correction, early updates are heavily dampened. Bias correction divides by (1 - ฮฒ^t), which at t=1 equals 0.1, restoring the estimate to gโ. After many steps, ฮฒ^t โ 0, so the correction becomes negligible.
import numpy as np
beta1, beta2 = 0.9, 0.999
# Simulate first few steps with and without bias correction
g = np.array([1.0, -0.5, 0.8]) # constant gradient for illustration
m, v = np.zeros(3), np.zeros(3)
for t in range(1, 6):
m = beta1 * m + (1 - beta1) * g
v = beta2 * v + (1 - beta2) * g**2
m_hat = m / (1 - beta1**t) # bias corrected
v_hat = v / (1 - beta2**t)
print(f"t={t}: m[0]={m[0]:.4f}, m_hat[0]={m_hat[0]:.4f}, correction_factor={1/(1-beta1**t):.2f}")Q4: When would you use SGD over Adam?
Answer: Four cases favour SGD with momentum over Adam:
- Computer vision from scratch โ training ResNet/EfficientNet; SGD+momentum often gives 0.5โ1% better final accuracy
- Small datasets โ Adam's adaptivity can overfit when there's little data; SGD's larger, less-targeted updates regularise implicitly
- Known good hyperparameters โ if the learning rate and schedule are already tuned, SGD's simplicity avoids Adam's potential sharp-minimum bias
- When you have training time to spare โ SGD often needs 3โ5ร more epochs but can find a better solution
import torch.optim as optim
# ResNet training recipe (ImageNet)
optimizer_vision = optim.SGD(
model.parameters(),
lr=0.1, # high initial lr for SGD
momentum=0.9,
weight_decay=1e-4,
nesterov=True, # Nesterov momentum: look-ahead gradient
)
scheduler_vision = optim.lr_scheduler.MultiStepLR(
optimizer_vision, milestones=[30, 60, 90], gamma=0.1
)
# Transformer / NLP recipe (BERT, GPT)
optimizer_nlp = optim.AdamW(
model.parameters(),
lr=3e-4,
betas=(0.9, 0.999),
weight_decay=0.01,
)Q5: How do you tune the learning rate for a new clinical AI model?
Answer: A three-step process:
- Learning rate range test: train for ~100 steps while exponentially increasing lr from 1e-7 to 10. Plot loss vs lr. Pick the lr just before the loss diverges โ typically 1/10 to 1/3 of the minimum-loss lr.
- Start with heuristics: AdamW with lr=3e-4 for tabular/MLP, 1e-4 for transformers, 1e-5 for fine-tuning a pre-trained model. For clinical data, err on the side of conservative lr โ noisy labels and class imbalance amplify the sensitivity.
- Use warmup + cosine decay: warmup over first 5โ10% of steps prevents early divergence; cosine decay ensures precise convergence. Monitor validation loss โ if it diverges or oscillates, divide lr by 3.
import torch.optim as optim
import math
def build_clinical_optimiser(
model,
lr: float = 3e-4,
total_steps: int = 5000,
warmup_fraction: float = 0.1,
) -> tuple:
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
warmup_steps = int(total_steps * warmup_fraction)
def lr_lambda(step: int) -> float:
if step < warmup_steps:
return step / warmup_steps
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return 0.5 * (1 + math.cos(math.pi * progress))
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
return optimizer, schedulerQ6: What causes training loss to spike mid-training, and how do you fix it?
Answer: Mid-training loss spikes typically have four causes:
- Learning rate too high for late training โ Use a scheduler; cosine annealing prevents this by design
- Exploding gradients โ Apply gradient clipping:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) - Bad batch โ A batch with extreme outliers causes a large gradient. With mini-batch SGD this is transient; if persistent, check data preprocessing for unnormalised features
- Batch norm in eval during training โ Using
model.eval()inside the training loop freezes BatchNorm statistics; ensuremodel.train()during training
import torch
import torch.nn as nn
def safe_train_step(
model: nn.Module,
X: torch.Tensor,
y: torch.Tensor,
optimizer: torch.optim.Optimizer,
criterion: nn.Module,
max_grad_norm: float = 1.0,
) -> dict:
model.train() # ensure BatchNorm and Dropout are in training mode
optimizer.zero_grad()
loss = criterion(model(X).squeeze(), y)
loss.backward()
# Gradient clipping before step
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
optimizer.step()
return {
"loss": loss.item(),
"grad_norm": grad_norm.item(),
"clipped": grad_norm.item() > max_grad_norm,
}
# Log grad_norm โ a sudden spike signals the cause of the loss spikeInterview Answer
"Gradient descent variants: full batch (exact but slow), SGD (noisy, rarely used alone), mini-batch (standard). Adam wins for most tasks due to per-weight adaptive learning rates and fast convergence; SGD+momentum is competitive for image classification given enough training time. Key Adam details: bias correction matters in early steps; use AdamW (decoupled weight decay) not Adam when regularising. Learning rate is the most critical hyperparameter โ use the range test, start with 3e-4 for Adam/AdamW, and always pair with a scheduler. For transformers: linear warmup + cosine decay is the standard. In clinical AI: conservative learning rates, gradient clipping (norm โค 1.0), and monitoring validation loss for spikes are non-negotiable."
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.