Supervised Fine-Tuning (SFT) for LLMs
Turn a pretrained base model into an instruction-following assistant using SFT. Covers data formats, loss masking, LoRA adapters, SFTTrainer, and quality signals.
What SFT Does
A pretrained base model is a token completion engine: given any prefix, it continues. It has no concept of being helpful, following instructions, or refusing dangerous requests. SFT (Supervised Fine-Tuning) teaches the model to behave as an assistant by training on examples of the desired interaction pattern.
The objective is identical to pretraining — cross-entropy on next-token prediction — but:
- Training data consists of structured conversations (instruction + response)
- Loss is computed only on assistant tokens, not on the user prompt
- The dataset is small (10K-500K examples vs trillions of tokens in pretraining)
What SFT teaches:
- Response format (answer directly, use markdown, structure reasoning)
- Following instructions (do what the user asks)
- Conversation turns (how to interpret multi-turn context)
What SFT does NOT teach:
- New factual knowledge (that comes from pretraining)
- Value alignment (that requires RLHF or DPO)
- Safety refusals (requires explicit training on adversarial examples)
Data Format: Chat Templates
Models require a consistent tokenization of conversations. The template wraps each message with special tokens:
# Llama-3 chat template format
LLAMA3_TEMPLATE = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{system_message}<|eot_id|><|start_header_id|>user<|end_header_id|>
{user_message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{assistant_message}<|eot_id|>"""
# ChatML format (used by Mistral, many open models)
CHATML_TEMPLATE = """<|im_start|>system
{system_message}<|im_end|>
<|im_start|>user
{user_message}<|im_end|>
<|im_start|>assistant
{assistant_message}<|im_end|>"""Using tokenizer.apply_chat_template() from HuggingFace handles this automatically:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
conversation = [
{"role": "system", "content": "You are a clinical pharmacist."},
{"role": "user", "content": "What is the interaction between warfarin and ibuprofen?"},
{"role": "assistant", "content": "Warfarin and ibuprofen have a major interaction. NSAIDs like ibuprofen..."},
]
# apply_chat_template adds special tokens and returns a string
formatted = tokenizer.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=False, # True when generating, False when training
)
print(formatted)Loss Masking: Train on Assistant Only
The model should learn to generate assistant responses, not repeat user prompts:
import torch
from transformers import AutoTokenizer
def tokenize_with_masking(
conversation: list[dict],
tokenizer,
max_length: int = 2048,
) -> dict:
"""
Tokenize conversation and mask non-assistant tokens with -100.
-100 is ignored by CrossEntropyLoss.
"""
# Full conversation (for input_ids)
full = tokenizer.apply_chat_template(
conversation,
tokenize=True,
add_generation_prompt=False,
return_tensors="pt",
)[0]
labels = full.clone()
# Build conversation prefix up to each assistant turn, then find boundaries
for i, message in enumerate(conversation):
if message["role"] == "assistant":
continue
# Everything before this assistant turn should be masked
prefix_conv = conversation[:i+1]
prefix_tokens = tokenizer.apply_chat_template(
prefix_conv,
tokenize=True,
add_generation_prompt=True, # Include the assistant header
return_tensors="pt",
)[0]
# Mask prefix tokens
prefix_len = len(prefix_tokens)
labels[:prefix_len] = -100
# Truncate to max_length
input_ids = full[:max_length]
labels = labels[:max_length]
attention_mask = torch.ones_like(input_ids)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}SFTTrainer: High-Level API
TRL's SFTTrainer handles chat template application, loss masking, and training loop:
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM
# Load base model
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B",
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer.pad_token = tokenizer.eos_token
# Load instruction dataset
dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft[:10000]")
# SFTConfig handles formatting and training
sft_config = SFTConfig(
output_dir="./llama3-sft",
max_seq_length=2048,
num_train_epochs=1,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-5,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
logging_steps=10,
save_steps=500,
bf16=True,
# packing=True packs multiple short conversations into one sequence
packing=False,
)
trainer = SFTTrainer(
model=model,
args=sft_config,
train_dataset=dataset,
tokenizer=tokenizer,
)
trainer.train()
trainer.save_model("./llama3-sft-final")Parameter-Efficient Fine-Tuning: LoRA
Training all 8B parameters requires massive memory. LoRA trains only low-rank adapter matrices instead:
from peft import LoraConfig, get_peft_model, TaskType
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=16, # Rank of the low-rank matrices (typically 8-64)
lora_alpha=32, # Scaling factor (typically 2× rank)
target_modules=[ # Which weight matrices to apply LoRA to
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_dropout=0.05,
bias="none",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# trainable params: 41,943,040 || all params: 8,030,261,248 || trainable%: 0.52%LoRA mechanics: For a weight matrix W (d×k), instead of learning ΔW directly, learn two small matrices: A (d×r) and B (r×k) where r is much smaller than d or k. The adapted weight is W + αAB.
With r=16 and d=4096, k=4096: ΔW would be 16M parameters; AB is only 16×4096 + 16×4096 = 131K parameters.
QLoRA: LoRA with 4-bit Quantized Base
Fine-tune a 70B model on a single A100 GPU:
import torch
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # NormalFloat4 — better for weights
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True, # Quantize the quantization constants
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-70B",
quantization_config=bnb_config,
device_map="auto",
)
# Apply LoRA on top of the 4-bit quantized model
model = get_peft_model(model, lora_config)
# Memory: 70B × 0.5 bytes/param (4-bit) + LoRA adapters ≈ 40GB
# Fits on a single 80GB A100QLoRA flow:
- Load base model in NF4 (4-bit)
- Add LoRA adapters (float16/bfloat16)
- Forward pass: dequantize 4-bit weights → compute → only update adapter gradients
- Adapter weights stay in high precision; base model stays frozen in 4-bit
SFT Dataset Quality
SFT quality is heavily data-driven. Scaling data quality matters more than scaling quantity:
LIMA finding (2023): 1,000 carefully curated examples of diverse, high-quality instructions produced a model competitive with models trained on 50,000+ examples. Data quality dominates data quantity in SFT.
Quality criteria for SFT examples:
- Diverse instruction types (factual, creative, coding, reasoning, refusals)
- Clear, correct responses that directly address the instruction
- Appropriate length (not too short, not padded)
- Consistent persona and formatting
- Adversarial examples with correct refusals
def score_sft_example(example: dict) -> float:
"""Heuristic quality score for an SFT example."""
score = 1.0
response = example["response"]
instruction = example["instruction"]
# Response length should be proportional to instruction complexity
instruction_words = len(instruction.split())
response_words = len(response.split())
if response_words < 5:
score -= 0.5 # Too short — likely not useful
if response_words > instruction_words * 50:
score -= 0.2 # Suspiciously verbose
# Responses starting with "Sure!" or "Certainly!" indicate sycophantic training data
if response.startswith(("Sure!", "Certainly!", "Of course!", "Great question!")):
score -= 0.3
# Check for hallucination markers
if "[CITATION NEEDED]" in response or "I believe" in response[:50]:
score -= 0.1
return max(0.0, score)
# Filter dataset to top 80% by quality score
def filter_sft_dataset(examples: list[dict], threshold: float = 0.7) -> list[dict]:
scored = [(score_sft_example(ex), ex) for ex in examples]
return [ex for score, ex in scored if score >= threshold]Evaluating SFT Results
def evaluate_sft_model(model, tokenizer, eval_prompts: list[dict]) -> dict:
"""Run a set of evaluation prompts and collect responses."""
model.eval()
results = []
for item in eval_prompts:
messages = [{"role": "user", "content": item["instruction"]}]
input_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.1,
do_sample=True,
)
response = tokenizer.decode(
output[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True,
)
results.append({
"instruction": item["instruction"],
"expected": item.get("reference"),
"generated": response,
})
return resultsFound this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.