Learnixo
Back to blog
AI Systemsintermediate

GPU Training in Practice

Setting up GPU training in PyTorch, multi-GPU strategies, monitoring GPU utilisation, and common pitfalls.

Asma Hafeez KhanMay 22, 20264 min read
Deep LearningGPUTrainingCUDAMulti-GPUInterview
Share:𝕏

GPU Training Boilerplate

Python
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on: {device}")
if device.type == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Move model to GPU
model = YourModel().to(device)

# Mixed precision training ( speedup,  memory savings)
scaler = GradScaler()   # scales loss to prevent fp16 underflow

optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

def train_step(model, batch, optimizer, scaler, criterion):
    inputs, targets = batch
    inputs = inputs.to(device, non_blocking=True)   # non_blocking: async transfer
    targets = targets.to(device, non_blocking=True)
    
    optimizer.zero_grad()
    
    with autocast():                    # compute in fp16
        outputs = model(inputs)
        loss = criterion(outputs, targets)
    
    scaler.scale(loss).backward()       # scale loss, backward in fp16
    scaler.unscale_(optimizer)          # unscale gradients for clipping
    nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # gradient clipping
    scaler.step(optimizer)              # update weights in fp32
    scaler.update()                     # adjust scale factor
    
    return loss.item()

Multi-GPU: DataParallel (Simple)

Python
# DataParallel: split batch across GPUs, gather results on GPU 0
# Simple to use but inefficient (GPU 0 is bottleneck)

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model = nn.DataParallel(model)

model = model.to(device)
# Usage: identical to single GPU  DataParallel handles splitting/gathering

Multi-GPU: DistributedDataParallel (Production)

Python
# DDP: each GPU has its own process, gradients are all-reduced
# Much more efficient than DataParallel

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DistributedSampler

def setup_ddp(rank: int, world_size: int) -> None:
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup_ddp():
    dist.destroy_process_group()

def train_ddp(rank: int, world_size: int, dataset, model_class):
    setup_ddp(rank, world_size)
    
    device = torch.device(f"cuda:{rank}")
    model = model_class().to(device)
    model = DDP(model, device_ids=[rank])
    
    # DistributedSampler ensures each GPU gets different data
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=32, sampler=sampler, pin_memory=True
    )
    
    optimizer = AdamW(model.parameters(), lr=1e-3)
    
    for epoch in range(10):
        sampler.set_epoch(epoch)   # ensures different shuffling each epoch
        for batch in dataloader:
            # standard training step
            pass
    
    cleanup_ddp()

# Launch: torchrun --nproc_per_node=4 train.py
# or: torch.multiprocessing.spawn(train_ddp, args=(world_size, dataset, Model), nprocs=world_size)

GPU Memory Debugging

Python
# Track memory usage
def print_gpu_memory(label: str = "") -> None:
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1e9
        reserved  = torch.cuda.memory_reserved() / 1e9
        print(f"{label}: allocated={allocated:.2f}GB, reserved={reserved:.2f}GB")

# Memory profiler
with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CUDA],
    profile_memory=True,
) as prof:
    output = model(inputs)

print(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=10))

# Common OOM fixes
# 1. Reduce batch size
# 2. Enable gradient checkpointing
model.gradient_checkpointing_enable()  # recomputes activations instead of storing

# 3. Use mixed precision
with autocast():
    output = model(inputs)  # fp16 operations use  less memory

# 4. Clear cache between folds/experiments
torch.cuda.empty_cache()

# 5. Use smaller model or LoRA fine-tuning

DataLoader Optimisation

Python
from torch.utils.data import DataLoader

# Optimised DataLoader configuration for GPU training
dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,          # CPU workers for data loading in parallel
    pin_memory=True,        # pin memory for faster CPU→GPU transfer
    prefetch_factor=2,      # prefetch 2 batches per worker
    persistent_workers=True, # keep workers alive between epochs
    drop_last=True,         # drop last incomplete batch (helps with BatchNorm)
)

# Profiling data loading vs compute
import time

data_times = []
compute_times = []

for i, batch in enumerate(dataloader):
    t0 = time.time()
    inputs = batch[0].to(device, non_blocking=True)
    targets = batch[1].to(device, non_blocking=True)
    t1 = time.time()
    
    with autocast():
        output = model(inputs)
        loss = criterion(output, targets)
    loss.backward()
    t2 = time.time()
    
    data_times.append(t1 - t0)
    compute_times.append(t2 - t1)

print(f"Avg data time:    {sum(data_times)/len(data_times)*1000:.1f}ms")
print(f"Avg compute time: {sum(compute_times)/len(compute_times)*1000:.1f}ms")
# If data_time > compute_time: GPU is starved  increase num_workers

Monitoring GPU Utilisation

Bash
# Command line monitoring
nvidia-smi                   # one-time snapshot
nvidia-smi -l 1              # refresh every second
nvidia-smi dmon              # continuous per-GPU stats

# What to look for:
#   GPU util: should be 85–99% (if < 50%, data loading is the bottleneck)
#   Memory used: should be close to max (if not, increase batch size)
#   Temperature: should be under 85°C
Python
# Python monitoring (during training)
import subprocess
import re

def get_gpu_stats() -> dict:
    result = subprocess.run(
        ["nvidia-smi", "--query-gpu=utilization.gpu,memory.used,memory.total,temperature.gpu",
         "--format=csv,noheader,nounits"],
        capture_output=True, text=True,
    )
    util, mem_used, mem_total, temp = result.stdout.strip().split(", ")
    return {
        "gpu_util_pct": int(util),
        "mem_used_gb": float(mem_used) / 1024,
        "mem_total_gb": float(mem_total) / 1024,
        "temperature_c": int(temp),
    }

Interview Answer

"GPU training in PyTorch: move model and data to GPU via .to(device); use mixed precision training (autocast + GradScaler) for 2× speedup and 2× memory savings; apply gradient clipping (clip_grad_norm_ ≤ 1.0) to prevent exploding gradients. For multi-GPU training, DataParallel is simple but inefficient (GPU 0 is the bottleneck); DistributedDataParallel (DDP) gives linear scaling — each GPU runs its own process and gradients are all-reduced. For GPU utilisation, target 85%+ — below 50% means data loading is the bottleneck, fixed by increasing DataLoader num_workers and using pin_memory=True. OOM errors: reduce batch size, enable gradient checkpointing, use LoRA for fine-tuning, or increase DataLoader efficiency."

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.