Deep Learning for AI Interviews · Lesson 20 of 56
SGD vs Mini-Batch vs Full-Batch Gradient Descent
Three Variants
Batch Gradient Descent (Full GD):
Compute gradient over ALL training samples.
One weight update per epoch.
Exact gradient — deterministic, stable.
Impractical for large datasets (can't fit all data in memory).
Stochastic Gradient Descent (SGD, batch_size=1):
Compute gradient from ONE random sample.
Update weights after every sample.
Very noisy gradient — high variance updates.
Fast per update but erratic path; rarely used in practice.
Mini-batch SGD (the standard):
Compute gradient from a small batch (32–256 samples).
Update weights after every batch.
Best of both worlds: GPU parallelism + regularising noise.
"SGD" in PyTorch refers to this when batch_size > 1.Side-by-Side Comparison
import torch
import torch.nn as nn
import numpy as np
def make_dataset(n: int = 2000):
X = torch.randn(n, 10)
w_true = torch.randn(10)
y = (X @ w_true + 0.1 * torch.randn(n)).sign().clamp(min=0)
return X, y
X, y = make_dataset(2000)
def train_variant(
X: torch.Tensor,
y: torch.Tensor,
batch_size: int,
lr: float = 0.01,
n_epochs: int = 20,
name: str = "",
) -> list[float]:
model = nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
criterion = nn.BCEWithLogitsLoss()
epoch_losses = []
for epoch in range(n_epochs):
# Shuffle
perm = torch.randperm(len(X))
X_shuf, y_shuf = X[perm], y[perm]
batch_losses = []
for i in range(0, len(X), batch_size):
X_batch = X_shuf[i:i + batch_size]
y_batch = y_shuf[i:i + batch_size]
optimizer.zero_grad()
loss = criterion(model(X_batch).squeeze(), y_batch)
loss.backward()
optimizer.step()
batch_losses.append(loss.item())
epoch_losses.append(np.mean(batch_losses))
return epoch_losses
# Batch GD: use full dataset as one batch
batch_losses = train_variant(X, y, batch_size=2000, name="Batch GD")
# Mini-batch: standard batch size
minib_losses = train_variant(X, y, batch_size=64, name="Mini-batch (64)")
# Stochastic: one sample at a time
stoch_losses = train_variant(X, y, batch_size=1, name="SGD (1)")
for name, losses in [("Batch GD", batch_losses), ("Mini-batch", minib_losses), ("SGD-1", stoch_losses)]:
print(f"{name:12s}: final loss = {losses[-1]:.4f}, variance = {np.std(losses):.4f}")The Noise That Helps
SGD noise acts as implicit regularisation:
1. Escapes sharp local minima
Exact gradient descent can get stuck in a sharp valley.
Noisy updates can "bounce" out toward flatter minima.
2. Flat minima generalise better
Sharp minimum: tiny weight perturbation → large loss spike
Flat minimum: robust to perturbations → robust to distribution shift
3. Saddle point escape
Pure GD can stall at saddle points (gradient ≈ 0).
Noise breaks symmetry and escapes.
Controlled by: batch_size (smaller = more noise)
batch_size=32: high noise, flat minima, good for generalisation
batch_size=512: low noise, faster training, may overfit
Rule of thumb: scale lr linearly with batch size (linear scaling rule)from torch.utils.data import DataLoader, TensorDataset
def linear_scaling_rule(base_lr: float, base_batch: int, new_batch: int) -> float:
"""Scale learning rate proportionally when changing batch size."""
return base_lr * (new_batch / base_batch)
# Base: lr=0.01 at batch_size=64
base_lr, base_batch = 0.01, 64
for batch_size in [32, 64, 128, 256, 512]:
lr = linear_scaling_rule(base_lr, base_batch, batch_size)
print(f"batch_size={batch_size:4d}: lr={lr:.5f}")PyTorch DataLoader for Mini-Batch
from torch.utils.data import DataLoader, TensorDataset
dataset = TensorDataset(X, y)
# Standard mini-batch setup
train_loader = DataLoader(
dataset,
batch_size=64,
shuffle=True, # shuffle every epoch
drop_last=True, # discard last partial batch (helps with BatchNorm)
num_workers=0, # 0 for in-process (safe on Windows)
)
def train_epoch(
model: nn.Module,
loader: DataLoader,
optimizer: torch.optim.Optimizer,
criterion: nn.Module,
) -> float:
model.train()
total_loss = 0.0
for X_batch, y_batch in loader:
optimizer.zero_grad()
loss = criterion(model(X_batch).squeeze(), y_batch)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(loader)
model = nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.BCEWithLogitsLoss()
for epoch in range(5):
avg_loss = train_epoch(model, train_loader, optimizer, criterion)
print(f"Epoch {epoch+1}: loss = {avg_loss:.4f}")Batch Size Trade-offs
Batch size | Gradient quality | GPU efficiency | Generalisation | Memory
-----------|-----------------|----------------|----------------|-------
1 | Very noisy | Poor | Can be good | Minimal
32 | Moderate noise | Moderate | Good | Low
128 | Low noise | Good | Moderate | Medium
512 | Very clean | Excellent | May overfit | High
Full data | Exact | Varies | Risk of sharp minima | Very high
Practical defaults:
Computer vision: 32–256 (images are large)
NLP/tabular: 32–512
LLM fine-tuning: 4–32 (model is huge; effective batch via gradient accumulation)
Gradient accumulation to simulate large batch with limited GPU memory:accumulation_steps = 8 # simulate batch_size = 64 × 8 = 512
optimizer.zero_grad()
for i, (X_batch, y_batch) in enumerate(train_loader):
loss = criterion(model(X_batch).squeeze(), y_batch)
loss = loss / accumulation_steps # scale loss to average correctly
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()Interview Answer
"Three gradient descent variants: (1) Full batch GD — exact gradient but impractical for large datasets and prone to sharp minima; (2) SGD (batch_size=1) — very noisy updates, rarely used; (3) mini-batch SGD — computes gradient on 32–512 samples, the standard in deep learning. Mini-batch wins because it utilises GPU parallelism efficiently and its noise implicitly regularises by favouring flat minima over sharp ones. The linear scaling rule says: when you double the batch size, double the learning rate to maintain equivalent dynamics. For very large models with limited GPU memory, gradient accumulation simulates large batches by accumulating gradients over N forward passes before stepping. The noise in mini-batch SGD helps escape saddle points and sharp local minima — it's a feature, not a bug."