LLM Training Objectives: From Next-Token to Alignment
The full training objective stack for large language models: next-token prediction loss, cross-entropy mechanics, data weighting, and how pretraining creates the base for alignment.
The Core Objective: Next-Token Prediction
Every GPT-style LLM is trained to predict the next token given all previous tokens. Given a sequence of tokens [tā, tā, ..., tā], the model learns to maximize:
P(tā, tā, ..., tā) = ā P(tįµ¢ | tā, ..., tįµ¢āā)This factorization is exact (it's the chain rule of probability) ā next-token prediction is not an approximation. It's a complete generative model of text.
Why this objective works:
- To predict the next token well, the model must understand syntax, semantics, world knowledge, and reasoning
- The training signal is dense: every token position provides a gradient update
- Data is unlimited: any text is self-supervised training data
- Compression is intelligence: a model that predicts well has learned structure in language
Cross-Entropy Loss: The Math
The loss at position i is the negative log probability of the correct token:
āįµ¢ = -log P(tįµ¢ | tā, ..., tįµ¢āā)Over a sequence of length T:
ā = -(1/T) Σᵢ log P(tįµ¢ | tā, ..., tįµ¢āā)Perplexity is the exponentiated average loss ā a more interpretable metric:
Perplexity = exp(ā) = exp(-(1/T) Σᵢ log P(tįµ¢ | context))A perplexity of 20 means the model is, on average, as uncertain as if choosing uniformly among 20 equally likely tokens.
import torch
import torch.nn.functional as F
def compute_lm_loss(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
Compute cross-entropy loss for language modeling.
logits: (B, T, vocab_size) ā model predictions
targets: (B, T) ā next token indices (input shifted by 1)
"""
B, T, V = logits.shape
# Flatten batch and sequence dimensions for cross_entropy
loss = F.cross_entropy(
logits.view(B * T, V),
targets.view(B * T),
ignore_index=-100, # Padding tokens don't contribute to loss
)
return loss
def compute_perplexity(model, data_loader, device) -> float:
"""Compute perplexity over a dataset."""
model.eval()
total_loss = 0.0
total_tokens = 0
with torch.no_grad():
for batch in data_loader:
input_ids = batch["input_ids"].to(device)
# Shift: targets are input shifted left by 1
inputs = input_ids[:, :-1]
targets = input_ids[:, 1:]
logits, _ = model(inputs)
loss = compute_lm_loss(logits, targets)
n_tokens = (targets != -100).sum().item()
total_loss += loss.item() * n_tokens
total_tokens += n_tokens
avg_loss = total_loss / total_tokens
return torch.exp(torch.tensor(avg_loss)).item()Training Data Construction
How raw text becomes training batches:
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer
def build_training_data(
texts: list[str],
tokenizer,
context_length: int = 2048,
) -> np.ndarray:
"""
Tokenize and pack texts into fixed-length training sequences.
Returns a flat array of token ids.
"""
all_tokens = []
for text in texts:
tokens = tokenizer.encode(text)
# Add EOS token between documents
tokens.append(tokenizer.eos_token_id)
all_tokens.extend(tokens)
# Convert to numpy array
all_tokens = np.array(all_tokens, dtype=np.uint16)
# Shard into context_length chunks (no padding waste)
n_sequences = len(all_tokens) // context_length
all_tokens = all_tokens[:n_sequences * context_length]
all_tokens = all_tokens.reshape(n_sequences, context_length)
return all_tokens
def make_batch(token_array: np.ndarray, batch_size: int, device: str) -> dict:
"""Sample a random batch from the pretraining corpus."""
indices = np.random.randint(0, len(token_array), size=batch_size)
x = torch.tensor(token_array[indices], dtype=torch.long, device=device)
# Input is all tokens; target is shifted by 1 position
inputs = x[:, :-1] # tā to t_{T-1}
targets = x[:, 1:] # tā to t_T
return {"input_ids": inputs, "labels": targets}Data Mixture and Weighting
Modern LLMs don't train on equal proportions of all available data. The mixture significantly affects capabilities:
DATA_MIXTURE = {
# GPT-3 approximate mixture
"common_crawl": 0.60, # Filtered web data (largest but noisiest)
"books": 0.16, # Books (long-form reasoning, coherence)
"wikipedia": 0.03, # High-quality factual knowledge
"github": 0.03, # Code (improves logical reasoning)
"webtext": 0.22, # Curated high-quality web text
}
# LLaMA-3 approximate mixture
LLAMA3_MIXTURE = {
"web": 0.82, # Heavily filtered CommonCrawl
"code": 0.08, # Code from GitHub, Stack Overflow
"math": 0.04, # Mathematical text
"books": 0.03, # Published books
"other": 0.03,
}
def weighted_dataset_sampler(datasets: dict, weights: dict):
"""Sample from multiple datasets according to mixture weights."""
import random
dataset_list = list(datasets.keys())
weight_list = [weights[name] for name in dataset_list]
# Normalize weights
total = sum(weight_list)
weight_list = [w / total for w in weight_list]
while True:
# Sample dataset according to mixture weights
dataset_name = random.choices(dataset_list, weights=weight_list, k=1)[0]
yield next(iter(datasets[dataset_name]))Key findings on data mixture:
- Code improves reasoning on non-code tasks (chain-of-thought quality)
- Math data improves quantitative reasoning
- Over-representing low-quality web text hurts coherence
- The optimal mixture is discovered empirically through ablations
Token Weighting: Ignoring Special Tokens
Not all positions should contribute equally to the loss:
def build_labels_for_chat(
conversation: list[dict],
tokenizer,
only_train_on_assistant: bool = True,
) -> tuple[list[int], list[int]]:
"""
Build input_ids and labels for chat fine-tuning.
When only_train_on_assistant=True, mask user/system tokens with -100.
"""
input_ids = []
labels = []
for message in conversation:
role_tokens = tokenizer.encode(f"<|{message['role']}|>\n", add_special_tokens=False)
content_tokens = tokenizer.encode(message["content"] + "\n", add_special_tokens=False)
eot = [tokenizer.eos_token_id]
msg_tokens = role_tokens + content_tokens + eot
input_ids.extend(msg_tokens)
if only_train_on_assistant and message["role"] != "assistant":
# Mask non-assistant tokens ā model sees them but doesn't learn from them
labels.extend([-100] * len(msg_tokens))
else:
labels.extend(msg_tokens)
return input_ids, labelsThe Training Objective Stack
Modern LLMs go through multiple training stages, each with a different objective:
| Stage | Objective | Data | Purpose | |---|---|---|---| | Pretraining | Next-token prediction | Trillions of tokens from web/books/code | General language understanding | | SFT | Next-token (assistant only) | Instruction-response pairs | Learn to follow instructions | | RLHF/PPO | Reward maximization | Human preference pairs | Align with human values | | DPO | Preference likelihood | Chosen/rejected pairs | Alignment without RL complexity |
Each stage builds on the previous: you cannot do SFT without a pretrained base, and alignment without instruction tuning produces poor results.
Loss Curves and Training Diagnostics
import matplotlib.pyplot as plt
def plot_training_metrics(log_file: str) -> None:
"""Plot training and validation loss curves from a training log."""
import json
train_steps, train_losses = [], []
val_steps, val_losses = [], []
with open(log_file) as f:
for line in f:
entry = json.loads(line)
if "train_loss" in entry:
train_steps.append(entry["step"])
train_losses.append(entry["train_loss"])
if "val_loss" in entry:
val_steps.append(entry["step"])
val_losses.append(entry["val_loss"])
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
ax1.plot(train_steps, train_losses, label="Train", alpha=0.7)
ax1.plot(val_steps, val_losses, label="Validation", linewidth=2)
ax1.set_xlabel("Steps")
ax1.set_ylabel("Cross-Entropy Loss")
ax1.set_title("Training Progress")
ax1.legend()
# Perplexity (exp of loss)
import math
val_ppx = [math.exp(l) for l in val_losses]
ax2.plot(val_steps, val_ppx)
ax2.set_xlabel("Steps")
ax2.set_ylabel("Perplexity")
ax2.set_title("Validation Perplexity")
plt.tight_layout()
plt.savefig("training_curves.png", dpi=150)Diagnosing training problems from loss curves:
- Loss plateaus early: Learning rate too low, or dataset is too small/repetitive
- Loss spikes and recovers: Gradient spikes from poorly formatted data in the batch
- Train and val loss diverge: Overfitting ā reduce model size, add dropout, or get more data
- Loss doesn't decrease: LR too high (exploding gradients), bad initialization, or bug in data loading
- Loss decreases then spikes sharply: LR schedule issue or a corrupted data shard encountered
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.