LoRA Explained
A deep dive into Low-Rank Adaptation — the math behind it, what rank and alpha control, which layers to target, and a full working example with the PEFT library on Llama 3.
LoRA Explained
LoRA (Low-Rank Adaptation) is the dominant PEFT method for LLM fine-tuning. It is conceptually elegant, computationally cheap, and produces adapters small enough to share and swap like software plugins. Understanding it thoroughly will make you a better AI engineer.
The Core Idea
A pre-trained weight matrix W has shape (d_out × d_in). To update it during fine-tuning, you would normally compute the full gradient and update every element. For a 4096 × 4096 attention projection, that is 16 million parameters — per layer, repeated across 32 layers.
LoRA's insight: the change to the weight matrix (the delta) does not need to be full rank. The update lies in a low-dimensional subspace. So instead of learning a full (d_out × d_in) update, you learn two small matrices:
- A: shape (r × d_in) — the "down-projection"
- B: shape (d_out × r) — the "up-projection"
The effective weight update is: delta_W = B × A
Where r is the rank, typically 4 to 64. For r=16 and a 4096 × 4096 matrix:
Full update: 4096 × 4096 = 16,777,216 parameters
LoRA update: (4096 × 16) + (16 × 4096) = 131,072 parameters
Compression: 128× fewer parameters to trainThe Math
import torch
import torch.nn as nn
import math
class LoRALinear(nn.Module):
"""
A linear layer with LoRA adaptation.
Forward pass computes:
y = x @ W.T + (x @ A.T @ B.T) * (alpha / r)
where:
W is the original frozen weight
A is shape (r, d_in) — randomly initialized
B is shape (d_out, r) — initialized to zeros (so delta starts at 0)
alpha / r is the scaling factor
"""
def __init__(
self,
in_features: int,
out_features: int,
rank: int = 16,
alpha: float = 32.0,
dropout: float = 0.0,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.rank = rank
self.scaling = alpha / rank # typical: alpha=2*rank so scaling=2.0
# Original weight: frozen after loading from pre-trained model
self.weight = nn.Parameter(
torch.empty(out_features, in_features),
requires_grad=False,
)
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
# LoRA matrices: only these are trained
self.lora_A = nn.Parameter(
torch.empty(rank, in_features)
)
self.lora_B = nn.Parameter(
torch.zeros(out_features, rank) # zeros: delta starts at 0
)
# A initialized with Kaiming uniform — standard for weight matrices
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
self.lora_dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Frozen base output
base = nn.functional.linear(x, self.weight)
# LoRA delta output
lora_input = self.lora_dropout(x)
delta = nn.functional.linear(
nn.functional.linear(lora_input, self.lora_A),
self.lora_B
) * self.scaling
return base + delta
def merge_weights(self) -> nn.Linear:
"""
At inference time, merge LoRA into the base weight.
Result: identical output, zero extra computation.
"""
merged_weight = self.weight + (self.lora_B @ self.lora_A) * self.scaling
merged = nn.Linear(self.in_features, self.out_features, bias=False)
merged.weight = nn.Parameter(merged_weight)
return merged
# Verify math: LoRA layer starts as identity to the original weight
layer = LoRALinear(512, 512, rank=16, alpha=32)
x = torch.randn(4, 512)
with torch.no_grad():
output = layer(x)
base_only = nn.functional.linear(x, layer.weight)
# Initially, B is zeros → delta = 0 → output == base_only
assert torch.allclose(output, base_only, atol=1e-5), "Initial LoRA should be identity"
print("Verified: initial LoRA output equals base model output")Rank: The Most Important Hyperparameter
Rank r controls the expressiveness of the LoRA update. Lower rank means fewer trainable parameters and a simpler update; higher rank means more parameters and a more expressive update.
def lora_parameter_count(d_model: int, rank: int, num_target_layers: int) -> dict:
"""
Calculate total trainable parameters for LoRA.
For each targeted linear layer:
- lora_A: rank × d_in parameters
- lora_B: d_out × rank parameters
Assuming square projections (d_in == d_out == d_model).
"""
params_per_layer = rank * d_model + d_model * rank # A + B
total_lora_params = params_per_layer * num_target_layers
# Example: Llama 3 8B has 32 layers, each with q, k, v, o projections
total_base_params = 8_000_000_000
return {
"rank": rank,
"params_per_targeted_layer": params_per_layer,
"total_lora_params": total_lora_params,
"base_model_params": total_base_params,
"trainable_fraction": f"{total_lora_params / total_base_params:.4%}",
"lora_memory_mb": round(total_lora_params * 2 / (1024 ** 2), 1), # bfloat16
}
# d_model=4096 (Llama 3 8B), 32 layers × 2 target matrices (q_proj, v_proj)
for rank in [4, 8, 16, 32, 64]:
result = lora_parameter_count(4096, rank, num_target_layers=64)
print(
f"r={rank:2d}: {result['total_lora_params']:>10,} trainable params "
f"({result['trainable_fraction']}) | {result['lora_memory_mb']} MB"
)
# r= 4: 2,097,152 (0.0262%) | 4.0 MB
# r= 8: 4,194,304 (0.0524%) | 8.0 MB
# r=16: 8,388,608 (0.1049%) | 16.0 MB
# r=32: 16,777,216 (0.2097%) | 32.0 MB
# r=64: 33,554,432 (0.4194%) | 64.0 MBGuidance for rank selection:
| Task complexity | Recommended rank | |---|---| | Simple style transfer (tone, format) | 4–8 | | Domain knowledge adaptation | 16 | | Complex reasoning in new domain | 32 | | Near-full-fine-tune quality needed | 64–128 |
Alpha: The Scaling Factor
Alpha controls the magnitude of the LoRA update relative to the base model output.
# The scaling formula: effective_lora_weight = (alpha / rank) * B @ A
# Common convention: set alpha = 2 * rank
# This means the scaling factor is always 2.0, regardless of rank
examples = [
(4, 8, 2.0),
(8, 16, 2.0),
(16, 32, 2.0),
(32, 64, 2.0),
(16, 16, 1.0), # alpha == rank → scaling = 1.0 (weaker update)
(16, 64, 4.0), # alpha = 4*rank → scaling = 4.0 (stronger update)
]
for rank, alpha, expected_scale in examples:
scale = alpha / rank
print(f"rank={rank:2d}, alpha={alpha:2d} → scale={scale:.1f} (expected {expected_scale:.1f})")
# Intuition:
# - Higher alpha/rank ratio → larger LoRA contribution → model drifts more from base
# - If your task is very different from pre-training: use higher alpha
# - If you want subtle adaptation: keep alpha == rankWhich Layers to Target
Not all layers need LoRA. Research shows that targeting the attention query and value projections gives the best quality-per-parameter tradeoff for instruction following and domain adaptation.
from transformers import AutoModelForCausalLM
import torch
def inspect_model_linear_layers(model_name: str) -> dict:
"""List all linear layer names and shapes to decide LoRA targets."""
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="cpu"
)
linear_layers = {}
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
layer_type = name.split(".")[-1]
if layer_type not in linear_layers:
linear_layers[layer_type] = {
"count": 0,
"shape": (module.out_features, module.in_features)
}
linear_layers[layer_type]["count"] += 1
return linear_layers
# For Llama 3 architecture, the linear layers are:
llama_layers = {
"q_proj": {"count": 32, "shape": (4096, 4096)},
"k_proj": {"count": 32, "shape": (1024, 4096)}, # GQA: fewer K heads
"v_proj": {"count": 32, "shape": (1024, 4096)}, # GQA: fewer V heads
"o_proj": {"count": 32, "shape": (4096, 4096)},
"gate_proj": {"count": 32, "shape": (14336, 4096)},
"up_proj": {"count": 32, "shape": (14336, 4096)},
"down_proj": {"count": 32, "shape": (4096, 14336)},
}
# Common LoRA target configurations:
configurations = {
"minimal (q+v only)": ["q_proj", "v_proj"],
"attention only": ["q_proj", "k_proj", "v_proj", "o_proj"],
"full (attention + FFN)": ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
}
for config_name, target_modules in configurations.items():
params = sum(
llama_layers[m]["count"] * (
llama_layers[m]["shape"][0] * 16 + # lora_B
16 * llama_layers[m]["shape"][1] # lora_A
)
for m in target_modules if m in llama_layers
)
print(f"{config_name}: {params:>12,} params ({params/8e9:.4%} of 8B)")
# minimal: 8,388,608 (0.1049%)
# attention only: 20,971,520 (0.2621%)
# full: 83,886,080 (1.0486%)Full Working Example: Apply LoRA to Llama 3
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
)
from peft import LoraConfig, TaskType, get_peft_model
from trl import SFTTrainer
from datasets import Dataset
import torch
# ─── 1. Load base model ───────────────────────────────────────────────────────
model_name = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # important for causal LM training
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.config.use_cache = False # required when using gradient checkpointing
# ─── 2. Define LoRA configuration ────────────────────────────────────────────
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=16, # rank
lora_alpha=32, # scaling = 32/16 = 2.0
target_modules=[ # Llama 3 attention projections
"q_proj",
"k_proj",
"v_proj",
"o_proj",
],
lora_dropout=0.05,
bias="none",
use_rslora=False, # set True for rank-stabilized LoRA
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# trainable params: 10,485,760 || all params: 3,221,217,280 || trainable%: 0.3255%
# ─── 3. Prepare dataset ──────────────────────────────────────────────────────
def format_drug_example(example: dict) -> dict:
"""Apply Llama 3 chat template to a drug Q&A example."""
messages = [
{"role": "system", "content": "You are a pharmaceutical drug information specialist. Provide accurate, evidence-based drug information."},
{"role": "user", "content": example["question"]},
{"role": "assistant", "content": example["answer"]},
]
formatted = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
return {"text": formatted}
# Sample data (in production: load from your JSONL file)
raw_data = [
{
"question": "What is the mechanism of action of metformin?",
"answer": (
"Metformin is a biguanide that works primarily by activating AMP-activated "
"protein kinase (AMPK) in hepatocytes. This suppresses hepatic gluconeogenesis, "
"reducing fasting blood glucose. Secondary effects include improved insulin "
"sensitivity in peripheral tissues and modest reduction in intestinal glucose absorption."
)
},
{
"question": "What are the contraindications for metformin?",
"answer": (
"Metformin is contraindicated in: (1) eGFR below 30 mL/min/1.73m² due to risk of "
"lactic acidosis; (2) acute or chronic metabolic acidosis; (3) radiological contrast "
"administration (hold 48 hours peri-procedure); (4) hepatic impairment. Use with caution "
"when eGFR is between 30 and 45."
)
},
# Add hundreds more examples
]
dataset = Dataset.from_list([format_drug_example(ex) for ex in raw_data])
split = dataset.train_test_split(test_size=0.1, seed=42)
# ─── 4. Training arguments ───────────────────────────────────────────────────
training_args = TrainingArguments(
output_dir="./llama3-drug-lora",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4, # effective batch = 16
learning_rate=2e-4, # higher LR fine for LoRA
warmup_ratio=0.05,
lr_scheduler_type="cosine",
bf16=True,
gradient_checkpointing=True,
logging_steps=10,
save_strategy="epoch",
evaluation_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
report_to="none",
dataloader_pin_memory=False,
)
# ─── 5. Train ────────────────────────────────────────────────────────────────
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=split["train"],
eval_dataset=split["test"],
tokenizer=tokenizer,
dataset_text_field="text",
max_seq_length=1024,
packing=False,
)
trainer.train()
# ─── 6. Save LoRA adapter only (small!) ──────────────────────────────────────
model.save_pretrained("./llama3-drug-lora-adapter")
tokenizer.save_pretrained("./llama3-drug-lora-adapter")
# Adapter folder will contain:
# adapter_config.json (~1 KB)
# adapter_model.safetensors (~20 MB for r=16 on q,k,v,o)Loading and Using the LoRA Adapter
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3B-Instruct",
torch_dtype=torch.bfloat16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
# Load LoRA adapter on top
model = PeftModel.from_pretrained(base_model, "./llama3-drug-lora-adapter")
# Option A: Keep as PEFT model (adapter applied at runtime)
# Option B: Merge adapter into weights for faster inference
model = model.merge_and_unload() # produces a standard HF model with merged weights
# Generate
def ask_drug_question(question: str) -> str:
messages = [
{"role": "system", "content": "You are a pharmaceutical drug information specialist."},
{"role": "user", "content": question},
]
input_ids = tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True
).to(model.device)
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=300,
temperature=0.1,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
new_tokens = output_ids[0][input_ids.shape[1]:]
return tokenizer.decode(new_tokens, skip_special_tokens=True)
answer = ask_drug_question("Explain the drug-drug interaction between warfarin and aspirin.")
print(answer)RSLoRA: An Improvement on Standard LoRA
Standard LoRA scales by alpha/r. As rank increases, the effective scale of the adapter output stays constant only if alpha is also adjusted. RSLoRA (Rank-Stabilized LoRA) uses alpha/sqrt(r) instead, providing more stable training across different ranks.
from peft import LoraConfig
# Standard LoRA scaling: alpha / rank
standard_lora = LoraConfig(
r=16, lora_alpha=32,
use_rslora=False, # scaling = 32/16 = 2.0
target_modules=["q_proj", "v_proj"],
task_type="CAUSAL_LM",
)
# RSLoRA: alpha / sqrt(rank) — more stable when rank > 16
rslora = LoraConfig(
r=64, lora_alpha=32,
use_rslora=True, # scaling = 32 / sqrt(64) = 4.0 (not 0.5)
target_modules=["q_proj", "v_proj"],
task_type="CAUSAL_LM",
)
# With RSLoRA, you can safely use high ranks without the adapter output becoming tinySummary
LoRA is elegant because it reduces a high-dimensional weight update to a product of two low-rank matrices. The key parameters:
- r (rank): Start with 16. Increase to 32 or 64 for complex tasks.
- alpha: Set to 2× rank as a default. Adjust if the model over- or under-adapts.
- target_modules: Start with q_proj and v_proj. Add k_proj and o_proj for more capacity. Add FFN layers (gate_proj, up_proj, down_proj) for the most capacity.
- dropout: Use 0.05 to 0.1 for small datasets, 0.0 for large ones.
The adapter is tiny (under 50 MB for most configurations), loads in seconds, and can be merged into the base model at inference time for zero overhead.
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.