Fine-Tuning LLMs · Lesson 4 of 16
Catastrophic Forgetting: What It Is and How to Avoid It
Catastrophic Forgetting
Catastrophic forgetting is the phenomenon where a neural network trained on a new task loses its ability to perform previously learned tasks. In the context of LLM fine-tuning, it means your carefully fine-tuned drug-information model can no longer write a grammatical sentence, follow multi-step instructions, or answer basic reasoning questions.
What It Is
When you fine-tune a language model on a narrow dataset, the gradient updates push all the model's weights toward patterns in your training data. If your training data is a small, homogeneous set (e.g., 500 Q&A pairs about medications), the model's weights drift significantly from the initialization that encoded broad language understanding.
The result: a model that is excellent at your narrow task but broken everywhere else.
# Conceptual demonstration: catastrophic forgetting in a small network
import torch
import torch.nn as nn
import torch.optim as optim
# Tiny network representing a "language model" with two capabilities
class TinyModel(nn.Module):
def __init__(self):
super().__init__()
self.shared = nn.Linear(10, 20) # shared representation
self.head_a = nn.Linear(20, 1) # Task A: general reasoning
self.head_b = nn.Linear(20, 1) # Task B: domain task
def forward(self, x, task="a"):
features = torch.relu(self.shared(x))
if task == "a":
return self.head_a(features)
return self.head_b(features)
model = TinyModel()
optimizer = optim.SGD(model.parameters(), lr=0.1)
# Phase 1: Pre-train on Task A (general language)
task_a_data = [(torch.randn(10), torch.tensor([1.0])) for _ in range(100)]
for x, y in task_a_data:
loss = nn.MSELoss()(model(x, task="a"), y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Evaluate Task A performance before fine-tuning
with torch.no_grad():
task_a_loss_before = sum(
nn.MSELoss()(model(x, task="a"), y).item()
for x, y in task_a_data[:20]
) / 20
print(f"Task A loss before fine-tuning: {task_a_loss_before:.4f}")
# Phase 2: Fine-tune aggressively on Task B (narrow domain)
# High learning rate, many epochs — catastrophic forgetting conditions
task_b_data = [(torch.randn(10) * 0.01, torch.tensor([5.0])) for _ in range(50)]
# Note: Task B has a VERY different distribution (small inputs, large targets)
optimizer_ft = optim.SGD(model.parameters(), lr=0.5) # much higher LR
for epoch in range(50):
for x, y in task_b_data:
loss = nn.MSELoss()(model(x, task="b"), y)
loss.backward()
optimizer_ft.step()
optimizer_ft.zero_grad()
# Evaluate Task A performance AFTER fine-tuning on Task B
with torch.no_grad():
task_a_loss_after = sum(
nn.MSELoss()(model(x, task="a"), y).item()
for x, y in task_a_data[:20]
) / 20
print(f"Task A loss after fine-tuning: {task_a_loss_after:.4f}")
print(f"Forgetting ratio: {task_a_loss_after / task_a_loss_before:.1f}x worse")
# Typical output with high LR: 15x worse — catastrophic forgettingWhy It Happens
The mechanism is gradient descent. Each gradient step moves parameters toward minimizing loss on the current batch of data. The optimizer has no memory of what the original weights were "good for" — it simply follows the gradient downhill on your new loss surface.
# Visualizing why gradients overwrite old knowledge
import torch
# Suppose the optimal weight for Task A (general language) is approximately 1.0
# Suppose the optimal weight for Task B (drug domain) is approximately -2.0
# A single shared weight can't be both
w = torch.tensor(1.0, requires_grad=True) # pre-trained value (Task A optimal)
def task_a_loss(w):
return (w - 1.0) ** 2 # loss is 0 when w = 1.0 (Task A optimum)
def task_b_loss(w):
return (w + 2.0) ** 2 # loss is 0 when w = -2.0 (Task B optimum)
print(f"Initial w: {w.item():.2f}")
print(f"Task A loss at w=1.0: {task_a_loss(w).item():.2f}") # 0.0 — perfect
print(f"Task B loss at w=1.0: {task_b_loss(w).item():.2f}") # 9.0 — terrible
# Fine-tune on Task B with high learning rate (10 steps)
optimizer = torch.optim.SGD([w], lr=0.5)
for step in range(10):
loss = task_b_loss(w)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f"\nAfter fine-tuning on Task B:")
print(f"w: {w.item():.2f}")
print(f"Task A loss: {task_a_loss(w).item():.2f}") # High — forgot Task A
print(f"Task B loss: {task_b_loss(w).item():.2f}") # Low — learned Task B
# w moved from 1.0 toward -2.0, breaking Task A performanceHow to Detect Catastrophic Forgetting
You need to benchmark the model on general tasks both before and after fine-tuning. If general performance drops significantly, forgetting has occurred.
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import json
def evaluate_general_qa(model_name_or_path: str, test_questions: list[dict]) -> dict:
"""
Evaluate a model on general-purpose questions to detect forgetting.
test_questions format:
[{"question": "...", "correct_answer": "...", "category": "..."}]
"""
generator = pipeline(
"text-generation",
model=model_name_or_path,
device_map="auto",
max_new_tokens=50,
)
results = {"correct": 0, "total": 0, "by_category": {}}
for item in test_questions:
prompt = f"Question: {item['question']}\nAnswer:"
output = generator(prompt, temperature=0.0, do_sample=False)[0]["generated_text"]
answer = output[len(prompt):].strip().lower()
is_correct = item["correct_answer"].lower() in answer
results["correct"] += int(is_correct)
results["total"] += 1
cat = item["category"]
if cat not in results["by_category"]:
results["by_category"][cat] = {"correct": 0, "total": 0}
results["by_category"][cat]["correct"] += int(is_correct)
results["by_category"][cat]["total"] += 1
results["overall_accuracy"] = results["correct"] / results["total"]
for cat in results["by_category"]:
d = results["by_category"][cat]
d["accuracy"] = d["correct"] / d["total"]
return results
# General knowledge test suite (abbreviated)
general_qa_tests = [
{"question": "What is the capital of France?", "correct_answer": "Paris", "category": "geography"},
{"question": "What is 15 multiplied by 7?", "correct_answer": "105", "category": "math"},
{"question": "Who wrote Hamlet?", "correct_answer": "Shakespeare", "category": "literature"},
{"question": "What gas do plants absorb from the air?", "correct_answer": "carbon dioxide", "category": "science"},
{"question": "What year did World War II end?", "correct_answer": "1945", "category": "history"},
]
# Run before and after fine-tuning
# before_results = evaluate_general_qa("meta-llama/Llama-3.2-3B-Instruct", general_qa_tests)
# after_results = evaluate_general_qa("./my-fine-tuned-model", general_qa_tests)
# Compare
def compare_forgetting(before: dict, after: dict, threshold: float = 0.10) -> dict:
"""Identify if fine-tuning caused significant forgetting."""
drop = before["overall_accuracy"] - after["overall_accuracy"]
forgot = drop > threshold
category_drops = {}
for cat in before["by_category"]:
if cat in after["by_category"]:
cat_drop = (
before["by_category"][cat]["accuracy"] -
after["by_category"][cat]["accuracy"]
)
category_drops[cat] = round(cat_drop, 3)
return {
"overall_accuracy_before": round(before["overall_accuracy"], 3),
"overall_accuracy_after": round(after["overall_accuracy"], 3),
"accuracy_drop": round(drop, 3),
"catastrophic_forgetting_detected": forgot,
"category_drops": category_drops,
"recommendation": (
"Reduce learning rate and epochs, or switch to LoRA" if forgot
else "Fine-tuning preserved general capabilities"
)
}Mitigation 1: Use LoRA
LoRA is the most effective mitigation. Because the original weights remain frozen, they cannot be overwritten. Only the small adapter matrices change.
from peft import get_peft_model, LoraConfig, TaskType
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3B-Instruct",
torch_dtype=torch.bfloat16,
device_map="auto"
)
# The key property: base weights are frozen
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.1,
bias="none",
)
peft_model = get_peft_model(model, lora_config)
# Verify: base model weights are frozen
for name, param in peft_model.named_parameters():
if "lora" not in name:
assert not param.requires_grad, f"Base weight {name} should be frozen!"
# LoRA cannot cause catastrophic forgetting of base weights
# because the base weights are NEVER UPDATED
# The model learns: output = frozen_W(x) + adapter(x)
# Removing the adapter restores the original model exactly
print("All base weights confirmed frozen. Catastrophic forgetting impossible via this path.")Mitigation 2: Lower Learning Rate and Fewer Epochs
from transformers import TrainingArguments
# Aggressive settings — HIGH forgetting risk
risky_args = TrainingArguments(
output_dir="./risky",
num_train_epochs=10,
learning_rate=5e-5, # too high for fine-tuning large models
per_device_train_batch_size=1,
warmup_ratio=0.0, # no warmup — sharp weight shifts early
)
# Safe settings for domain fine-tuning
safe_args = TrainingArguments(
output_dir="./safe",
num_train_epochs=3, # fewer epochs
learning_rate=2e-5, # lower LR
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_ratio=0.05, # warmup — gradual weight shift at start
lr_scheduler_type="cosine", # cosine decay — LR drops toward end
weight_decay=0.01, # regularization
)
# Rule of thumb for learning rates:
# Full fine-tuning: 1e-5 to 2e-5
# LoRA: 1e-4 to 3e-4
# QLoRA: 1e-4 to 2e-4Mitigation 3: Replay Data
Mix a small fraction of general-purpose data into your fine-tuning dataset. The model continues to see general text, which prevents the weights from drifting too far toward your domain.
from datasets import Dataset, concatenate_datasets
import random
def create_replay_dataset(
domain_data: list[dict],
general_data: list[dict],
replay_fraction: float = 0.1,
seed: int = 42,
) -> Dataset:
"""
Mix domain-specific data with a fraction of general-purpose data.
Replay fraction of 0.1 means 10% of training examples are general text.
This dramatically reduces forgetting with minimal impact on domain performance.
"""
random.seed(seed)
num_replay = int(len(domain_data) * replay_fraction)
replay_sample = random.sample(general_data, min(num_replay, len(general_data)))
combined = domain_data + replay_sample
random.shuffle(combined)
return Dataset.from_list(combined)
# Example usage
drug_qa_data = [
{"text": "Q: What is aspirin?\nA: Aspirin is an NSAID used for pain, fever..."},
# ... hundreds more drug Q&A pairs
]
# General text from a public dataset (e.g., OpenHermes, Alpaca, etc.)
general_data = [
{"text": "Q: What is the capital of Germany?\nA: Berlin."},
{"text": "Q: Explain recursion.\nA: Recursion is when a function calls itself..."},
# ... sample from general instruction-following data
]
mixed_dataset = create_replay_dataset(
domain_data=drug_qa_data,
general_data=general_data,
replay_fraction=0.15 # 15% general text
)
print(f"Total training examples: {len(mixed_dataset)}")
print(f"Domain examples: {len(drug_qa_data)}")
print(f"Replay examples: {len(mixed_dataset) - len(drug_qa_data)}")Mitigation 4: Elastic Weight Consolidation (EWC)
EWC is a research-grade technique that adds a regularization term to the loss that penalizes large changes to weights that were important for previous tasks.
import torch
import torch.nn as nn
from copy import deepcopy
class EWCRegularizer:
"""
Elastic Weight Consolidation: penalizes changing weights
that were important for the pre-trained task.
Note: Expensive to compute — requires a forward/backward pass
over the pre-training task to estimate Fisher information.
In practice, LoRA is simpler and achieves similar results.
"""
def __init__(self, model: nn.Module, importance: float = 5000.0):
self.importance = importance
self.original_params = {}
self.fisher = {}
# Store original parameter values
for name, param in model.named_parameters():
self.original_params[name] = param.data.clone()
self.fisher[name] = torch.zeros_like(param.data)
def update_fisher(self, model: nn.Module, dataloader, num_batches: int = 100):
"""Estimate Fisher information from pre-training data."""
model.eval()
for i, batch in enumerate(dataloader):
if i >= num_batches:
break
output = model(**batch)
loss = output.loss
loss.backward()
for name, param in model.named_parameters():
if param.grad is not None:
self.fisher[name] += param.grad.data.clone() ** 2
# Normalize
for name in self.fisher:
self.fisher[name] /= num_batches
def ewc_loss(self, model: nn.Module) -> torch.Tensor:
"""Compute EWC penalty term."""
penalty = torch.tensor(0.0)
for name, param in model.named_parameters():
if name in self.fisher:
diff = param - self.original_params[name].to(param.device)
penalty += (self.fisher[name].to(param.device) * diff ** 2).sum()
return self.importance * penalty
# In training loop:
# total_loss = task_loss + ewc.ewc_loss(model)
# total_loss.backward()Detection Checklist
Run this checklist before deploying a fine-tuned model:
FORGETTING_CHECKLIST = [
{
"test": "Basic instruction following",
"prompt": "List three countries in Europe.",
"check": lambda r: len([w for w in r.split() if w.istitle()]) >= 3,
"severity": "critical",
},
{
"test": "Multi-step reasoning",
"prompt": "If I have 12 apples and give away 5, then buy 3 more, how many do I have?",
"check": lambda r: "10" in r,
"severity": "critical",
},
{
"test": "Language coherence",
"prompt": "Write one sentence about the weather.",
"check": lambda r: len(r.split()) >= 5 and r.endswith("."),
"severity": "high",
},
{
"test": "Format compliance",
"prompt": "Write a JSON object with keys 'name' and 'age'.",
"check": lambda r: "{" in r and "name" in r and "age" in r,
"severity": "medium",
},
]
def run_forgetting_checklist(generate_fn) -> dict:
"""Run forgetting detection checklist. Pass a function that takes a prompt and returns text."""
results = {"passed": 0, "failed": 0, "failures": []}
for item in FORGETTING_CHECKLIST:
response = generate_fn(item["prompt"])
passed = item["check"](response)
if passed:
results["passed"] += 1
else:
results["failed"] += 1
results["failures"].append({
"test": item["test"],
"severity": item["severity"],
"response": response[:200],
})
results["score"] = results["passed"] / len(FORGETTING_CHECKLIST)
return resultsSummary
Catastrophic forgetting is a predictable, measurable, and preventable problem. The hierarchy of mitigations, from most to least recommended:
- Use LoRA — base weights frozen, forgetting is mathematically impossible via gradient updates to base weights
- Lower learning rate — weight changes are smaller, less drift from pre-trained state
- Fewer epochs — fewer updates, less opportunity to overwrite original weights
- Replay data — the model keeps seeing general text during fine-tuning
- EWC regularization — penalizes changes to important weights
For the vast majority of projects, LoRA alone solves the problem. Evaluate general capabilities before and after, and you will have quantitative evidence that your fine-tuned model is specialized without being broken.