Learnixo
Back to blog
AI Systemsadvanced

LLM Training Infrastructure

How large language models are trained at scale: distributed training strategies, GPU communication, mixed precision, gradient checkpointing, and fault tolerance.

Asma Hafeez KhanMay 16, 20266 min read
LLMTraining InfrastructureDistributed TrainingGPUDeepSpeed
Share:𝕏

Why Standard Training Doesn't Scale

A 7B parameter LLM requires roughly 28GB just to store parameters in float32. A single A100 GPU has 80GB β€” enough for the model but not for gradients (another 28GB), optimizer states (Adam: 56GB for momentum and variance), and activations for a batch. Training 70B or 405B models requires distributing across hundreds of GPUs.

Memory breakdown for a 7B parameter model:

  • Parameters: 7B Γ— 4 bytes = 28GB (float32) or 14GB (bfloat16)
  • Gradients: same size as parameters
  • Adam optimizer states: 2 tensors Γ— parameter size (momentum + variance)
  • Activations: proportional to batch_size Γ— sequence_length Γ— d_model Γ— n_layers

Total for mixed-precision training: approximately 7B Γ— 16 bytes = 112GB β€” requiring at least 2 A100s even for a 7B model.


Parallelism Strategies

Data Parallelism (DP)

Each GPU holds a full model copy. Different batches are processed on different GPUs. Gradients are averaged across GPUs (all-reduce) after each backward pass.

Python
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup_ddp(rank: int, world_size: int) -> None:
    """Initialize distributed process group."""
    dist.init_process_group(
        backend="nccl",   # NVIDIA Collective Communications Library
        rank=rank,
        world_size=world_size,
    )
    torch.cuda.set_device(rank)

def train_with_ddp(rank: int, world_size: int, model, train_dataset) -> None:
    setup_ddp(rank, world_size)

    model = model.to(rank)
    model = DDP(model, device_ids=[rank])

    sampler = torch.utils.data.DistributedSampler(
        train_dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True,
    )
    loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, sampler=sampler)

    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

    for epoch in range(num_epochs):
        sampler.set_epoch(epoch)  # Ensures different shuffle each epoch
        for batch in loader:
            optimizer.zero_grad()
            loss = model(**batch).loss
            loss.backward()
            # DDP automatically all-reduces gradients across GPUs
            optimizer.step()

Limitation of pure DP: Model must fit on a single GPU. For 70B+ models, this is impossible.


Tensor Parallelism (TP)

Split individual weight matrices across GPUs. Each GPU computes part of the matrix multiply:

Linear(d_model, 4*d_model) splits into:
  GPU 0: Linear(d_model, d_model)    β€” first quarter of output
  GPU 1: Linear(d_model, d_model)    β€” second quarter
  GPU 2: Linear(d_model, d_model)    β€” third quarter
  GPU 3: Linear(d_model, d_model)    β€” fourth quarter

After each matmul, an all-reduce merges the partial results.

Used in Megatron-LM for training GPT-3 scale models. Requires fast GPU interconnect (NVLink) because synchronization happens within each forward pass.

Python
# Megatron-LM column-parallel linear layer (conceptual)
class ColumnParallelLinear(nn.Module):
    """Split output dimension across tensor parallel group."""

    def __init__(self, in_features: int, out_features: int, tp_rank: int, tp_size: int):
        super().__init__()
        # Each GPU only holds out_features/tp_size columns
        self.local_out = out_features // tp_size
        self.weight = nn.Parameter(torch.empty(self.local_out, in_features))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Each GPU produces a partial output
        return F.linear(x, self.weight)
        # Caller is responsible for all-gather across TP group

Pipeline Parallelism (PP)

Split transformer layers across GPUs. GPU 0 runs layers 1-12, GPU 1 runs layers 13-24, etc. Uses micro-batches to overlap computation:

Without pipeline: GPU 0 computes β†’ waits β†’ GPU 1 computes β†’ waits β†’ ...
                  ("bubble" time wasted)

With micro-batches (GPipe):
  GPU 0: process micro-batch 1 β†’ pass to GPU 1
  GPU 0: process micro-batch 2 while GPU 1 processes micro-batch 1
  GPU 0: process micro-batch 3 while GPU 1 processes micro-batch 2
         ...
  (much less bubble time)

3D Parallelism (used for models like GPT-3, Megatron-Turing NLG): combine all three strategies simultaneously. A cluster of 512 GPUs might use: 8-way tensor parallelism Γ— 8-way pipeline parallelism Γ— 8-way data parallelism.


ZeRO: Zero Redundancy Optimizer

DeepSpeed's ZeRO partitions optimizer state, gradients, and parameters across data-parallel GPUs:

| ZeRO Stage | What is Sharded | Memory Reduction | |---|---|---| | ZeRO-1 | Optimizer states only | 4Γ— | | ZeRO-2 | Optimizer states + gradients | 8Γ— | | ZeRO-3 | Optimizer states + gradients + parameters | 64Γ— |

Python
# DeepSpeed ZeRO-3 configuration
ds_config = {
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",          # Move optimizer states to CPU RAM
            "pin_memory": True,
        },
        "offload_param": {
            "device": "cpu",          # Move parameters to CPU when not in use
            "pin_memory": True,
        },
        "overlap_comm": True,         # Overlap gradient all-reduce with backward pass
        "contiguous_gradients": True,
        "sub_group_size": 1e9,        # Process parameters in chunks
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
    },
    "bf16": {
        "enabled": True              # Use bfloat16 for compute
    },
    "gradient_clipping": 1.0,
    "train_micro_batch_size_per_gpu": 4,
    "gradient_accumulation_steps": 8,
}

# ZeRO-3 with CPU offload can train a 70B model on 8Γ— A100 GPUs
# (that would otherwise require 32+ GPUs with ZeRO-1)

Mixed Precision Training

Train in bfloat16 (or float16), keep a float32 master copy for optimizer updates:

Python
from torch.cuda.amp import autocast, GradScaler

def train_step_mixed_precision(
    model,
    batch,
    optimizer,
    scaler: GradScaler,
) -> float:
    optimizer.zero_grad()

    # Forward pass in float16/bfloat16
    with autocast(dtype=torch.bfloat16):
        outputs = model(**batch)
        loss = outputs.loss

    # Scale loss to avoid underflow in float16 gradients
    scaler.scale(loss).backward()

    # Unscale before gradient clipping
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    # Update (scales back down if overflow detected)
    scaler.step(optimizer)
    scaler.update()

    return loss.item()

# bfloat16 is preferred over float16 for LLM training:
# - float16: 5-bit exponent (overflow risk), 10-bit mantissa
# - bfloat16: 8-bit exponent (same range as float32), 7-bit mantissa
# bfloat16 is more stable because its exponent range matches float32

Gradient Checkpointing

Trade compute for memory by recomputing activations during backward pass:

Python
from torch.utils.checkpoint import checkpoint

class CheckpointedBlock(nn.Module):
    """Transformer block with gradient checkpointing."""

    def __init__(self, config):
        super().__init__()
        self.attn = CausalSelfAttention(config)
        self.ffn = MLP(config)
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)

    def _forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # During backward, this block is recomputed rather than storing activations
        return checkpoint(self._forward, x, use_reentrant=False)

# Activation memory is reduced by ~n_layerΓ— at the cost of ~30% more compute

Gradient Accumulation

Simulate large batch sizes when GPU memory is limited:

Python
def train_with_gradient_accumulation(
    model,
    data_loader,
    optimizer,
    accumulation_steps: int = 8,
) -> None:
    """
    Effective batch size = per_gpu_batch_size Γ— accumulation_steps Γ— n_gpus.
    """
    optimizer.zero_grad()
    running_loss = 0.0

    for step, batch in enumerate(data_loader):
        with autocast(dtype=torch.bfloat16):
            loss = model(**batch).loss
            # Divide by accumulation steps so gradient magnitude is consistent
            loss = loss / accumulation_steps

        loss.backward()
        running_loss += loss.item()

        if (step + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

            print(f"Step {step + 1}: loss = {running_loss * accumulation_steps:.4f}")
            running_loss = 0.0

Fault Tolerance and Checkpointing

Training runs lasting weeks must survive hardware failures:

Python
import os
import torch

def save_checkpoint(
    model,
    optimizer,
    lr_scheduler,
    step: int,
    loss: float,
    checkpoint_dir: str,
) -> None:
    """Save training checkpoint for recovery."""
    os.makedirs(checkpoint_dir, exist_ok=True)
    path = os.path.join(checkpoint_dir, f"step_{step:08d}.pt")

    torch.save({
        "step": step,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "lr_scheduler_state_dict": lr_scheduler.state_dict(),
        "loss": loss,
    }, path)

    # Keep only last 3 checkpoints to save disk space
    checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")])
    for old_ckpt in checkpoints[:-3]:
        os.remove(os.path.join(checkpoint_dir, old_ckpt))


def resume_from_checkpoint(
    model,
    optimizer,
    lr_scheduler,
    checkpoint_path: str,
) -> int:
    """Resume training from a checkpoint."""
    ckpt = torch.load(checkpoint_path, map_location="cpu")
    model.load_state_dict(ckpt["model_state_dict"])
    optimizer.load_state_dict(ckpt["optimizer_state_dict"])
    lr_scheduler.load_state_dict(ckpt["lr_scheduler_state_dict"])
    return ckpt["step"]

Production practices:

  • Save checkpoints every 500-1000 steps
  • Async checkpoint writes (don't block training) using background threads
  • Validate checkpoints can be loaded before deleting old ones
  • Use cloud storage (S3, GCS) as secondary backup
  • Track training state in a metadata file separate from model weights

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.