GPU Training in Practice
Setting up GPU training in PyTorch, multi-GPU strategies, monitoring GPU utilisation, and common pitfalls.
GPU Training Boilerplate
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on: {device}")
if device.type == "cuda":
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
# Move model to GPU
model = YourModel().to(device)
# Mixed precision training (2× speedup, 2× memory savings)
scaler = GradScaler() # scales loss to prevent fp16 underflow
optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
def train_step(model, batch, optimizer, scaler, criterion):
inputs, targets = batch
inputs = inputs.to(device, non_blocking=True) # non_blocking: async transfer
targets = targets.to(device, non_blocking=True)
optimizer.zero_grad()
with autocast(): # compute in fp16
outputs = model(inputs)
loss = criterion(outputs, targets)
scaler.scale(loss).backward() # scale loss, backward in fp16
scaler.unscale_(optimizer) # unscale gradients for clipping
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # gradient clipping
scaler.step(optimizer) # update weights in fp32
scaler.update() # adjust scale factor
return loss.item()Multi-GPU: DataParallel (Simple)
# DataParallel: split batch across GPUs, gather results on GPU 0
# Simple to use but inefficient (GPU 0 is bottleneck)
if torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs")
model = nn.DataParallel(model)
model = model.to(device)
# Usage: identical to single GPU — DataParallel handles splitting/gatheringMulti-GPU: DistributedDataParallel (Production)
# DDP: each GPU has its own process, gradients are all-reduced
# Much more efficient than DataParallel
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DistributedSampler
def setup_ddp(rank: int, world_size: int) -> None:
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup_ddp():
dist.destroy_process_group()
def train_ddp(rank: int, world_size: int, dataset, model_class):
setup_ddp(rank, world_size)
device = torch.device(f"cuda:{rank}")
model = model_class().to(device)
model = DDP(model, device_ids=[rank])
# DistributedSampler ensures each GPU gets different data
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=32, sampler=sampler, pin_memory=True
)
optimizer = AdamW(model.parameters(), lr=1e-3)
for epoch in range(10):
sampler.set_epoch(epoch) # ensures different shuffling each epoch
for batch in dataloader:
# standard training step
pass
cleanup_ddp()
# Launch: torchrun --nproc_per_node=4 train.py
# or: torch.multiprocessing.spawn(train_ddp, args=(world_size, dataset, Model), nprocs=world_size)GPU Memory Debugging
# Track memory usage
def print_gpu_memory(label: str = "") -> None:
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1e9
reserved = torch.cuda.memory_reserved() / 1e9
print(f"{label}: allocated={allocated:.2f}GB, reserved={reserved:.2f}GB")
# Memory profiler
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
profile_memory=True,
) as prof:
output = model(inputs)
print(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=10))
# Common OOM fixes
# 1. Reduce batch size
# 2. Enable gradient checkpointing
model.gradient_checkpointing_enable() # recomputes activations instead of storing
# 3. Use mixed precision
with autocast():
output = model(inputs) # fp16 operations use 2× less memory
# 4. Clear cache between folds/experiments
torch.cuda.empty_cache()
# 5. Use smaller model or LoRA fine-tuningDataLoader Optimisation
from torch.utils.data import DataLoader
# Optimised DataLoader configuration for GPU training
dataloader = DataLoader(
dataset,
batch_size=64,
shuffle=True,
num_workers=4, # CPU workers for data loading in parallel
pin_memory=True, # pin memory for faster CPU→GPU transfer
prefetch_factor=2, # prefetch 2 batches per worker
persistent_workers=True, # keep workers alive between epochs
drop_last=True, # drop last incomplete batch (helps with BatchNorm)
)
# Profiling data loading vs compute
import time
data_times = []
compute_times = []
for i, batch in enumerate(dataloader):
t0 = time.time()
inputs = batch[0].to(device, non_blocking=True)
targets = batch[1].to(device, non_blocking=True)
t1 = time.time()
with autocast():
output = model(inputs)
loss = criterion(output, targets)
loss.backward()
t2 = time.time()
data_times.append(t1 - t0)
compute_times.append(t2 - t1)
print(f"Avg data time: {sum(data_times)/len(data_times)*1000:.1f}ms")
print(f"Avg compute time: {sum(compute_times)/len(compute_times)*1000:.1f}ms")
# If data_time > compute_time: GPU is starved — increase num_workersMonitoring GPU Utilisation
# Command line monitoring
nvidia-smi # one-time snapshot
nvidia-smi -l 1 # refresh every second
nvidia-smi dmon # continuous per-GPU stats
# What to look for:
# GPU util: should be 85–99% (if < 50%, data loading is the bottleneck)
# Memory used: should be close to max (if not, increase batch size)
# Temperature: should be under 85°C# Python monitoring (during training)
import subprocess
import re
def get_gpu_stats() -> dict:
result = subprocess.run(
["nvidia-smi", "--query-gpu=utilization.gpu,memory.used,memory.total,temperature.gpu",
"--format=csv,noheader,nounits"],
capture_output=True, text=True,
)
util, mem_used, mem_total, temp = result.stdout.strip().split(", ")
return {
"gpu_util_pct": int(util),
"mem_used_gb": float(mem_used) / 1024,
"mem_total_gb": float(mem_total) / 1024,
"temperature_c": int(temp),
}Interview Answer
"GPU training in PyTorch: move model and data to GPU via .to(device); use mixed precision training (autocast + GradScaler) for 2× speedup and 2× memory savings; apply gradient clipping (clip_grad_norm_ ≤ 1.0) to prevent exploding gradients. For multi-GPU training, DataParallel is simple but inefficient (GPU 0 is the bottleneck); DistributedDataParallel (DDP) gives linear scaling — each GPU runs its own process and gradients are all-reduced. For GPU utilisation, target 85%+ — below 50% means data loading is the bottleneck, fixed by increasing DataLoader num_workers and using pin_memory=True. OOM errors: reduce batch size, enable gradient checkpointing, use LoRA for fine-tuning, or increase DataLoader efficiency."
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.