LLM Training Infrastructure
How large language models are trained at scale: distributed training strategies, GPU communication, mixed precision, gradient checkpointing, and fault tolerance.
Why Standard Training Doesn't Scale
A 7B parameter LLM requires roughly 28GB just to store parameters in float32. A single A100 GPU has 80GB β enough for the model but not for gradients (another 28GB), optimizer states (Adam: 56GB for momentum and variance), and activations for a batch. Training 70B or 405B models requires distributing across hundreds of GPUs.
Memory breakdown for a 7B parameter model:
- Parameters: 7B Γ 4 bytes = 28GB (float32) or 14GB (bfloat16)
- Gradients: same size as parameters
- Adam optimizer states: 2 tensors Γ parameter size (momentum + variance)
- Activations: proportional to batch_size Γ sequence_length Γ d_model Γ n_layers
Total for mixed-precision training: approximately 7B Γ 16 bytes = 112GB β requiring at least 2 A100s even for a 7B model.
Parallelism Strategies
Data Parallelism (DP)
Each GPU holds a full model copy. Different batches are processed on different GPUs. Gradients are averaged across GPUs (all-reduce) after each backward pass.
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_ddp(rank: int, world_size: int) -> None:
"""Initialize distributed process group."""
dist.init_process_group(
backend="nccl", # NVIDIA Collective Communications Library
rank=rank,
world_size=world_size,
)
torch.cuda.set_device(rank)
def train_with_ddp(rank: int, world_size: int, model, train_dataset) -> None:
setup_ddp(rank, world_size)
model = model.to(rank)
model = DDP(model, device_ids=[rank])
sampler = torch.utils.data.DistributedSampler(
train_dataset,
num_replicas=world_size,
rank=rank,
shuffle=True,
)
loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, sampler=sampler)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for epoch in range(num_epochs):
sampler.set_epoch(epoch) # Ensures different shuffle each epoch
for batch in loader:
optimizer.zero_grad()
loss = model(**batch).loss
loss.backward()
# DDP automatically all-reduces gradients across GPUs
optimizer.step()Limitation of pure DP: Model must fit on a single GPU. For 70B+ models, this is impossible.
Tensor Parallelism (TP)
Split individual weight matrices across GPUs. Each GPU computes part of the matrix multiply:
Linear(d_model, 4*d_model) splits into:
GPU 0: Linear(d_model, d_model) β first quarter of output
GPU 1: Linear(d_model, d_model) β second quarter
GPU 2: Linear(d_model, d_model) β third quarter
GPU 3: Linear(d_model, d_model) β fourth quarter
After each matmul, an all-reduce merges the partial results.Used in Megatron-LM for training GPT-3 scale models. Requires fast GPU interconnect (NVLink) because synchronization happens within each forward pass.
# Megatron-LM column-parallel linear layer (conceptual)
class ColumnParallelLinear(nn.Module):
"""Split output dimension across tensor parallel group."""
def __init__(self, in_features: int, out_features: int, tp_rank: int, tp_size: int):
super().__init__()
# Each GPU only holds out_features/tp_size columns
self.local_out = out_features // tp_size
self.weight = nn.Parameter(torch.empty(self.local_out, in_features))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Each GPU produces a partial output
return F.linear(x, self.weight)
# Caller is responsible for all-gather across TP groupPipeline Parallelism (PP)
Split transformer layers across GPUs. GPU 0 runs layers 1-12, GPU 1 runs layers 13-24, etc. Uses micro-batches to overlap computation:
Without pipeline: GPU 0 computes β waits β GPU 1 computes β waits β ...
("bubble" time wasted)
With micro-batches (GPipe):
GPU 0: process micro-batch 1 β pass to GPU 1
GPU 0: process micro-batch 2 while GPU 1 processes micro-batch 1
GPU 0: process micro-batch 3 while GPU 1 processes micro-batch 2
...
(much less bubble time)3D Parallelism (used for models like GPT-3, Megatron-Turing NLG): combine all three strategies simultaneously. A cluster of 512 GPUs might use: 8-way tensor parallelism Γ 8-way pipeline parallelism Γ 8-way data parallelism.
ZeRO: Zero Redundancy Optimizer
DeepSpeed's ZeRO partitions optimizer state, gradients, and parameters across data-parallel GPUs:
| ZeRO Stage | What is Sharded | Memory Reduction | |---|---|---| | ZeRO-1 | Optimizer states only | 4Γ | | ZeRO-2 | Optimizer states + gradients | 8Γ | | ZeRO-3 | Optimizer states + gradients + parameters | 64Γ |
# DeepSpeed ZeRO-3 configuration
ds_config = {
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu", # Move optimizer states to CPU RAM
"pin_memory": True,
},
"offload_param": {
"device": "cpu", # Move parameters to CPU when not in use
"pin_memory": True,
},
"overlap_comm": True, # Overlap gradient all-reduce with backward pass
"contiguous_gradients": True,
"sub_group_size": 1e9, # Process parameters in chunks
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
},
"bf16": {
"enabled": True # Use bfloat16 for compute
},
"gradient_clipping": 1.0,
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 8,
}
# ZeRO-3 with CPU offload can train a 70B model on 8Γ A100 GPUs
# (that would otherwise require 32+ GPUs with ZeRO-1)Mixed Precision Training
Train in bfloat16 (or float16), keep a float32 master copy for optimizer updates:
from torch.cuda.amp import autocast, GradScaler
def train_step_mixed_precision(
model,
batch,
optimizer,
scaler: GradScaler,
) -> float:
optimizer.zero_grad()
# Forward pass in float16/bfloat16
with autocast(dtype=torch.bfloat16):
outputs = model(**batch)
loss = outputs.loss
# Scale loss to avoid underflow in float16 gradients
scaler.scale(loss).backward()
# Unscale before gradient clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Update (scales back down if overflow detected)
scaler.step(optimizer)
scaler.update()
return loss.item()
# bfloat16 is preferred over float16 for LLM training:
# - float16: 5-bit exponent (overflow risk), 10-bit mantissa
# - bfloat16: 8-bit exponent (same range as float32), 7-bit mantissa
# bfloat16 is more stable because its exponent range matches float32Gradient Checkpointing
Trade compute for memory by recomputing activations during backward pass:
from torch.utils.checkpoint import checkpoint
class CheckpointedBlock(nn.Module):
"""Transformer block with gradient checkpointing."""
def __init__(self, config):
super().__init__()
self.attn = CausalSelfAttention(config)
self.ffn = MLP(config)
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
def _forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
# During backward, this block is recomputed rather than storing activations
return checkpoint(self._forward, x, use_reentrant=False)
# Activation memory is reduced by ~n_layerΓ at the cost of ~30% more computeGradient Accumulation
Simulate large batch sizes when GPU memory is limited:
def train_with_gradient_accumulation(
model,
data_loader,
optimizer,
accumulation_steps: int = 8,
) -> None:
"""
Effective batch size = per_gpu_batch_size Γ accumulation_steps Γ n_gpus.
"""
optimizer.zero_grad()
running_loss = 0.0
for step, batch in enumerate(data_loader):
with autocast(dtype=torch.bfloat16):
loss = model(**batch).loss
# Divide by accumulation steps so gradient magnitude is consistent
loss = loss / accumulation_steps
loss.backward()
running_loss += loss.item()
if (step + 1) % accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
print(f"Step {step + 1}: loss = {running_loss * accumulation_steps:.4f}")
running_loss = 0.0Fault Tolerance and Checkpointing
Training runs lasting weeks must survive hardware failures:
import os
import torch
def save_checkpoint(
model,
optimizer,
lr_scheduler,
step: int,
loss: float,
checkpoint_dir: str,
) -> None:
"""Save training checkpoint for recovery."""
os.makedirs(checkpoint_dir, exist_ok=True)
path = os.path.join(checkpoint_dir, f"step_{step:08d}.pt")
torch.save({
"step": step,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"lr_scheduler_state_dict": lr_scheduler.state_dict(),
"loss": loss,
}, path)
# Keep only last 3 checkpoints to save disk space
checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")])
for old_ckpt in checkpoints[:-3]:
os.remove(os.path.join(checkpoint_dir, old_ckpt))
def resume_from_checkpoint(
model,
optimizer,
lr_scheduler,
checkpoint_path: str,
) -> int:
"""Resume training from a checkpoint."""
ckpt = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
lr_scheduler.load_state_dict(ckpt["lr_scheduler_state_dict"])
return ckpt["step"]Production practices:
- Save checkpoints every 500-1000 steps
- Async checkpoint writes (don't block training) using background threads
- Validate checkpoints can be loaded before deleting old ones
- Use cloud storage (S3, GCS) as secondary backup
- Track training state in a metadata file separate from model weights
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.