LLMs Deep Dive · Lesson 11 of 24
DPO: Direct Preference Optimization
The Alignment Problem
A pretrained LLM predicts likely next tokens — it doesn't follow instructions, doesn't prefer truthful answers, and doesn't refuse harmful requests. Alignment fine-tuning changes this.
RLHF (Reinforcement Learning from Human Feedback) was the dominant approach:
RLHF pipeline:
1. Supervised fine-tuning (SFT): teach the model to follow instructions
2. Reward model: train a model to score completions by human preference
3. RL optimisation: use PPO to maximise reward while staying close to SFT model
Problem with RLHF:
Requires a separate reward model
PPO is unstable — sensitive to hyperparameters
Three separate training stages — expensive
Reward hacking: model finds ways to maximise reward without being helpfulDPO: The Insight
DPO (Rafailov et al., 2023) shows that the RLHF objective can be optimised directly without a reward model or RL:
RLHF maximises: E[r(x, y)] - β · KL[π_θ || π_ref]
r = reward, π_θ = current policy, π_ref = reference (SFT) policy
DPO derives the optimal policy analytically:
r*(x, y) = β · log(π*(y|x) / π_ref(y|x)) + β · log Z(x)
This means: the reward can be expressed in terms of the policy itself.
Substituting into the RLHF loss and simplifying:
DPO loss = -E_{(x, y_w, y_l)} log σ(β · log(π_θ(y_w|x)/π_ref(y_w|x))
- β · log(π_θ(y_l|x)/π_ref(y_l|x)))
where y_w = preferred response, y_l = dispreferred responseDPO Training Data Format
Preference pairs:
Prompt x: "Explain what Warfarin does."
Preferred (y_w): "Warfarin is an anticoagulant that prevents blood clots
by inhibiting Vitamin K epoxide reductase..."
Dispreferred (y_l): "Warfarin is a rat poison used as a blood thinner."
(technically true but inappropriate framing)Human annotators (or GPT-4) compare pairs and label which response is preferred. These pairs are the entire training signal for DPO — no reward model needed.
DPO Implementation
import torch
import torch.nn.functional as F
def dpo_loss(
policy_chosen_logps: torch.Tensor, # log P_θ(y_w | x)
policy_rejected_logps: torch.Tensor, # log P_θ(y_l | x)
ref_chosen_logps: torch.Tensor, # log P_ref(y_w | x)
ref_rejected_logps: torch.Tensor, # log P_ref(y_l | x)
beta: float = 0.1
) -> torch.Tensor:
chosen_rewards = beta * (policy_chosen_logps - ref_chosen_logps)
rejected_rewards = beta * (policy_rejected_logps - ref_rejected_logps)
return -F.logsigmoid(chosen_rewards - rejected_rewards).mean()
# Computing log probs from logits:
def log_probs_from_logits(logits, labels):
log_probs = F.log_softmax(logits, dim=-1)
# gather at each position the log prob of the actual token
return log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1).sum(-1)RLHF vs DPO
| Property | RLHF (PPO) | DPO | |----------|-----------|-----| | Reward model | Required (separate model) | Not needed | | Training stages | 3 (SFT → RM → PPO) | 2 (SFT → DPO) | | Stability | Sensitive to hyperparameters | More stable | | Compute | Higher (PPO sampling overhead) | Lower | | Quality | Often better (more flexible) | Competitive, simpler | | Used by | InstructGPT, LLaMA 2 Chat | Zephyr, Llama 3 Instruct |
Beta Parameter
β controls the KL divergence constraint — how far the model can move from π_ref:
β → 0: almost no constraint — model ignores the reference policy
Risks: reward hacking, forgetting instruction-following from SFT
β → ∞: very tight constraint — model barely moves from π_ref
Risks: barely learns preference, ignores training signal
Typical β: 0.1–0.5 (LLaMA 3 uses β=0.1)Interview Answer
"DPO (Direct Preference Optimisation) aligns LLMs from preference pairs without a separate reward model or reinforcement learning. It mathematically shows that the RLHF objective has a closed-form solution expressible purely in terms of the language model's own log probabilities. The DPO loss rewards the policy for assigning higher probability to preferred responses and lower probability to dispreferred ones, relative to a frozen reference model. This replaces RLHF's three-stage pipeline (SFT → reward model → PPO) with a stable two-stage process (SFT → DPO), and is now widely used in open-source alignment (Zephyr, LLaMA 3 Instruct)."