Learnixo

RAG Systems · Lesson 20 of 24

Re-Ranking with Cross-Encoders

The Two-Stage Retrieval Pattern

Embedding-based retrieval (bi-encoder) is fast but imprecise — it embeds query and documents independently. Cross-encoder rerankers compare query and document jointly, capturing much richer relevance signals — but are too slow to rank all documents.

Solution: Two-stage pipeline:

  1. Stage 1 (Recall): Bi-encoder retrieves top-N candidates quickly (N = 50-200)
  2. Stage 2 (Precision): Cross-encoder reranks candidates, keeps top-K (K = 5-10)

The reranker operates on a small set, so its higher compute cost is acceptable.


Cross-Encoder Reranking

A cross-encoder takes (query, document) pairs as input and outputs a relevance score:

Python
from sentence_transformers import CrossEncoder
import numpy as np

def load_cross_encoder(model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2") -> CrossEncoder:
    """Load a cross-encoder model for reranking."""
    return CrossEncoder(model_name, max_length=512)


def rerank_with_cross_encoder(
    query: str,
    candidates: list[dict],
    cross_encoder: CrossEncoder,
    top_k: int = 5,
) -> list[dict]:
    """
    Rerank retrieved candidates using a cross-encoder.
    
    candidates: list of dicts with 'content' key
    Returns top_k candidates sorted by cross-encoder score.
    """
    # Create query-document pairs
    pairs = [(query, candidate["content"]) for candidate in candidates]

    # Score all pairs (can batch)
    scores = cross_encoder.predict(pairs)

    # Attach scores and sort
    for candidate, score in zip(candidates, scores):
        candidate["rerank_score"] = float(score)

    return sorted(candidates, key=lambda x: x["rerank_score"], reverse=True)[:top_k]


# Cross-encoder models comparison
CROSS_ENCODER_MODELS = {
    "cross-encoder/ms-marco-MiniLM-L-6-v2": {
        "speed": "fast",
        "quality": "good",
        "params": "22M",
        "notes": "Best speed/quality tradeoff for general use",
    },
    "cross-encoder/ms-marco-MiniLM-L-12-v2": {
        "speed": "medium",
        "quality": "better",
        "params": "33M",
        "notes": "2× slower, 5-10% better than L-6",
    },
    "cross-encoder/ms-marco-electra-base": {
        "speed": "slow",
        "quality": "best",
        "params": "110M",
        "notes": "Highest quality, 5× slower than MiniLM-L-6",
    },
    "BAAI/bge-reranker-large": {
        "speed": "slow",
        "quality": "excellent",
        "params": "560M",
        "notes": "Top BEIR benchmark scores",
    },
}


# Full two-stage retrieval pipeline
def two_stage_retrieval(
    query: str,
    query_embedding: list[float],
    retriever,
    cross_encoder: CrossEncoder,
    stage1_k: int = 50,     # Retrieve more in stage 1
    stage2_k: int = 5,      # Return fewer, high-quality in stage 2
) -> list[dict]:
    """Two-stage retrieval: fast search + cross-encoder reranking."""
    # Stage 1: Fast retrieval
    candidates = retriever.retrieve(query_embedding, top_k=stage1_k)

    # Stage 2: Cross-encoder reranking
    reranked = rerank_with_cross_encoder(
        query=query,
        candidates=candidates,
        cross_encoder=cross_encoder,
        top_k=stage2_k,
    )

    return reranked

Cohere Rerank API

Cohere provides a reranking API that doesn't require hosting a model:

Python
import cohere

co = cohere.Client("your-cohere-api-key")

def rerank_with_cohere(
    query: str,
    documents: list[str],
    top_k: int = 5,
    model: str = "rerank-english-v3.0",
) -> list[dict]:
    """Rerank documents using Cohere Rerank API."""
    response = co.rerank(
        model=model,
        query=query,
        documents=documents,
        top_n=top_k,
        return_documents=True,
    )

    return [
        {
            "content": result.document.text,
            "relevance_score": result.relevance_score,
            "original_rank": result.index,
        }
        for result in response.results
    ]


# Cohere models:
# rerank-english-v3.0: English, high quality
# rerank-multilingual-v3.0: Multiple languages
# rerank-english-v2.0: Older, slightly lower quality

def cohere_two_stage(
    query: str,
    query_embedding: list[float],
    retriever,
    stage1_k: int = 50,
    stage2_k: int = 5,
) -> list[dict]:
    """Two-stage retrieval with Cohere reranking."""
    # Stage 1: Vector retrieval
    candidates = retriever.retrieve(query_embedding, top_k=stage1_k)
    documents = [c["content"] for c in candidates]

    # Stage 2: Cohere reranking
    reranked = rerank_with_cohere(query, documents, top_k=stage2_k)

    # Merge with original metadata
    doc_map = {c["content"]: c for c in candidates}
    results = []
    for item in reranked:
        original = doc_map.get(item["content"], {})
        results.append({
            **original,
            "rerank_score": item["relevance_score"],
        })

    return results

LLM-as-Judge Reranking

For highest quality (and highest cost): use an LLM to assess relevance:

Python
from openai import OpenAI
import json

client = OpenAI()

RELEVANCE_JUDGE_PROMPT = """Rate how relevant this document is for answering the question.

Question: {question}

Document:
{document}

Rate relevance from 0-10:
- 0: Completely irrelevant
- 5: Partially relevant
- 10: Directly and completely addresses the question

Return JSON: {{"score": number, "reason": "one sentence"}}"""

def llm_rerank(
    query: str,
    candidates: list[dict],
    top_k: int = 5,
    judge_model: str = "gpt-4o-mini",
) -> list[dict]:
    """Rerank using LLM as relevance judge (expensive but high quality)."""
    scored = []

    for candidate in candidates:
        response = client.chat.completions.create(
            model=judge_model,
            messages=[
                {
                    "role": "user",
                    "content": RELEVANCE_JUDGE_PROMPT.format(
                        question=query,
                        document=candidate["content"][:1000],
                    ),
                }
            ],
            response_format={"type": "json_object"},
            temperature=0,
        )

        result = json.loads(response.choices[0].message.content)
        scored.append({
            **candidate,
            "llm_score": result.get("score", 5),
            "llm_reason": result.get("reason", ""),
        })

    return sorted(scored, key=lambda x: x["llm_score"], reverse=True)[:top_k]

Lost-in-the-Middle Reordering

After reranking, order documents to combat position bias in LLM attention:

Python
def optimal_document_ordering(reranked_docs: list[dict]) -> list[dict]:
    """
    Reorder documents to minimize lost-in-the-middle effect.
    
    LLMs attend better to content at the start and end of context.
    Strategy: put most relevant doc first, second most relevant last,
    and less relevant docs in the middle.
    
    For N=5 docs ranked 1-5:
    Position 0: rank 1 (best)
    Position 1: rank 3
    Position 2: rank 5
    Position 3: rank 4
    Position 4: rank 2 (second best — end of context)
    """
    n = len(reranked_docs)
    if n <= 2:
        return reranked_docs

    result = [None] * n
    front = 0
    back = n - 1

    for i, doc in enumerate(reranked_docs):
        if i % 2 == 0:
            result[front] = doc
            front += 1
        else:
            result[back] = doc
            back -= 1

    return result


# Full pipeline with position-aware ordering
def rag_retrieve_and_order(
    query: str,
    query_embedding: list[float],
    retriever,
    cross_encoder: CrossEncoder,
    stage1_k: int = 50,
    final_k: int = 5,
) -> list[dict]:
    """Complete retrieval with reranking and position optimization."""
    # Two-stage retrieval
    reranked = two_stage_retrieval(
        query=query,
        query_embedding=query_embedding,
        retriever=retriever,
        cross_encoder=cross_encoder,
        stage1_k=stage1_k,
        stage2_k=final_k,
    )

    # Optimal ordering for LLM context
    ordered = optimal_document_ordering(reranked)

    return ordered

Evaluating Reranking Quality

Python
def evaluate_reranker(
    test_cases: list[dict],
    retriever,
    reranker,
    embedding_fn,
) -> dict:
    """Compare retrieval with and without reranking."""
    metrics = {"with_rerank": [], "without_rerank": []}

    for case in test_cases:
        query = case["query"]
        relevant_ids = set(case["relevant_doc_ids"])

        query_emb = embedding_fn(query)

        # Without reranking
        raw_results = retriever.retrieve(query_emb, top_k=5)
        raw_ids = set(r["id"] for r in raw_results)
        metrics["without_rerank"].append(len(raw_ids & relevant_ids) / len(relevant_ids))

        # With reranking (stage1=50, stage2=5)
        candidates = retriever.retrieve(query_emb, top_k=50)
        reranked = reranker.rerank(query, candidates, top_k=5)
        reranked_ids = set(r["id"] for r in reranked)
        metrics["with_rerank"].append(len(reranked_ids & relevant_ids) / len(relevant_ids))

    return {
        "recall_at_5_raw": sum(metrics["without_rerank"]) / len(test_cases),
        "recall_at_5_reranked": sum(metrics["with_rerank"]) / len(test_cases),
        "improvement": (
            sum(metrics["with_rerank"]) - sum(metrics["without_rerank"])
        ) / len(test_cases),
    }

# Typical improvements from reranking:
# General domain: 5-15% recall improvement
# Technical domain: 10-20% (bigger gain when vocabulary is specialized)
# With HyDE: another 5-10% on top of reranking