Learnixo
Back to blog
AI Systemsintermediate

LoRA Explained

A deep dive into Low-Rank Adaptation — the math behind it, what rank and alpha control, which layers to target, and a full working example with the PEFT library on Llama 3.

Asma Hafeez KhanMay 15, 202610 min read
LoRAFine-TuningPEFTLLMLlamaAttention
Share:𝕏

LoRA Explained

LoRA (Low-Rank Adaptation) is the dominant PEFT method for LLM fine-tuning. It is conceptually elegant, computationally cheap, and produces adapters small enough to share and swap like software plugins. Understanding it thoroughly will make you a better AI engineer.


The Core Idea

A pre-trained weight matrix W has shape (d_out × d_in). To update it during fine-tuning, you would normally compute the full gradient and update every element. For a 4096 × 4096 attention projection, that is 16 million parameters — per layer, repeated across 32 layers.

LoRA's insight: the change to the weight matrix (the delta) does not need to be full rank. The update lies in a low-dimensional subspace. So instead of learning a full (d_out × d_in) update, you learn two small matrices:

  • A: shape (r × d_in) — the "down-projection"
  • B: shape (d_out × r) — the "up-projection"

The effective weight update is: delta_W = B × A

Where r is the rank, typically 4 to 64. For r=16 and a 4096 × 4096 matrix:

Full update: 4096 × 4096 = 16,777,216 parameters
LoRA update: (4096 × 16) + (16 × 4096) = 131,072 parameters
Compression: 128× fewer parameters to train

The Math

Python
import torch
import torch.nn as nn
import math

class LoRALinear(nn.Module):
    """
    A linear layer with LoRA adaptation.

    Forward pass computes:
      y = x @ W.T + (x @ A.T @ B.T) * (alpha / r)

    where:
      W is the original frozen weight
      A is shape (r, d_in)   — randomly initialized
      B is shape (d_out, r)  — initialized to zeros (so delta starts at 0)
      alpha / r is the scaling factor
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        rank: int = 16,
        alpha: float = 32.0,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.scaling = alpha / rank  # typical: alpha=2*rank so scaling=2.0

        # Original weight: frozen after loading from pre-trained model
        self.weight = nn.Parameter(
            torch.empty(out_features, in_features),
            requires_grad=False,
        )
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

        # LoRA matrices: only these are trained
        self.lora_A = nn.Parameter(
            torch.empty(rank, in_features)
        )
        self.lora_B = nn.Parameter(
            torch.zeros(out_features, rank)   # zeros: delta starts at 0
        )

        # A initialized with Kaiming uniform  standard for weight matrices
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))

        self.lora_dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Frozen base output
        base = nn.functional.linear(x, self.weight)

        # LoRA delta output
        lora_input = self.lora_dropout(x)
        delta = nn.functional.linear(
            nn.functional.linear(lora_input, self.lora_A),
            self.lora_B
        ) * self.scaling

        return base + delta

    def merge_weights(self) -> nn.Linear:
        """
        At inference time, merge LoRA into the base weight.
        Result: identical output, zero extra computation.
        """
        merged_weight = self.weight + (self.lora_B @ self.lora_A) * self.scaling
        merged = nn.Linear(self.in_features, self.out_features, bias=False)
        merged.weight = nn.Parameter(merged_weight)
        return merged

# Verify math: LoRA layer starts as identity to the original weight
layer = LoRALinear(512, 512, rank=16, alpha=32)
x = torch.randn(4, 512)

with torch.no_grad():
    output = layer(x)
    base_only = nn.functional.linear(x, layer.weight)
    # Initially, B is zeros  delta = 0  output == base_only
    assert torch.allclose(output, base_only, atol=1e-5), "Initial LoRA should be identity"
    print("Verified: initial LoRA output equals base model output")

Rank: The Most Important Hyperparameter

Rank r controls the expressiveness of the LoRA update. Lower rank means fewer trainable parameters and a simpler update; higher rank means more parameters and a more expressive update.

Python
def lora_parameter_count(d_model: int, rank: int, num_target_layers: int) -> dict:
    """
    Calculate total trainable parameters for LoRA.

    For each targeted linear layer:
      - lora_A: rank × d_in parameters
      - lora_B: d_out × rank parameters

    Assuming square projections (d_in == d_out == d_model).
    """
    params_per_layer = rank * d_model + d_model * rank  # A + B
    total_lora_params = params_per_layer * num_target_layers

    # Example: Llama 3 8B has 32 layers, each with q, k, v, o projections
    total_base_params = 8_000_000_000

    return {
        "rank": rank,
        "params_per_targeted_layer": params_per_layer,
        "total_lora_params": total_lora_params,
        "base_model_params": total_base_params,
        "trainable_fraction": f"{total_lora_params / total_base_params:.4%}",
        "lora_memory_mb": round(total_lora_params * 2 / (1024 ** 2), 1),  # bfloat16
    }

# d_model=4096 (Llama 3 8B), 32 layers × 2 target matrices (q_proj, v_proj)
for rank in [4, 8, 16, 32, 64]:
    result = lora_parameter_count(4096, rank, num_target_layers=64)
    print(
        f"r={rank:2d}: {result['total_lora_params']:>10,} trainable params "
        f"({result['trainable_fraction']}) | {result['lora_memory_mb']} MB"
    )

# r= 4:     2,097,152 (0.0262%) |  4.0 MB
# r= 8:     4,194,304 (0.0524%) |  8.0 MB
# r=16:     8,388,608 (0.1049%) | 16.0 MB
# r=32:    16,777,216 (0.2097%) | 32.0 MB
# r=64:    33,554,432 (0.4194%) | 64.0 MB

Guidance for rank selection:

| Task complexity | Recommended rank | |---|---| | Simple style transfer (tone, format) | 4–8 | | Domain knowledge adaptation | 16 | | Complex reasoning in new domain | 32 | | Near-full-fine-tune quality needed | 64–128 |


Alpha: The Scaling Factor

Alpha controls the magnitude of the LoRA update relative to the base model output.

Python
# The scaling formula: effective_lora_weight = (alpha / rank) * B @ A
# Common convention: set alpha = 2 * rank
# This means the scaling factor is always 2.0, regardless of rank

examples = [
    (4,   8,   2.0),
    (8,  16,   2.0),
    (16, 32,   2.0),
    (32, 64,   2.0),
    (16, 16,   1.0),  # alpha == rank  scaling = 1.0 (weaker update)
    (16, 64,   4.0),  # alpha = 4*rank  scaling = 4.0 (stronger update)
]

for rank, alpha, expected_scale in examples:
    scale = alpha / rank
    print(f"rank={rank:2d}, alpha={alpha:2d} → scale={scale:.1f} (expected {expected_scale:.1f})")

# Intuition:
# - Higher alpha/rank ratio  larger LoRA contribution  model drifts more from base
# - If your task is very different from pre-training: use higher alpha
# - If you want subtle adaptation: keep alpha == rank

Which Layers to Target

Not all layers need LoRA. Research shows that targeting the attention query and value projections gives the best quality-per-parameter tradeoff for instruction following and domain adaptation.

Python
from transformers import AutoModelForCausalLM
import torch

def inspect_model_linear_layers(model_name: str) -> dict:
    """List all linear layer names and shapes to decide LoRA targets."""
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map="cpu"
    )

    linear_layers = {}
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            layer_type = name.split(".")[-1]
            if layer_type not in linear_layers:
                linear_layers[layer_type] = {
                    "count": 0,
                    "shape": (module.out_features, module.in_features)
                }
            linear_layers[layer_type]["count"] += 1

    return linear_layers

# For Llama 3 architecture, the linear layers are:
llama_layers = {
    "q_proj":    {"count": 32, "shape": (4096, 4096)},
    "k_proj":    {"count": 32, "shape": (1024, 4096)},  # GQA: fewer K heads
    "v_proj":    {"count": 32, "shape": (1024, 4096)},  # GQA: fewer V heads
    "o_proj":    {"count": 32, "shape": (4096, 4096)},
    "gate_proj": {"count": 32, "shape": (14336, 4096)},
    "up_proj":   {"count": 32, "shape": (14336, 4096)},
    "down_proj": {"count": 32, "shape": (4096, 14336)},
}

# Common LoRA target configurations:
configurations = {
    "minimal (q+v only)": ["q_proj", "v_proj"],
    "attention only": ["q_proj", "k_proj", "v_proj", "o_proj"],
    "full (attention + FFN)": ["q_proj", "k_proj", "v_proj", "o_proj",
                               "gate_proj", "up_proj", "down_proj"],
}

for config_name, target_modules in configurations.items():
    params = sum(
        llama_layers[m]["count"] * (
            llama_layers[m]["shape"][0] * 16 +  # lora_B
            16 * llama_layers[m]["shape"][1]     # lora_A
        )
        for m in target_modules if m in llama_layers
    )
    print(f"{config_name}: {params:>12,} params ({params/8e9:.4%} of 8B)")

# minimal:         8,388,608 (0.1049%)
# attention only: 20,971,520 (0.2621%)
# full:           83,886,080 (1.0486%)

Full Working Example: Apply LoRA to Llama 3

Python
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
)
from peft import LoraConfig, TaskType, get_peft_model
from trl import SFTTrainer
from datasets import Dataset
import torch

# ─── 1. Load base model ───────────────────────────────────────────────────────
model_name = "meta-llama/Llama-3.2-3B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"   # important for causal LM training

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
model.config.use_cache = False       # required when using gradient checkpointing

# ─── 2. Define LoRA configuration ────────────────────────────────────────────
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=16,                            # rank
    lora_alpha=32,                   # scaling = 32/16 = 2.0
    target_modules=[                 # Llama 3 attention projections
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
    ],
    lora_dropout=0.05,
    bias="none",
    use_rslora=False,                # set True for rank-stabilized LoRA
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# trainable params: 10,485,760 || all params: 3,221,217,280 || trainable%: 0.3255%

# ─── 3. Prepare dataset ──────────────────────────────────────────────────────
def format_drug_example(example: dict) -> dict:
    """Apply Llama 3 chat template to a drug Q&A example."""
    messages = [
        {"role": "system", "content": "You are a pharmaceutical drug information specialist. Provide accurate, evidence-based drug information."},
        {"role": "user", "content": example["question"]},
        {"role": "assistant", "content": example["answer"]},
    ]
    formatted = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False,
    )
    return {"text": formatted}

# Sample data (in production: load from your JSONL file)
raw_data = [
    {
        "question": "What is the mechanism of action of metformin?",
        "answer": (
            "Metformin is a biguanide that works primarily by activating AMP-activated "
            "protein kinase (AMPK) in hepatocytes. This suppresses hepatic gluconeogenesis, "
            "reducing fasting blood glucose. Secondary effects include improved insulin "
            "sensitivity in peripheral tissues and modest reduction in intestinal glucose absorption."
        )
    },
    {
        "question": "What are the contraindications for metformin?",
        "answer": (
            "Metformin is contraindicated in: (1) eGFR below 30 mL/min/1.73m² due to risk of "
            "lactic acidosis; (2) acute or chronic metabolic acidosis; (3) radiological contrast "
            "administration (hold 48 hours peri-procedure); (4) hepatic impairment. Use with caution "
            "when eGFR is between 30 and 45."
        )
    },
    # Add hundreds more examples
]

dataset = Dataset.from_list([format_drug_example(ex) for ex in raw_data])
split = dataset.train_test_split(test_size=0.1, seed=42)

# ─── 4. Training arguments ───────────────────────────────────────────────────
training_args = TrainingArguments(
    output_dir="./llama3-drug-lora",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,    # effective batch = 16
    learning_rate=2e-4,               # higher LR fine for LoRA
    warmup_ratio=0.05,
    lr_scheduler_type="cosine",
    bf16=True,
    gradient_checkpointing=True,
    logging_steps=10,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    report_to="none",
    dataloader_pin_memory=False,
)

# ─── 5. Train ────────────────────────────────────────────────────────────────
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=split["train"],
    eval_dataset=split["test"],
    tokenizer=tokenizer,
    dataset_text_field="text",
    max_seq_length=1024,
    packing=False,
)

trainer.train()

# ─── 6. Save LoRA adapter only (small!) ──────────────────────────────────────
model.save_pretrained("./llama3-drug-lora-adapter")
tokenizer.save_pretrained("./llama3-drug-lora-adapter")

# Adapter folder will contain:
# adapter_config.json  (~1 KB)
# adapter_model.safetensors  (~20 MB for r=16 on q,k,v,o)

Loading and Using the LoRA Adapter

Python
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-3B-Instruct",
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")

# Load LoRA adapter on top
model = PeftModel.from_pretrained(base_model, "./llama3-drug-lora-adapter")

# Option A: Keep as PEFT model (adapter applied at runtime)
# Option B: Merge adapter into weights for faster inference
model = model.merge_and_unload()  # produces a standard HF model with merged weights

# Generate
def ask_drug_question(question: str) -> str:
    messages = [
        {"role": "system", "content": "You are a pharmaceutical drug information specialist."},
        {"role": "user", "content": question},
    ]
    input_ids = tokenizer.apply_chat_template(
        messages,
        return_tensors="pt",
        add_generation_prompt=True
    ).to(model.device)

    with torch.no_grad():
        output_ids = model.generate(
            input_ids,
            max_new_tokens=300,
            temperature=0.1,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )

    new_tokens = output_ids[0][input_ids.shape[1]:]
    return tokenizer.decode(new_tokens, skip_special_tokens=True)

answer = ask_drug_question("Explain the drug-drug interaction between warfarin and aspirin.")
print(answer)

RSLoRA: An Improvement on Standard LoRA

Standard LoRA scales by alpha/r. As rank increases, the effective scale of the adapter output stays constant only if alpha is also adjusted. RSLoRA (Rank-Stabilized LoRA) uses alpha/sqrt(r) instead, providing more stable training across different ranks.

Python
from peft import LoraConfig

# Standard LoRA scaling: alpha / rank
standard_lora = LoraConfig(
    r=16, lora_alpha=32,
    use_rslora=False,   # scaling = 32/16 = 2.0
    target_modules=["q_proj", "v_proj"],
    task_type="CAUSAL_LM",
)

# RSLoRA: alpha / sqrt(rank)  more stable when rank > 16
rslora = LoraConfig(
    r=64, lora_alpha=32,
    use_rslora=True,    # scaling = 32 / sqrt(64) = 4.0 (not 0.5)
    target_modules=["q_proj", "v_proj"],
    task_type="CAUSAL_LM",
)
# With RSLoRA, you can safely use high ranks without the adapter output becoming tiny

Summary

LoRA is elegant because it reduces a high-dimensional weight update to a product of two low-rank matrices. The key parameters:

  • r (rank): Start with 16. Increase to 32 or 64 for complex tasks.
  • alpha: Set to 2× rank as a default. Adjust if the model over- or under-adapts.
  • target_modules: Start with q_proj and v_proj. Add k_proj and o_proj for more capacity. Add FFN layers (gate_proj, up_proj, down_proj) for the most capacity.
  • dropout: Use 0.05 to 0.1 for small datasets, 0.0 for large ones.

The adapter is tiny (under 50 MB for most configurations), loads in seconds, and can be merged into the base model at inference time for zero overhead.

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.