Learnixo
Back to blog
AI Systemsadvanced

LLM Training Objectives: From Next-Token to Alignment

The full training objective stack for large language models: next-token prediction loss, cross-entropy mechanics, data weighting, and how pretraining creates the base for alignment.

Asma Hafeez KhanMay 16, 20266 min read
LLMTrainingCross-EntropyPretrainingObjective
Share:š•

The Core Objective: Next-Token Prediction

Every GPT-style LLM is trained to predict the next token given all previous tokens. Given a sequence of tokens [t₁, tā‚‚, ..., tā‚™], the model learns to maximize:

P(t₁, tā‚‚, ..., tā‚™) = āˆ P(tįµ¢ | t₁, ..., tᵢ₋₁)

This factorization is exact (it's the chain rule of probability) — next-token prediction is not an approximation. It's a complete generative model of text.

Why this objective works:

  • To predict the next token well, the model must understand syntax, semantics, world knowledge, and reasoning
  • The training signal is dense: every token position provides a gradient update
  • Data is unlimited: any text is self-supervised training data
  • Compression is intelligence: a model that predicts well has learned structure in language

Cross-Entropy Loss: The Math

The loss at position i is the negative log probability of the correct token:

ā„’įµ¢ = -log P(tįµ¢ | t₁, ..., tᵢ₋₁)

Over a sequence of length T:

ā„’ = -(1/T) Σᵢ log P(tįµ¢ | t₁, ..., tᵢ₋₁)

Perplexity is the exponentiated average loss — a more interpretable metric:

Perplexity = exp(ā„’) = exp(-(1/T) Σᵢ log P(tįµ¢ | context))

A perplexity of 20 means the model is, on average, as uncertain as if choosing uniformly among 20 equally likely tokens.

Python
import torch
import torch.nn.functional as F

def compute_lm_loss(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """
    Compute cross-entropy loss for language modeling.
    
    logits: (B, T, vocab_size) — model predictions
    targets: (B, T) — next token indices (input shifted by 1)
    """
    B, T, V = logits.shape
    # Flatten batch and sequence dimensions for cross_entropy
    loss = F.cross_entropy(
        logits.view(B * T, V),
        targets.view(B * T),
        ignore_index=-100,  # Padding tokens don't contribute to loss
    )
    return loss

def compute_perplexity(model, data_loader, device) -> float:
    """Compute perplexity over a dataset."""
    model.eval()
    total_loss = 0.0
    total_tokens = 0

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch["input_ids"].to(device)
            # Shift: targets are input shifted left by 1
            inputs = input_ids[:, :-1]
            targets = input_ids[:, 1:]

            logits, _ = model(inputs)
            loss = compute_lm_loss(logits, targets)
            n_tokens = (targets != -100).sum().item()
            total_loss += loss.item() * n_tokens
            total_tokens += n_tokens

    avg_loss = total_loss / total_tokens
    return torch.exp(torch.tensor(avg_loss)).item()

Training Data Construction

How raw text becomes training batches:

Python
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer

def build_training_data(
    texts: list[str],
    tokenizer,
    context_length: int = 2048,
) -> np.ndarray:
    """
    Tokenize and pack texts into fixed-length training sequences.
    Returns a flat array of token ids.
    """
    all_tokens = []

    for text in texts:
        tokens = tokenizer.encode(text)
        # Add EOS token between documents
        tokens.append(tokenizer.eos_token_id)
        all_tokens.extend(tokens)

    # Convert to numpy array
    all_tokens = np.array(all_tokens, dtype=np.uint16)

    # Shard into context_length chunks (no padding waste)
    n_sequences = len(all_tokens) // context_length
    all_tokens = all_tokens[:n_sequences * context_length]
    all_tokens = all_tokens.reshape(n_sequences, context_length)

    return all_tokens


def make_batch(token_array: np.ndarray, batch_size: int, device: str) -> dict:
    """Sample a random batch from the pretraining corpus."""
    indices = np.random.randint(0, len(token_array), size=batch_size)
    x = torch.tensor(token_array[indices], dtype=torch.long, device=device)

    # Input is all tokens; target is shifted by 1 position
    inputs = x[:, :-1]    # t₁ to t_{T-1}
    targets = x[:, 1:]    # tā‚‚ to t_T

    return {"input_ids": inputs, "labels": targets}

Data Mixture and Weighting

Modern LLMs don't train on equal proportions of all available data. The mixture significantly affects capabilities:

Python
DATA_MIXTURE = {
    # GPT-3 approximate mixture
    "common_crawl": 0.60,      # Filtered web data (largest but noisiest)
    "books": 0.16,             # Books (long-form reasoning, coherence)
    "wikipedia": 0.03,         # High-quality factual knowledge
    "github": 0.03,            # Code (improves logical reasoning)
    "webtext": 0.22,           # Curated high-quality web text
}

# LLaMA-3 approximate mixture
LLAMA3_MIXTURE = {
    "web": 0.82,               # Heavily filtered CommonCrawl
    "code": 0.08,              # Code from GitHub, Stack Overflow
    "math": 0.04,              # Mathematical text
    "books": 0.03,             # Published books
    "other": 0.03,
}

def weighted_dataset_sampler(datasets: dict, weights: dict):
    """Sample from multiple datasets according to mixture weights."""
    import random

    dataset_list = list(datasets.keys())
    weight_list = [weights[name] for name in dataset_list]

    # Normalize weights
    total = sum(weight_list)
    weight_list = [w / total for w in weight_list]

    while True:
        # Sample dataset according to mixture weights
        dataset_name = random.choices(dataset_list, weights=weight_list, k=1)[0]
        yield next(iter(datasets[dataset_name]))

Key findings on data mixture:

  • Code improves reasoning on non-code tasks (chain-of-thought quality)
  • Math data improves quantitative reasoning
  • Over-representing low-quality web text hurts coherence
  • The optimal mixture is discovered empirically through ablations

Token Weighting: Ignoring Special Tokens

Not all positions should contribute equally to the loss:

Python
def build_labels_for_chat(
    conversation: list[dict],
    tokenizer,
    only_train_on_assistant: bool = True,
) -> tuple[list[int], list[int]]:
    """
    Build input_ids and labels for chat fine-tuning.
    When only_train_on_assistant=True, mask user/system tokens with -100.
    """
    input_ids = []
    labels = []

    for message in conversation:
        role_tokens = tokenizer.encode(f"<|{message['role']}|>\n", add_special_tokens=False)
        content_tokens = tokenizer.encode(message["content"] + "\n", add_special_tokens=False)
        eot = [tokenizer.eos_token_id]

        msg_tokens = role_tokens + content_tokens + eot

        input_ids.extend(msg_tokens)

        if only_train_on_assistant and message["role"] != "assistant":
            # Mask non-assistant tokens — model sees them but doesn't learn from them
            labels.extend([-100] * len(msg_tokens))
        else:
            labels.extend(msg_tokens)

    return input_ids, labels

The Training Objective Stack

Modern LLMs go through multiple training stages, each with a different objective:

| Stage | Objective | Data | Purpose | |---|---|---|---| | Pretraining | Next-token prediction | Trillions of tokens from web/books/code | General language understanding | | SFT | Next-token (assistant only) | Instruction-response pairs | Learn to follow instructions | | RLHF/PPO | Reward maximization | Human preference pairs | Align with human values | | DPO | Preference likelihood | Chosen/rejected pairs | Alignment without RL complexity |

Each stage builds on the previous: you cannot do SFT without a pretrained base, and alignment without instruction tuning produces poor results.


Loss Curves and Training Diagnostics

Python
import matplotlib.pyplot as plt

def plot_training_metrics(log_file: str) -> None:
    """Plot training and validation loss curves from a training log."""
    import json

    train_steps, train_losses = [], []
    val_steps, val_losses = [], []

    with open(log_file) as f:
        for line in f:
            entry = json.loads(line)
            if "train_loss" in entry:
                train_steps.append(entry["step"])
                train_losses.append(entry["train_loss"])
            if "val_loss" in entry:
                val_steps.append(entry["step"])
                val_losses.append(entry["val_loss"])

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    ax1.plot(train_steps, train_losses, label="Train", alpha=0.7)
    ax1.plot(val_steps, val_losses, label="Validation", linewidth=2)
    ax1.set_xlabel("Steps")
    ax1.set_ylabel("Cross-Entropy Loss")
    ax1.set_title("Training Progress")
    ax1.legend()

    # Perplexity (exp of loss)
    import math
    val_ppx = [math.exp(l) for l in val_losses]
    ax2.plot(val_steps, val_ppx)
    ax2.set_xlabel("Steps")
    ax2.set_ylabel("Perplexity")
    ax2.set_title("Validation Perplexity")

    plt.tight_layout()
    plt.savefig("training_curves.png", dpi=150)

Diagnosing training problems from loss curves:

  • Loss plateaus early: Learning rate too low, or dataset is too small/repetitive
  • Loss spikes and recovers: Gradient spikes from poorly formatted data in the batch
  • Train and val loss diverge: Overfitting — reduce model size, add dropout, or get more data
  • Loss doesn't decrease: LR too high (exploding gradients), bad initialization, or bug in data loading
  • Loss decreases then spikes sharply: LR schedule issue or a corrupted data shard encountered

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:š•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.