Learnixo

LLMs Deep Dive · Lesson 9 of 24

Supervised Fine-Tuning (SFT)

What SFT Does

A pretrained base model is a token completion engine: given any prefix, it continues. It has no concept of being helpful, following instructions, or refusing dangerous requests. SFT (Supervised Fine-Tuning) teaches the model to behave as an assistant by training on examples of the desired interaction pattern.

The objective is identical to pretraining — cross-entropy on next-token prediction — but:

  1. Training data consists of structured conversations (instruction + response)
  2. Loss is computed only on assistant tokens, not on the user prompt
  3. The dataset is small (10K-500K examples vs trillions of tokens in pretraining)

What SFT teaches:

  • Response format (answer directly, use markdown, structure reasoning)
  • Following instructions (do what the user asks)
  • Conversation turns (how to interpret multi-turn context)

What SFT does NOT teach:

  • New factual knowledge (that comes from pretraining)
  • Value alignment (that requires RLHF or DPO)
  • Safety refusals (requires explicit training on adversarial examples)

Data Format: Chat Templates

Models require a consistent tokenization of conversations. The template wraps each message with special tokens:

Python
# Llama-3 chat template format
LLAMA3_TEMPLATE = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>

{system_message}<|eot_id|><|start_header_id|>user<|end_header_id|>

{user_message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{assistant_message}<|eot_id|>"""

# ChatML format (used by Mistral, many open models)
CHATML_TEMPLATE = """<|im_start|>system
{system_message}<|im_end|>
<|im_start|>user
{user_message}<|im_end|>
<|im_start|>assistant
{assistant_message}<|im_end|>"""

Using tokenizer.apply_chat_template() from HuggingFace handles this automatically:

Python
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")

conversation = [
    {"role": "system", "content": "You are a clinical pharmacist."},
    {"role": "user", "content": "What is the interaction between warfarin and ibuprofen?"},
    {"role": "assistant", "content": "Warfarin and ibuprofen have a major interaction. NSAIDs like ibuprofen..."},
]

# apply_chat_template adds special tokens and returns a string
formatted = tokenizer.apply_chat_template(
    conversation,
    tokenize=False,
    add_generation_prompt=False,   # True when generating, False when training
)
print(formatted)

Loss Masking: Train on Assistant Only

The model should learn to generate assistant responses, not repeat user prompts:

Python
import torch
from transformers import AutoTokenizer

def tokenize_with_masking(
    conversation: list[dict],
    tokenizer,
    max_length: int = 2048,
) -> dict:
    """
    Tokenize conversation and mask non-assistant tokens with -100.
    -100 is ignored by CrossEntropyLoss.
    """
    # Full conversation (for input_ids)
    full = tokenizer.apply_chat_template(
        conversation,
        tokenize=True,
        add_generation_prompt=False,
        return_tensors="pt",
    )[0]

    labels = full.clone()

    # Build conversation prefix up to each assistant turn, then find boundaries
    for i, message in enumerate(conversation):
        if message["role"] == "assistant":
            continue

        # Everything before this assistant turn should be masked
        prefix_conv = conversation[:i+1]
        prefix_tokens = tokenizer.apply_chat_template(
            prefix_conv,
            tokenize=True,
            add_generation_prompt=True,  # Include the assistant header
            return_tensors="pt",
        )[0]

        # Mask prefix tokens
        prefix_len = len(prefix_tokens)
        labels[:prefix_len] = -100

    # Truncate to max_length
    input_ids = full[:max_length]
    labels = labels[:max_length]

    attention_mask = torch.ones_like(input_ids)

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
    }

SFTTrainer: High-Level API

TRL's SFTTrainer handles chat template application, loss masking, and training loop:

Python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM

# Load base model
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer.pad_token = tokenizer.eos_token

# Load instruction dataset
dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft[:10000]")

# SFTConfig handles formatting and training
sft_config = SFTConfig(
    output_dir="./llama3-sft",
    max_seq_length=2048,
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    logging_steps=10,
    save_steps=500,
    bf16=True,
    # packing=True packs multiple short conversations into one sequence
    packing=False,
)

trainer = SFTTrainer(
    model=model,
    args=sft_config,
    train_dataset=dataset,
    tokenizer=tokenizer,
)

trainer.train()
trainer.save_model("./llama3-sft-final")

Parameter-Efficient Fine-Tuning: LoRA

Training all 8B parameters requires massive memory. LoRA trains only low-rank adapter matrices instead:

Python
from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=16,                    # Rank of the low-rank matrices (typically 8-64)
    lora_alpha=32,           # Scaling factor (typically  rank)
    target_modules=[         # Which weight matrices to apply LoRA to
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_dropout=0.05,
    bias="none",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# trainable params: 41,943,040 || all params: 8,030,261,248 || trainable%: 0.52%

LoRA mechanics: For a weight matrix W (d×k), instead of learning ΔW directly, learn two small matrices: A (d×r) and B (r×k) where r is much smaller than d or k. The adapted weight is W + αAB.

With r=16 and d=4096, k=4096: ΔW would be 16M parameters; AB is only 16×4096 + 16×4096 = 131K parameters.


QLoRA: LoRA with 4-bit Quantized Base

Fine-tune a 70B model on a single A100 GPU:

Python
import torch
from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",        # NormalFloat4  better for weights
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,   # Quantize the quantization constants
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-70B",
    quantization_config=bnb_config,
    device_map="auto",
)

# Apply LoRA on top of the 4-bit quantized model
model = get_peft_model(model, lora_config)

# Memory: 70B × 0.5 bytes/param (4-bit) + LoRA adapters  40GB
# Fits on a single 80GB A100

QLoRA flow:

  1. Load base model in NF4 (4-bit)
  2. Add LoRA adapters (float16/bfloat16)
  3. Forward pass: dequantize 4-bit weights → compute → only update adapter gradients
  4. Adapter weights stay in high precision; base model stays frozen in 4-bit

SFT Dataset Quality

SFT quality is heavily data-driven. Scaling data quality matters more than scaling quantity:

LIMA finding (2023): 1,000 carefully curated examples of diverse, high-quality instructions produced a model competitive with models trained on 50,000+ examples. Data quality dominates data quantity in SFT.

Quality criteria for SFT examples:

  • Diverse instruction types (factual, creative, coding, reasoning, refusals)
  • Clear, correct responses that directly address the instruction
  • Appropriate length (not too short, not padded)
  • Consistent persona and formatting
  • Adversarial examples with correct refusals
Python
def score_sft_example(example: dict) -> float:
    """Heuristic quality score for an SFT example."""
    score = 1.0
    response = example["response"]
    instruction = example["instruction"]

    # Response length should be proportional to instruction complexity
    instruction_words = len(instruction.split())
    response_words = len(response.split())

    if response_words < 5:
        score -= 0.5  # Too short  likely not useful

    if response_words > instruction_words * 50:
        score -= 0.2  # Suspiciously verbose

    # Responses starting with "Sure!" or "Certainly!" indicate sycophantic training data
    if response.startswith(("Sure!", "Certainly!", "Of course!", "Great question!")):
        score -= 0.3

    # Check for hallucination markers
    if "[CITATION NEEDED]" in response or "I believe" in response[:50]:
        score -= 0.1

    return max(0.0, score)

# Filter dataset to top 80% by quality score
def filter_sft_dataset(examples: list[dict], threshold: float = 0.7) -> list[dict]:
    scored = [(score_sft_example(ex), ex) for ex in examples]
    return [ex for score, ex in scored if score >= threshold]

Evaluating SFT Results

Python
def evaluate_sft_model(model, tokenizer, eval_prompts: list[dict]) -> dict:
    """Run a set of evaluation prompts and collect responses."""
    model.eval()
    results = []

    for item in eval_prompts:
        messages = [{"role": "user", "content": item["instruction"]}]
        input_text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        inputs = tokenizer(input_text, return_tensors="pt").to(model.device)

        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=512,
                temperature=0.1,
                do_sample=True,
            )

        response = tokenizer.decode(
            output[0][inputs["input_ids"].shape[1]:],
            skip_special_tokens=True,
        )

        results.append({
            "instruction": item["instruction"],
            "expected": item.get("reference"),
            "generated": response,
        })

    return results