Learnixo
Back to blog
AI Systemsintermediate

Pretraining: How LLMs Learn from Raw Text

The next-token prediction objective, training data curation, curriculum design, and what a model actually learns during pretraining on trillions of tokens.

Asma Hafeez KhanMay 16, 20266 min read
TransformersPretrainingTrainingLLM
Share:๐•

The Core Objective: Next-Token Prediction

Language model pretraining uses the self-supervised next-token prediction objective. Given tokens tโ‚, tโ‚‚, ..., t_{n-1}, predict t_n. The loss is cross-entropy over the vocabulary:

L = -ฮฃ log P(t_i | tโ‚, ..., t_{i-1})

No human labels required โ€” the labels are the next tokens in the text itself. This is why pretraining can scale to trillions of tokens: the training signal is freely available in any text corpus.

Python
import torch
import torch.nn as nn

def compute_language_modeling_loss(
    logits: torch.Tensor,  # (batch, seq_len, vocab_size)
    labels: torch.Tensor,  # (batch, seq_len) โ€” the input token IDs, shifted by 1
) -> torch.Tensor:
    """Standard causal language modeling loss."""
    # Shift: predict token i+1 from tokens 0..i
    # Input:  [t1, t2, t3, t4]
    # Labels: [t2, t3, t4, <pad>]

    shift_logits = logits[:, :-1, :].contiguous()  # (batch, seq_len-1, vocab)
    shift_labels = labels[:, 1:].contiguous()       # (batch, seq_len-1)

    loss_fn = nn.CrossEntropyLoss(ignore_index=-100)  # -100 = don't count this position
    return loss_fn(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1),
    )

What the Model Learns

Minimizing next-token prediction loss forces the model to implicitly learn:

  • Syntax: Correct grammar is more predictable than incorrect grammar
  • Semantics: Semantically related words follow each other more predictably
  • World knowledge: "Paris is the capital of ___" โ†’ "France" has high probability
  • Reasoning patterns: "If X then Y. X is true. Therefore ___" โ†’ "Y"
  • Code: Code structure and API usage patterns
  • Instruction following (partially): Documents like Stack Overflow Q&As include instruction-response patterns

This is the foundation: everything from factual recall to emergent reasoning emerges from optimizing a single prediction objective over diverse data.


Training Data Curation

The quality and diversity of pretraining data has an outsized effect on the resulting model. Modern LLM training data pipelines:

Web data (60โ€“80% of tokens):

  • Common Crawl: Petabytes of web text, requires heavy filtering
  • C4 (Colossal Clean Crawled Corpus): Cleaned Common Crawl subset used by T5
  • Filtering: Remove duplicates, adult content, low-quality pages, personal data

Curated sources (20โ€“40% of tokens):

  • Books (Project Gutenberg, Books3): Long-form reasoning, diverse vocabulary
  • Wikipedia: Factual, encyclopedic knowledge
  • Code (GitHub, Stack Overflow): Programming ability
  • Academic papers (arXiv, PubMed): Scientific knowledge

Data quality pipeline:

Python
def quality_filter(document: str) -> bool:
    """Filter low-quality web documents."""
    # Minimum length
    if len(document.split()) < 50:
        return False

    # Bullet point fraction โ€” likely navigation menus, not prose
    lines = document.split("\n")
    bullet_lines = sum(1 for l in lines if l.strip().startswith(("โ€ข", "-", "*", "ยท")))
    if bullet_lines / max(len(lines), 1) > 0.9:
        return False

    # Language detection (if targeting English)
    # Use fasttext or langdetect
    # if detect_language(document) != "en": return False

    # Perplexity filter โ€” very low perplexity = boilerplate, very high = garbage
    # Filter documents that score outside [10, 1000] on a small reference LM

    # Deduplication (typically MinHash + LSH at scale)
    return True

def compute_data_mixture(total_tokens: int) -> dict:
    """Approximate data mixture for a well-balanced LLM."""
    return {
        "web_filtered": int(total_tokens * 0.65),
        "books": int(total_tokens * 0.15),
        "code": int(total_tokens * 0.10),
        "wikipedia": int(total_tokens * 0.05),
        "academic": int(total_tokens * 0.05),
    }

# LLaMA-3 data mixture (approximate, 15T tokens):
mixture = compute_data_mixture(15_000_000_000_000)
for source, tokens in mixture.items():
    print(f"{source}: {tokens/1e12:.2f}T tokens")

Deduplication: Critical for Quality

Near-duplicate documents in training data cause memorization of exact strings and degrade generalization. MinHash deduplication is standard:

Python
from datasketch import MinHash, MinHashLSH

def create_minhash(text: str, num_perm: int = 128) -> MinHash:
    """Create a MinHash signature for document deduplication."""
    m = MinHash(num_perm=num_perm)
    # Shingles: overlapping n-grams of characters
    for i in range(len(text) - 4):
        m.update(text[i:i+5].encode("utf-8"))
    return m

def deduplicate_corpus(documents: list[str], threshold: float = 0.8) -> list[str]:
    """Remove near-duplicate documents using MinHash LSH."""
    lsh = MinHashLSH(threshold=threshold, num_perm=128)
    unique_docs = []

    for i, doc in enumerate(documents):
        mh = create_minhash(doc)
        # Check if a similar document is already in the index
        if not lsh.query(mh):
            lsh.insert(f"doc_{i}", mh)
            unique_docs.append(doc)

    return unique_docs

LLaMA-3 and others deduplicate at the document level using Jaccard similarity. Without deduplication, models memorize and regurgitate near-verbatim training text.


Training Curriculum and Data Ordering

Naive: randomly shuffle all 15T tokens and train in one pass.

Better: control what the model sees and when:

Curriculum learning: Start with simpler, cleaner data (books, Wikipedia), then add noisier web data. Some evidence this speeds up early training convergence.

Data mixing strategies:

Python
def weighted_data_mixer(
    sources: dict[str, list[str]],  # source name โ†’ list of documents
    weights: dict[str, float],       # source name โ†’ sampling weight
    batch_size: int,
) -> list[str]:
    """Sample a batch from multiple data sources with given weights."""
    import random

    # Normalize weights
    total_weight = sum(weights.values())
    normalized = {k: v / total_weight for k, v in weights.items()}

    batch = []
    for _ in range(batch_size):
        # Sample source according to weight
        r = random.random()
        cumulative = 0.0
        for source, weight in normalized.items():
            cumulative += weight
            if r <= cumulative:
                doc = random.choice(sources[source])
                batch.append(doc)
                break
    return batch

Epoch management: Modern LLM pretraining typically does less than one epoch on most data (trillions of tokens, each seen once). This is different from CV/NLP fine-tuning where 3โ€“10 epochs is normal.


Distributed Training at Scale

Training a 7B parameter model on 15T tokens requires distributed training. Key techniques:

Data parallelism: Each GPU processes a different batch; gradients are averaged across GPUs with all-reduce.

Tensor parallelism: Split individual weight matrices across GPUs (Megatron-LM style). Required for very large models.

Pipeline parallelism: Split model layers across GPUs; GPUs run in assembly-line fashion.

Python
# PyTorch FSDP (Fully Sharded Data Parallel) โ€” standard for 7B-70B training
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy

model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,  # Shard params, grads, optimizer state
    auto_wrap_policy=transformer_auto_wrap_policy,
    mixed_precision=mixed_precision_policy,
)

Mixed precision: Train in BF16 (not FP16) for stability. BF16 has the same exponent range as FP32, preventing overflow in large models.


Learning Rate Schedule

Standard pretraining uses a cosine decay schedule with linear warmup:

Python
import math

def cosine_lr_schedule(
    step: int,
    warmup_steps: int,
    total_steps: int,
    max_lr: float,
    min_lr: float,
) -> float:
    """Cosine decay with linear warmup."""
    if step < warmup_steps:
        return max_lr * step / warmup_steps

    progress = (step - warmup_steps) / (total_steps - warmup_steps)
    cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
    return min_lr + (max_lr - min_lr) * cosine_decay

# Typical hyperparameters for 7B model pretraining:
max_lr = 3e-4
min_lr = 3e-5  # 10% of max
warmup_steps = 2000
total_steps = 3_000_000  # For 15T tokens with batch size ~4M tokens

What Pretraining Does NOT Teach

Pretraining on raw text teaches the model to predict tokens, not to be helpful. A pretrained base model:

  • Will continue any text in whatever style the prompt looks like
  • Will not follow instructions (it predicts the next token, not executes commands)
  • Will not refuse harmful requests
  • May generate the most statistically likely completion, which isn't always the correct answer

This is why pretraining is followed by supervised fine-tuning (SFT) and alignment (RLHF/DPO) to produce a usable assistant.

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:๐•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.