Learnixo

RAG Systems · Lesson 19 of 24

MMR: Maximum Marginal Relevance for Diversity

The Redundancy Problem

Standard top-k retrieval returns the k most similar chunks — which often means k near-identical chunks:

Query: "Warfarin dosing in elderly patients"

Top-5 cosine similarity:
  1. "Elderly patients (>65) require lower warfarin doses..." (sim=0.94)
  2. "In patients over 65, warfarin dosing should be reduced..." (sim=0.93)
  3. "Warfarin doses in the elderly should be conservative..." (sim=0.91)
  4. "Consider reduced warfarin for older patients..." (sim=0.90)
  5. "Renal function affects warfarin in elderly patients." (sim=0.85)

Chunks 1–4 are near-identical — sending all 4 to the LLM wastes context window.
Chunk 5 adds new information (renal function) but scores lower.

MMR selects for both relevance AND diversity.


The MMR Formula

MMR(dᵢ) = argmax [ λ · Sim(dᵢ, query) - (1-λ) · max_{dⱼ ∈ S} Sim(dᵢ, dⱼ) ]

Where:
  S = set of already-selected documents
  Sim(dᵢ, query) = cosine similarity to query (relevance)
  max_{dⱼ ∈ S} Sim(dᵢ, dⱼ) = similarity to closest already-selected doc (redundancy)
  λ = trade-off parameter [0, 1]

  λ = 1.0: pure relevance (identical to top-k)
  λ = 0.5: equal weight on relevance and diversity (recommended starting point)
  λ = 0.0: pure diversity (maximum spread, ignores query relevance)

MMR is a greedy algorithm: iteratively pick the next document that maximises the formula.


Implementation

Python
import numpy as np
from numpy.linalg import norm

def cosine_sim(a: np.ndarray, b: np.ndarray) -> float:
    return float(np.dot(a, b) / (norm(a) * norm(b) + 1e-10))

def maximal_marginal_relevance(
    query_embedding: np.ndarray,
    candidate_embeddings: np.ndarray,
    candidates: list[dict],
    top_k: int = 5,
    lambda_param: float = 0.5,
) -> list[dict]:
    """
    candidates: list of dicts with "content", "metadata", etc.
    candidate_embeddings: shape (n, d)
    """
    if len(candidates) == 0:
        return []
    
    # Pre-compute query similarities
    query_sims = np.array([
        cosine_sim(query_embedding, emb)
        for emb in candidate_embeddings
    ])
    
    selected_indices = []
    remaining = list(range(len(candidates)))
    
    for _ in range(min(top_k, len(candidates))):
        if not selected_indices:
            # First selection: pick most similar to query
            best_idx = int(np.argmax(query_sims))
        else:
            # Subsequent selections: MMR trade-off
            selected_embs = candidate_embeddings[selected_indices]
            best_score = -np.inf
            best_idx = -1
            
            for idx in remaining:
                relevance = query_sims[idx]
                redundancy = max(
                    cosine_sim(candidate_embeddings[idx], selected_embs[j])
                    for j in range(len(selected_embs))
                )
                score = lambda_param * relevance - (1 - lambda_param) * redundancy
                if score > best_score:
                    best_score = score
                    best_idx = idx
        
        selected_indices.append(best_idx)
        remaining.remove(best_idx)
    
    return [candidates[i] for i in selected_indices]


# Usage in RAG pipeline
def retrieve_with_mmr(
    query: str,
    collection,
    embedder,
    top_k: int = 5,
    fetch_k: int = 20,      # fetch more, then diversify
    lambda_param: float = 0.5,
) -> list[dict]:
    query_embedding = embedder.encode([query])[0]
    
    # Fetch a large candidate pool
    results = collection.query(
        query_embeddings=[query_embedding.tolist()],
        n_results=fetch_k,
        include=["documents", "metadatas", "embeddings"],
    )
    
    candidates = [
        {"content": doc, "metadata": meta}
        for doc, meta in zip(results["documents"][0], results["metadatas"][0])
    ]
    candidate_embeddings = np.array(results["embeddings"][0])
    
    return maximal_marginal_relevance(
        query_embedding=query_embedding,
        candidate_embeddings=candidate_embeddings,
        candidates=candidates,
        top_k=top_k,
        lambda_param=lambda_param,
    )

LangChain MMR

Python
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings

embedding_fn = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embedding_fn)

# MMR retrieval built into LangChain
retriever = vectorstore.as_retriever(
    search_type="mmr",
    search_kwargs={
        "k": 5,           # final results to return
        "fetch_k": 20,    # candidates to score with MMR
        "lambda_mult": 0.5,  # LangChain uses lambda_mult (= λ above)
    }
)

docs = retriever.get_relevant_documents("Warfarin dosing in elderly patients")

When to Use MMR

Use MMR when:
  ✓ Knowledge base has many near-duplicate chunks (e.g., multiple versions
    of the same guideline, overlapping chunking with high overlap)
  ✓ The query topic is narrow but requires coverage from multiple angles
    (e.g., "warfarin" → dosing + monitoring + interactions + contraindications)
  ✓ Context window is tight — can't afford redundant chunks

Use standard top-k when:
  ✓ Knowledge base is well-curated with minimal redundancy
  ✓ The query requires the most precise single answer (not multiple perspectives)
  ✓ Latency is critical (MMR adds O(k × fetch_k) pairwise comparisons)

Lambda tuning:
  λ = 0.7–0.8: mostly relevance, slight diversity
  λ = 0.5:     balanced (recommended starting point)
  λ = 0.3:     strongly diverse — may sacrifice relevance

Interview Answer

"Maximal Marginal Relevance iteratively selects documents by maximising a combination of relevance to the query and dissimilarity to already-selected documents: MMR(dᵢ) = λ·Sim(dᵢ, query) - (1-λ)·max_j Sim(dᵢ, dⱼ). This prevents returning five near-identical chunks when the top cosine matches happen to be paraphrases of each other — a common problem when chunking with high overlap. I use MMR when the knowledge base has redundancy or when the query needs multi-faceted coverage. The practical recipe: fetch 20 candidates with cosine search, then MMR-select 5; lambda=0.5 is a good default."