Learnixo

Transformer Architecture Q&A · Lesson 16 of 23

Training Objective: Next Token Prediction

The Three Pretraining Objectives

Transformers are pretrained with self-supervised objectives — no human labels required. Three objectives dominate:

1. Causal Language Modelling (CLM)
   Used by: GPT series, LLaMA, Mistral, Falcon
   Architecture: Decoder-only

2. Masked Language Modelling (MLM)
   Used by: BERT, RoBERTa, ClinicalBERT, BioBERT
   Architecture: Encoder-only

3. Sequence-to-Sequence (Seq2Seq)
   Used by: T5, BART, mT5
   Architecture: Encoder-decoder

Causal Language Modelling (CLM)

Predict the next token given all previous tokens:

Corpus: "The patient takes Warfarin 5mg daily"

Training examples (one sequence, all positions simultaneously):
  Input: "The"                        → Target: "patient"
  Input: "The patient"                → Target: "takes"
  Input: "The patient takes"          → Target: "Warfarin"
  Input: "The patient takes Warfarin" → Target: "5mg"
  ...

Loss: cross-entropy at every position
  L = -1/T Σᵢ log P(token_i | token_0, ..., token_{i-1})

Every token contributes to the loss — maximally data-efficient. No masking needed on the loss (unlike MLM).


Masked Language Modelling (MLM)

Randomly mask 15% of tokens and predict them from surrounding context:

Original: "The patient takes Warfarin 5mg daily"
Masked:   "The patient [MASK] Warfarin 5mg [MASK]"

  80% of selected tokens → replaced with [MASK]
  10% of selected tokens → replaced with a random token
  10% of selected tokens → kept unchanged (but still predicted)

Loss: cross-entropy on MASKED positions only
  L = -1/|masked| Σ_{i∈masked} log P(token_i | all other tokens)

The 10%/10% trick prevents the model from learning "ignore [MASK] tokens at fine-tuning time" — it must predict correctly even for unchanged and randomly-replaced tokens.


Why MLM Is Bidirectional

CLM: position i sees tokens 0..i-1
     → predicts one direction only
     → useful for generation

MLM: [MASK] at position i sees tokens 0..i-1 AND i+1..T
     → uses BOTH directions
     → useful for understanding

This bidirectional context is why BERT outperforms GPT on classification,
NER, and extractive QA despite having similar parameters.

Seq2Seq Pretraining (T5)

T5 uses "span corruption" — mask random spans of text (not single tokens):

Original: "Thank you for inviting me to your party last week"
Corrupted: "Thank you [X] me to your party [Y] week"

Target: "[X] for inviting [Y] last [Z]"
  (only the corrupted spans need to be generated, with sentinel tokens)

Loss: standard seq2seq cross-entropy on the target sequence

The encoder sees the corrupted input; the decoder generates the missing spans. This trains both bidirectional understanding and generation simultaneously.


BART Pretraining

BART uses more aggressive corruption strategies:

Token masking:       individual token → [MASK]
Token deletion:      token removed entirely (length changes)
Text infilling:      whole span → single [MASK] (length unknown)
Sentence permutation: shuffle sentence order
Document rotation:   rotate start of document

BART always predicts the original uncorrupted text — making it particularly good at summarisation (which is also a "produce clean text from corrupted/noisy input" task).


Loss Functions

All three objectives use cross-entropy loss:

Python
import torch
import torch.nn.functional as F

def compute_clm_loss(logits, target_ids):
    # logits: (batch, seq_len, vocab_size)
    # target_ids: (batch, seq_len)  shifted right by 1
    return F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        target_ids.view(-1),
        ignore_index=-100  # mask padding
    )

def compute_mlm_loss(logits, target_ids, mask):
    # mask: True where token was masked
    # target_ids: -100 for non-masked positions (ignored)
    labels = torch.where(mask, target_ids, torch.tensor(-100))
    return F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        labels.view(-1),
        ignore_index=-100
    )

The -100 convention (PyTorch's default ignore_index) allows different positions to be ignored at different samples within the same batch.


Comparing Objectives

| Property | CLM | MLM | Seq2Seq | |----------|-----|-----|---------| | Direction | Left-to-right | Bidirectional | Bidirectional (enc) + Left-to-right (dec) | | % tokens with loss | 100% | ~15% | 100% (of target) | | Architecture | Decoder-only | Encoder-only | Encoder-decoder | | Good for | Generation | Understanding | Structured I/O | | Examples | GPT, LLaMA | BERT, RoBERTa | T5, BART |


Interview Answer

"The three main pretraining objectives are: CLM (causal language modelling) — predict the next token given all previous tokens, applied to all positions, used by GPT/LLaMA; MLM (masked language modelling) — randomly mask 15% of tokens and predict them from bidirectional context, used by BERT; and seq2seq — the encoder sees corrupted input, the decoder generates the original, used by T5/BART. All use cross-entropy loss. CLM is maximally data-efficient but unidirectional; MLM enables bidirectional understanding but supervises only 15% of positions; seq2seq trains both understanding and generation jointly."