DPO: Direct Preference Optimization
How DPO achieves alignment without reinforcement learning. Covers the mathematical derivation from RLHF, the DPO loss, dataset construction, and when DPO outperforms PPO.
The Core Insight
RLHF/PPO is complex: it requires training a separate reward model, running 4 models simultaneously during training, and tuning sensitive RL hyperparameters. DPO (Rafailov et al., 2023) showed that the same optimization problem can be solved directly from preference pairs — no reward model, no PPO.
The key insight: the optimal policy under the RLHF objective has a closed-form relationship to the reference policy. This means we can parameterize the reward function implicitly through the policy itself and optimize directly on preference data.
Mathematical Derivation
The RLHF objective (maximize reward, penalize KL from reference):
max_π E[r(x, y)] - β·KL(π(y|x) || π_ref(y|x))The optimal solution to this constrained optimization is:
π*(y|x) = π_ref(y|x) · exp(r(x,y)/β) / Z(x)where Z(x) is a normalizing constant.
Rearranging to express the reward in terms of the policy:
r(x, y) = β · log(π*(y|x) / π_ref(y|x)) + β · log Z(x)Substituting into the Bradley-Terry preference model (which the reward model uses):
P(y_w ≻ y_l | x) = σ(r(x, y_w) - r(x, y_l))
= σ(β · log(π*(y_w|x)/π_ref(y_w|x)) - β · log(π*(y_l|x)/π_ref(y_l|x)))Note that Z(x) cancels out — we never need to compute it. This gives us the DPO loss:
L_DPO(π_θ) = -E[(x,y_w,y_l)] log σ(β · log(π_θ(y_w|x)/π_ref(y_w|x)) - β · log(π_θ(y_l|x)/π_ref(y_l|x)))DPO Implementation
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
def compute_log_probs(
model,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
"""
Compute the sum of log probabilities of the response tokens.
Only the response tokens (not prompt) are summed.
"""
with torch.no_grad() if model.training is False else torch.enable_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits # (B, T, V)
# Shift for next-token prediction
shift_logits = logits[:, :-1, :].contiguous() # (B, T-1, V)
shift_labels = labels[:, 1:].contiguous() # (B, T-1)
shift_mask = (shift_labels != -100) # True where we compute loss
# Per-token log probabilities
log_probs = F.log_softmax(shift_logits, dim=-1) # (B, T-1, V)
token_log_probs = log_probs.gather(
-1, shift_labels.clamp(min=0).unsqueeze(-1)
).squeeze(-1) # (B, T-1)
# Sum over response tokens only (mask ignores prompt)
sequence_log_probs = (token_log_probs * shift_mask).sum(dim=-1) # (B,)
return sequence_log_probs
def dpo_loss(
policy_model,
reference_model,
batch: dict,
beta: float = 0.1,
) -> tuple[torch.Tensor, dict]:
"""
Compute DPO loss for a batch of preference pairs.
batch contains:
- chosen_input_ids, chosen_attention_mask, chosen_labels
- rejected_input_ids, rejected_attention_mask, rejected_labels
"""
# Log probs under current policy
policy_chosen_logps = compute_log_probs(
policy_model,
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["chosen_labels"],
)
policy_rejected_logps = compute_log_probs(
policy_model,
batch["rejected_input_ids"],
batch["rejected_attention_mask"],
batch["rejected_labels"],
)
# Log probs under frozen reference model
with torch.no_grad():
ref_chosen_logps = compute_log_probs(
reference_model,
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["chosen_labels"],
)
ref_rejected_logps = compute_log_probs(
reference_model,
batch["rejected_input_ids"],
batch["rejected_attention_mask"],
batch["rejected_labels"],
)
# DPO objective: chosen log ratio should exceed rejected log ratio
chosen_rewards = beta * (policy_chosen_logps - ref_chosen_logps)
rejected_rewards = beta * (policy_rejected_logps - ref_rejected_logps)
loss = -F.logsigmoid(chosen_rewards - rejected_rewards).mean()
# Metrics for monitoring
metrics = {
"loss": loss.item(),
"chosen_rewards": chosen_rewards.mean().item(),
"rejected_rewards": rejected_rewards.mean().item(),
"reward_margin": (chosen_rewards - rejected_rewards).mean().item(),
"accuracy": (chosen_rewards > rejected_rewards).float().mean().item(),
}
return loss, metricsDPO with TRL
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
# Policy model (start from SFT model)
policy_model = AutoModelForCausalLM.from_pretrained(
"path/to/sft-model",
torch_dtype=torch.bfloat16,
device_map="auto",
)
# Reference model (frozen copy of SFT model)
ref_model = AutoModelForCausalLM.from_pretrained(
"path/to/sft-model",
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("path/to/sft-model")
# Preference dataset format expected by DPOTrainer:
# {"prompt": str, "chosen": str, "rejected": str}
dataset = load_dataset("Anthropic/hh-rlhf", split="train")
dpo_config = DPOConfig(
beta=0.1, # Temperature — controls deviation from reference
max_length=1024,
max_prompt_length=512,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=5e-7, # Very low LR — DPO is sensitive to overfitting
num_train_epochs=1,
bf16=True,
output_dir="./llama3-dpo",
logging_steps=10,
evaluation_strategy="steps",
eval_steps=100,
)
trainer = DPOTrainer(
model=policy_model,
ref_model=ref_model,
args=dpo_config,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=tokenizer,
)
trainer.train()Dataset Construction for DPO
Converting raw preference data to DPO format:
def prepare_dpo_dataset(raw_pairs: list[dict], tokenizer) -> list[dict]:
"""
Convert raw preference pairs to DPO training format.
raw_pairs: list of {prompt, chosen_response, rejected_response}
"""
formatted = []
for pair in raw_pairs:
# Format prompt as a user message (using chat template)
prompt_messages = [{"role": "user", "content": pair["prompt"]}]
prompt = tokenizer.apply_chat_template(
prompt_messages,
tokenize=False,
add_generation_prompt=True,
)
formatted.append({
"prompt": prompt,
"chosen": pair["chosen_response"],
"rejected": pair["rejected_response"],
})
return formatted
def augment_with_synthetic_pairs(
sft_model,
tokenizer,
prompts: list[str],
reward_model,
n_samples: int = 8,
) -> list[dict]:
"""
Generate synthetic preference pairs using best-of-N sampling.
Sample N responses, score all with reward model, use best as chosen
and worst as rejected.
"""
pairs = []
for prompt in prompts:
# Generate N responses
inputs = tokenizer(prompt, return_tensors="pt").to(sft_model.device)
responses = []
for _ in range(n_samples):
with torch.no_grad():
output = sft_model.generate(
**inputs,
max_new_tokens=256,
temperature=0.9,
do_sample=True,
)
response_text = tokenizer.decode(
output[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True,
)
score = reward_model(prompt + response_text)
responses.append((score, response_text))
# Sort by reward score
responses.sort(key=lambda x: x[0], reverse=True)
pairs.append({
"prompt": prompt,
"chosen": responses[0][1], # Highest scored
"rejected": responses[-1][1], # Lowest scored
})
return pairsMonitoring DPO Training
Key metrics to watch during DPO training:
def evaluate_dpo_model(
policy_model,
ref_model,
tokenizer,
eval_dataset,
beta: float = 0.1,
) -> dict:
"""Compute DPO evaluation metrics."""
total_loss = 0.0
total_accuracy = 0.0
total_margin = 0.0
n_batches = 0
policy_model.eval()
for batch in eval_dataset:
with torch.no_grad():
loss, metrics = dpo_loss(policy_model, ref_model, batch, beta=beta)
total_loss += metrics["loss"]
total_accuracy += metrics["accuracy"]
total_margin += metrics["reward_margin"]
n_batches += 1
return {
"eval_loss": total_loss / n_batches,
"eval_accuracy": total_accuracy / n_batches, # Should be > 0.5 and growing
"reward_margin": total_margin / n_batches, # Larger = more confident
}
# What to watch:
# - eval_accuracy climbing toward ~0.80-0.90 is healthy
# - eval_accuracy plateauing at 0.55-0.60 means data is too ambiguous or beta is wrong
# - chosen_rewards growing while rejected_rewards fall: good
# - Both rewards falling steeply: KL penalty too high (increase beta)DPO vs PPO: When to Choose
Choose DPO when:
- You have a high-quality preference dataset (10K+ pairs)
- You want simpler training (no RL, no reward model)
- Training stability is important
- You're fine-tuning a model for a specific domain
Choose PPO when:
- You need online preference learning (generating and rating in real time)
- The task requires complex multi-step reasoning where RL exploration helps
- You have a reliable reward function (e.g., code execution success, math verification)
- You want the model to improve beyond human demonstration quality
Practical considerations:
- DPO requires the reference model in memory during training (adds memory cost)
- DPO is sensitive to the quality of the SFT initialization — start from a strong SFT model
- DPO beta (typically 0.05-0.5) controls how much deviation from SFT is allowed; higher beta = stay closer to SFT
- DPO with low-quality preference data underperforms PPO — GIGO applies
Variants: IPO, KTO, ORPO
| Method | Key Change | Advantage | |---|---|---| | DPO | Original derivation from RLHF | No reward model needed | | IPO (Identity PO) | Regularization term prevents overfitting | More robust with less data | | KTO | Uses only good/bad labels (no pairs) | Works without paired preferences | | ORPO | Combines SFT loss + odds ratio penalty | Single training stage, no ref model |
ORPO (Odds Ratio Preference Optimization) is increasingly popular because it eliminates the need for a reference model entirely, making training cheaper:
def orpo_loss(
policy_model,
chosen_input_ids, chosen_labels,
rejected_input_ids, rejected_labels,
lambda_orpo: float = 0.1,
) -> torch.Tensor:
"""
ORPO loss = SFT loss on chosen + odds ratio penalty.
No reference model needed.
"""
# Standard SFT loss (maximize log prob of chosen)
chosen_logps = compute_log_probs(policy_model, chosen_input_ids, chosen_labels)
sft_loss = -chosen_logps.mean()
# Odds ratio: log(p(chosen) / (1 - p(chosen))) - log(p(rejected) / (1 - p(rejected)))
rejected_logps = compute_log_probs(policy_model, rejected_input_ids, rejected_labels)
log_odds_chosen = chosen_logps - torch.log1p(-chosen_logps.exp())
log_odds_rejected = rejected_logps - torch.log1p(-rejected_logps.exp())
ratio = F.logsigmoid(log_odds_chosen - log_odds_rejected)
or_loss = -ratio.mean()
return sft_loss + lambda_orpo * or_lossFound this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.