Learnixo

Fine-Tuning LLMs · Lesson 4 of 16

Catastrophic Forgetting: What It Is and How to Avoid It

Catastrophic Forgetting

Catastrophic forgetting is the phenomenon where a neural network trained on a new task loses its ability to perform previously learned tasks. In the context of LLM fine-tuning, it means your carefully fine-tuned drug-information model can no longer write a grammatical sentence, follow multi-step instructions, or answer basic reasoning questions.


What It Is

When you fine-tune a language model on a narrow dataset, the gradient updates push all the model's weights toward patterns in your training data. If your training data is a small, homogeneous set (e.g., 500 Q&A pairs about medications), the model's weights drift significantly from the initialization that encoded broad language understanding.

The result: a model that is excellent at your narrow task but broken everywhere else.

Python
# Conceptual demonstration: catastrophic forgetting in a small network
import torch
import torch.nn as nn
import torch.optim as optim

# Tiny network representing a "language model" with two capabilities
class TinyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.shared = nn.Linear(10, 20)   # shared representation
        self.head_a = nn.Linear(20, 1)    # Task A: general reasoning
        self.head_b = nn.Linear(20, 1)    # Task B: domain task

    def forward(self, x, task="a"):
        features = torch.relu(self.shared(x))
        if task == "a":
            return self.head_a(features)
        return self.head_b(features)

model = TinyModel()
optimizer = optim.SGD(model.parameters(), lr=0.1)

# Phase 1: Pre-train on Task A (general language)
task_a_data = [(torch.randn(10), torch.tensor([1.0])) for _ in range(100)]
for x, y in task_a_data:
    loss = nn.MSELoss()(model(x, task="a"), y)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

# Evaluate Task A performance before fine-tuning
with torch.no_grad():
    task_a_loss_before = sum(
        nn.MSELoss()(model(x, task="a"), y).item()
        for x, y in task_a_data[:20]
    ) / 20
print(f"Task A loss before fine-tuning: {task_a_loss_before:.4f}")

# Phase 2: Fine-tune aggressively on Task B (narrow domain)
# High learning rate, many epochs  catastrophic forgetting conditions
task_b_data = [(torch.randn(10) * 0.01, torch.tensor([5.0])) for _ in range(50)]
# Note: Task B has a VERY different distribution (small inputs, large targets)

optimizer_ft = optim.SGD(model.parameters(), lr=0.5)  # much higher LR
for epoch in range(50):
    for x, y in task_b_data:
        loss = nn.MSELoss()(model(x, task="b"), y)
        loss.backward()
        optimizer_ft.step()
        optimizer_ft.zero_grad()

# Evaluate Task A performance AFTER fine-tuning on Task B
with torch.no_grad():
    task_a_loss_after = sum(
        nn.MSELoss()(model(x, task="a"), y).item()
        for x, y in task_a_data[:20]
    ) / 20
print(f"Task A loss after fine-tuning: {task_a_loss_after:.4f}")
print(f"Forgetting ratio: {task_a_loss_after / task_a_loss_before:.1f}x worse")
# Typical output with high LR: 15x worse  catastrophic forgetting

Why It Happens

The mechanism is gradient descent. Each gradient step moves parameters toward minimizing loss on the current batch of data. The optimizer has no memory of what the original weights were "good for" — it simply follows the gradient downhill on your new loss surface.

Python
# Visualizing why gradients overwrite old knowledge
import torch

# Suppose the optimal weight for Task A (general language) is approximately 1.0
# Suppose the optimal weight for Task B (drug domain) is approximately -2.0
# A single shared weight can't be both

w = torch.tensor(1.0, requires_grad=True)  # pre-trained value (Task A optimal)

def task_a_loss(w):
    return (w - 1.0) ** 2   # loss is 0 when w = 1.0 (Task A optimum)

def task_b_loss(w):
    return (w + 2.0) ** 2   # loss is 0 when w = -2.0 (Task B optimum)

print(f"Initial w: {w.item():.2f}")
print(f"Task A loss at w=1.0: {task_a_loss(w).item():.2f}")  # 0.0 — perfect
print(f"Task B loss at w=1.0: {task_b_loss(w).item():.2f}")  # 9.0 — terrible

# Fine-tune on Task B with high learning rate (10 steps)
optimizer = torch.optim.SGD([w], lr=0.5)
for step in range(10):
    loss = task_b_loss(w)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

print(f"\nAfter fine-tuning on Task B:")
print(f"w: {w.item():.2f}")
print(f"Task A loss: {task_a_loss(w).item():.2f}")   # High — forgot Task A
print(f"Task B loss: {task_b_loss(w).item():.2f}")   # Low — learned Task B
# w moved from 1.0 toward -2.0, breaking Task A performance

How to Detect Catastrophic Forgetting

You need to benchmark the model on general tasks both before and after fine-tuning. If general performance drops significantly, forgetting has occurred.

Python
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import json

def evaluate_general_qa(model_name_or_path: str, test_questions: list[dict]) -> dict:
    """
    Evaluate a model on general-purpose questions to detect forgetting.

    test_questions format:
    [{"question": "...", "correct_answer": "...", "category": "..."}]
    """
    generator = pipeline(
        "text-generation",
        model=model_name_or_path,
        device_map="auto",
        max_new_tokens=50,
    )

    results = {"correct": 0, "total": 0, "by_category": {}}

    for item in test_questions:
        prompt = f"Question: {item['question']}\nAnswer:"
        output = generator(prompt, temperature=0.0, do_sample=False)[0]["generated_text"]
        answer = output[len(prompt):].strip().lower()

        is_correct = item["correct_answer"].lower() in answer
        results["correct"] += int(is_correct)
        results["total"] += 1

        cat = item["category"]
        if cat not in results["by_category"]:
            results["by_category"][cat] = {"correct": 0, "total": 0}
        results["by_category"][cat]["correct"] += int(is_correct)
        results["by_category"][cat]["total"] += 1

    results["overall_accuracy"] = results["correct"] / results["total"]
    for cat in results["by_category"]:
        d = results["by_category"][cat]
        d["accuracy"] = d["correct"] / d["total"]

    return results

# General knowledge test suite (abbreviated)
general_qa_tests = [
    {"question": "What is the capital of France?", "correct_answer": "Paris", "category": "geography"},
    {"question": "What is 15 multiplied by 7?", "correct_answer": "105", "category": "math"},
    {"question": "Who wrote Hamlet?", "correct_answer": "Shakespeare", "category": "literature"},
    {"question": "What gas do plants absorb from the air?", "correct_answer": "carbon dioxide", "category": "science"},
    {"question": "What year did World War II end?", "correct_answer": "1945", "category": "history"},
]

# Run before and after fine-tuning
# before_results = evaluate_general_qa("meta-llama/Llama-3.2-3B-Instruct", general_qa_tests)
# after_results  = evaluate_general_qa("./my-fine-tuned-model", general_qa_tests)

# Compare
def compare_forgetting(before: dict, after: dict, threshold: float = 0.10) -> dict:
    """Identify if fine-tuning caused significant forgetting."""
    drop = before["overall_accuracy"] - after["overall_accuracy"]
    forgot = drop > threshold

    category_drops = {}
    for cat in before["by_category"]:
        if cat in after["by_category"]:
            cat_drop = (
                before["by_category"][cat]["accuracy"] -
                after["by_category"][cat]["accuracy"]
            )
            category_drops[cat] = round(cat_drop, 3)

    return {
        "overall_accuracy_before": round(before["overall_accuracy"], 3),
        "overall_accuracy_after": round(after["overall_accuracy"], 3),
        "accuracy_drop": round(drop, 3),
        "catastrophic_forgetting_detected": forgot,
        "category_drops": category_drops,
        "recommendation": (
            "Reduce learning rate and epochs, or switch to LoRA" if forgot
            else "Fine-tuning preserved general capabilities"
        )
    }

Mitigation 1: Use LoRA

LoRA is the most effective mitigation. Because the original weights remain frozen, they cannot be overwritten. Only the small adapter matrices change.

Python
from peft import get_peft_model, LoraConfig, TaskType
from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-3B-Instruct",
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

# The key property: base weights are frozen
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
    bias="none",
)

peft_model = get_peft_model(model, lora_config)

# Verify: base model weights are frozen
for name, param in peft_model.named_parameters():
    if "lora" not in name:
        assert not param.requires_grad, f"Base weight {name} should be frozen!"

# LoRA cannot cause catastrophic forgetting of base weights
# because the base weights are NEVER UPDATED
# The model learns: output = frozen_W(x) + adapter(x)
# Removing the adapter restores the original model exactly
print("All base weights confirmed frozen. Catastrophic forgetting impossible via this path.")

Mitigation 2: Lower Learning Rate and Fewer Epochs

Python
from transformers import TrainingArguments

# Aggressive settings  HIGH forgetting risk
risky_args = TrainingArguments(
    output_dir="./risky",
    num_train_epochs=10,
    learning_rate=5e-5,         # too high for fine-tuning large models
    per_device_train_batch_size=1,
    warmup_ratio=0.0,           # no warmup  sharp weight shifts early
)

# Safe settings for domain fine-tuning
safe_args = TrainingArguments(
    output_dir="./safe",
    num_train_epochs=3,         # fewer epochs
    learning_rate=2e-5,         # lower LR
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    warmup_ratio=0.05,          # warmup  gradual weight shift at start
    lr_scheduler_type="cosine", # cosine decay  LR drops toward end
    weight_decay=0.01,          # regularization
)

# Rule of thumb for learning rates:
# Full fine-tuning: 1e-5 to 2e-5
# LoRA:             1e-4 to 3e-4
# QLoRA:            1e-4 to 2e-4

Mitigation 3: Replay Data

Mix a small fraction of general-purpose data into your fine-tuning dataset. The model continues to see general text, which prevents the weights from drifting too far toward your domain.

Python
from datasets import Dataset, concatenate_datasets
import random

def create_replay_dataset(
    domain_data: list[dict],
    general_data: list[dict],
    replay_fraction: float = 0.1,
    seed: int = 42,
) -> Dataset:
    """
    Mix domain-specific data with a fraction of general-purpose data.
    Replay fraction of 0.1 means 10% of training examples are general text.
    This dramatically reduces forgetting with minimal impact on domain performance.
    """
    random.seed(seed)

    num_replay = int(len(domain_data) * replay_fraction)
    replay_sample = random.sample(general_data, min(num_replay, len(general_data)))

    combined = domain_data + replay_sample
    random.shuffle(combined)

    return Dataset.from_list(combined)

# Example usage
drug_qa_data = [
    {"text": "Q: What is aspirin?\nA: Aspirin is an NSAID used for pain, fever..."},
    # ... hundreds more drug Q&A pairs
]

# General text from a public dataset (e.g., OpenHermes, Alpaca, etc.)
general_data = [
    {"text": "Q: What is the capital of Germany?\nA: Berlin."},
    {"text": "Q: Explain recursion.\nA: Recursion is when a function calls itself..."},
    # ... sample from general instruction-following data
]

mixed_dataset = create_replay_dataset(
    domain_data=drug_qa_data,
    general_data=general_data,
    replay_fraction=0.15  # 15% general text
)

print(f"Total training examples: {len(mixed_dataset)}")
print(f"Domain examples: {len(drug_qa_data)}")
print(f"Replay examples: {len(mixed_dataset) - len(drug_qa_data)}")

Mitigation 4: Elastic Weight Consolidation (EWC)

EWC is a research-grade technique that adds a regularization term to the loss that penalizes large changes to weights that were important for previous tasks.

Python
import torch
import torch.nn as nn
from copy import deepcopy

class EWCRegularizer:
    """
    Elastic Weight Consolidation: penalizes changing weights
    that were important for the pre-trained task.

    Note: Expensive to compute — requires a forward/backward pass
    over the pre-training task to estimate Fisher information.
    In practice, LoRA is simpler and achieves similar results.
    """

    def __init__(self, model: nn.Module, importance: float = 5000.0):
        self.importance = importance
        self.original_params = {}
        self.fisher = {}

        # Store original parameter values
        for name, param in model.named_parameters():
            self.original_params[name] = param.data.clone()
            self.fisher[name] = torch.zeros_like(param.data)

    def update_fisher(self, model: nn.Module, dataloader, num_batches: int = 100):
        """Estimate Fisher information from pre-training data."""
        model.eval()
        for i, batch in enumerate(dataloader):
            if i >= num_batches:
                break
            output = model(**batch)
            loss = output.loss
            loss.backward()

            for name, param in model.named_parameters():
                if param.grad is not None:
                    self.fisher[name] += param.grad.data.clone() ** 2

        # Normalize
        for name in self.fisher:
            self.fisher[name] /= num_batches

    def ewc_loss(self, model: nn.Module) -> torch.Tensor:
        """Compute EWC penalty term."""
        penalty = torch.tensor(0.0)
        for name, param in model.named_parameters():
            if name in self.fisher:
                diff = param - self.original_params[name].to(param.device)
                penalty += (self.fisher[name].to(param.device) * diff ** 2).sum()
        return self.importance * penalty

# In training loop:
# total_loss = task_loss + ewc.ewc_loss(model)
# total_loss.backward()

Detection Checklist

Run this checklist before deploying a fine-tuned model:

Python
FORGETTING_CHECKLIST = [
    {
        "test": "Basic instruction following",
        "prompt": "List three countries in Europe.",
        "check": lambda r: len([w for w in r.split() if w.istitle()]) >= 3,
        "severity": "critical",
    },
    {
        "test": "Multi-step reasoning",
        "prompt": "If I have 12 apples and give away 5, then buy 3 more, how many do I have?",
        "check": lambda r: "10" in r,
        "severity": "critical",
    },
    {
        "test": "Language coherence",
        "prompt": "Write one sentence about the weather.",
        "check": lambda r: len(r.split()) >= 5 and r.endswith("."),
        "severity": "high",
    },
    {
        "test": "Format compliance",
        "prompt": "Write a JSON object with keys 'name' and 'age'.",
        "check": lambda r: "{" in r and "name" in r and "age" in r,
        "severity": "medium",
    },
]

def run_forgetting_checklist(generate_fn) -> dict:
    """Run forgetting detection checklist. Pass a function that takes a prompt and returns text."""
    results = {"passed": 0, "failed": 0, "failures": []}
    for item in FORGETTING_CHECKLIST:
        response = generate_fn(item["prompt"])
        passed = item["check"](response)
        if passed:
            results["passed"] += 1
        else:
            results["failed"] += 1
            results["failures"].append({
                "test": item["test"],
                "severity": item["severity"],
                "response": response[:200],
            })
    results["score"] = results["passed"] / len(FORGETTING_CHECKLIST)
    return results

Summary

Catastrophic forgetting is a predictable, measurable, and preventable problem. The hierarchy of mitigations, from most to least recommended:

  1. Use LoRA — base weights frozen, forgetting is mathematically impossible via gradient updates to base weights
  2. Lower learning rate — weight changes are smaller, less drift from pre-trained state
  3. Fewer epochs — fewer updates, less opportunity to overwrite original weights
  4. Replay data — the model keeps seeing general text during fine-tuning
  5. EWC regularization — penalizes changes to important weights

For the vast majority of projects, LoRA alone solves the problem. Evaluate general capabilities before and after, and you will have quantitative evidence that your fine-tuned model is specialized without being broken.