Learnixo

Live Coding Interview Prep · Lesson 4 of 16

Heaps for Top-K Retrieval Problems

Why Heaps Beat Sorting for Top-K

In retrieval systems, you never want to sort all n results just to return the top k. If you have 1 million documents and want top 10, sorting gives you O(n log n) — you've done a lot of work for results you'll discard. A heap gives you O(n log k): you process every document exactly once, maintaining only k candidates.

For k=10 and n=1,000,000: log(n)=20, log(k)=3.3. The heap approach is roughly 6x more work per element than brute force, but you only keep k elements — not n.

This is the go-to pattern for any "top-k with a score" problem in AI interviews.

Heap Fundamentals

A heap is a complete binary tree where the parent is always smaller (min-heap) or larger (max-heap) than its children.

Python
import heapq

# Python's heapq is a MIN-HEAP by default
# The smallest element is always at index 0

# Push and pop — O(log n)
h = []
heapq.heappush(h, 5)
heapq.heappush(h, 1)
heapq.heappush(h, 3)
print(h[0])  # 1 — minimum is always at root

smallest = heapq.heappop(h)  # removes and returns 1
print(smallest, h[0])  # 1, 3

# Build heap from list — O(n) using heapify
data = [5, 1, 3, 2, 4]
heapq.heapify(data)  # in-place, O(n)
print(data[0])  # 1

# Push and pop atomically — more efficient than separate push + pop
result = heapq.heappushpop(data, 0)  # pushes 0, pops minimum
print(result)  # 0 (smaller than current min)

result = heapq.heapreplace(data, 100)  # pops minimum, then pushes 100
print(result)  # the value that was popped

Simulating a Max-Heap

Python only gives you a min-heap. The standard trick for a max-heap is to negate values.

Python
# MAX-HEAP trick: store negated values
max_heap = []
scores = [0.82, 0.91, 0.74, 0.88, 0.65]

for score in scores:
    heapq.heappush(max_heap, -score)  # negate!

# Largest score is at the root (most negative = most positive original)
largest = -heapq.heappop(max_heap)
print(f"Largest score: {largest}")  # 0.91

# For (score, value) pairs  negate only the score
candidates = [
    (0.82, "doc_a"),
    (0.91, "doc_b"),
    (0.74, "doc_c"),
]

max_heap_pairs = []
for score, doc_id in candidates:
    heapq.heappush(max_heap_pairs, (-score, doc_id))

score, doc_id = heapq.heappop(max_heap_pairs)
print(f"Most similar: {doc_id} with score {-score:.2f}")  # doc_b with 0.91

Shortcuts: nlargest and nsmallest

For one-shot "give me top k" without maintaining a running heap:

Python
import heapq

scores = {
    "doc_a": 0.82,
    "doc_b": 0.91,
    "doc_c": 0.74,
    "doc_d": 0.88,
    "doc_e": 0.65,
    "doc_f": 0.95,
    "doc_g": 0.71,
}

# Top 3 by score value
top3 = heapq.nlargest(3, scores.items(), key=lambda x: x[1])
print("Top 3:", top3)
# [('doc_f', 0.95), ('doc_b', 0.91), ('doc_d', 0.88)]

# Bottom 2 by score (least similar)
bottom2 = heapq.nsmallest(2, scores.items(), key=lambda x: x[1])
print("Bottom 2:", bottom2)

# When to use nlargest vs sorted:
# - k << n: nlargest is O(n log k), sorted is O(n log n)  use nlargest
# - k is close to n: just sort the full list  O(n log n) with smaller constant
# Rule of thumb: if k < n // 10, use nlargest/nsmallest

Complete Solution: Top-K Similar Documents Using Heap

This is the full implementation you'd write in an AI engineering interview.

Python
import heapq
import math
from dataclasses import dataclass

@dataclass
class Document:
    doc_id: str
    embedding: list[float]
    text: str


def dot_product(a: list[float], b: list[float]) -> float:
    return sum(x * y for x, y in zip(a, b))


def magnitude(v: list[float]) -> float:
    return math.sqrt(sum(x * x for x in v))


def cosine_similarity(a: list[float], b: list[float]) -> float:
    """
    Cosine similarity between two vectors.
    Returns value in [-1, 1]. Higher = more similar.
    Returns 0.0 for zero vectors to avoid division by zero.
    """
    mag_a = magnitude(a)
    mag_b = magnitude(b)
    if mag_a == 0.0 or mag_b == 0.0:
        return 0.0
    return dot_product(a, b) / (mag_a * mag_b)


def top_k_similar_documents(
    query_embedding: list[float],
    documents: list[Document],
    k: int,
) -> list[tuple[float, Document]]:
    """
    Find the k documents most similar to the query embedding.
    
    Algorithm: min-heap of size k
    - We maintain a min-heap of the k best (score, doc) pairs seen so far
    - For each document, compute similarity
    - If heap has fewer than k elements, push
    - If similarity > heap minimum, pop minimum and push new doc
    - At the end, pop all k elements (they come out in ascending order, so reverse)
    
    Time: O(n log k)  — n docs, each heap op is O(log k)
    Space: O(k)       — heap only ever holds k elements
    
    vs. sort-all approach:
    Time: O(n + n log n) = O(n log n)
    Space: O(n)
    
    For k=10, n=100,000: heap is ~5x faster
    """
    if k <= 0:
        return []
    
    # Min-heap: (score, tie_breaker, document)
    # We use a tie_breaker integer to avoid comparing Document objects
    heap: list[tuple[float, int, Document]] = []
    
    for idx, doc in enumerate(documents):
        score = cosine_similarity(query_embedding, doc.embedding)
        
        if len(heap) < k:
            # Heap not full yet  always push
            heapq.heappush(heap, (score, idx, doc))
        elif score > heap[0][0]:
            # New score beats the current minimum  replace it
            heapq.heapreplace(heap, (score, idx, doc))
        # Otherwise: score is not in top k, discard
    
    # Extract results in descending order (most similar first)
    results = []
    while heap:
        score, _, doc = heapq.heappop(heap)
        results.append((score, doc))
    
    results.reverse()  # heappop gives ascending order, we want descending
    return results


# Demonstration
import random
random.seed(42)

def random_embedding(dim: int = 8) -> list[float]:
    v = [random.gauss(0, 1) for _ in range(dim)]
    mag = magnitude(v)
    return [x / mag for x in v]  # normalize to unit sphere

# Create 20 test documents with 8-dim embeddings
docs = [
    Document(
        doc_id=f"doc_{i:02d}",
        embedding=random_embedding(8),
        text=f"This is document number {i}.",
    )
    for i in range(20)
]

# Query: a random vector
query = random_embedding(8)

# Get top 5
top5 = top_k_similar_documents(query, docs, k=5)
print("Top 5 similar documents:")
for rank, (score, doc) in enumerate(top5, 1):
    print(f"  #{rank}: {doc.doc_id} — similarity = {score:.4f}")

# Verify: compare against brute-force sort
all_scored = sorted(
    [(cosine_similarity(query, d.embedding), d) for d in docs],
    key=lambda x: -x[0],
)
print("\nGround truth (brute-force sort):")
for rank, (score, doc) in enumerate(all_scored[:5], 1):
    print(f"  #{rank}: {doc.doc_id} — similarity = {score:.4f}")

heap_ids = [doc.doc_id for _, doc in top5]
sort_ids = [doc.doc_id for _, doc in all_scored[:5]]
assert heap_ids == sort_ids, f"Mismatch: {heap_ids} vs {sort_ids}"
print("\nResults match! Heap and brute-force agree on top 5.")

Handling Streaming Results with a Heap

In production, you might not have all documents in memory at once. A heap is the right structure for streaming top-k.

Python
from typing import Iterator, Generator

def stream_documents(n: int, dim: int = 8) -> Iterator[Document]:
    """Simulate a streaming document iterator (e.g., from a database cursor)."""
    for i in range(n):
        yield Document(
            doc_id=f"stream_doc_{i:04d}",
            embedding=random_embedding(dim),
            text=f"Streamed document {i}.",
        )


def top_k_streaming(
    query_embedding: list[float],
    document_stream: Iterator[Document],
    k: int,
) -> list[tuple[float, Document]]:
    """
    Top-k from a streaming source — never loads all documents into memory.
    
    This works because heaps only need to hold k elements at a time,
    regardless of how many documents flow through the stream.
    
    Time: O(n log k)
    Space: O(k) — only k documents in memory at any point
    """
    heap: list[tuple[float, int, Document]] = []
    
    for idx, doc in enumerate(document_stream):
        score = cosine_similarity(query_embedding, doc.embedding)
        
        if len(heap) < k:
            heapq.heappush(heap, (score, idx, doc))
        elif score > heap[0][0]:
            heapq.heapreplace(heap, (score, idx, doc))
    
    results = sorted(heap, key=lambda x: -x[0])
    return [(score, doc) for score, _, doc in results]


# Test streaming
query = random_embedding(8)
stream = stream_documents(n=10_000, dim=8)
top10 = top_k_streaming(query, stream, k=10)
print(f"\nTop 10 from stream of 10,000 documents:")
for rank, (score, doc) in enumerate(top10, 1):
    print(f"  #{rank}: {doc.doc_id} — score {score:.4f}")

Merge K Sorted Lists (Bonus Heap Problem)

This pattern appears in RAG systems when you retrieve from multiple vector indices.

Python
def merge_k_sorted_result_lists(
    result_lists: list[list[tuple[float, str]]],
) -> list[tuple[float, str]]:
    """
    Merge k sorted lists of (score, doc_id) into one sorted list (descending).
    
    Application: you query k different vector indices and want to merge results.
    
    Algorithm: use a max-heap initialized with the first element from each list.
    Each pop gives the global maximum; then push the next element from that list.
    
    Time: O(n log k) where n = total elements across all lists
    Space: O(k) for the heap
    """
    # Min-heap with negated scores for max-heap behavior
    # Elements: (-score, list_index, element_index, doc_id)
    heap = []
    
    # Initialize with first element from each non-empty list
    for list_idx, result_list in enumerate(result_lists):
        if result_list:
            score, doc_id = result_list[0]
            heapq.heappush(heap, (-score, list_idx, 0, doc_id))
    
    merged = []
    
    while heap:
        neg_score, list_idx, elem_idx, doc_id = heapq.heappop(heap)
        merged.append((-neg_score, doc_id))
        
        # Push the next element from the same list
        next_elem_idx = elem_idx + 1
        if next_elem_idx < len(result_lists[list_idx]):
            next_score, next_doc_id = result_lists[list_idx][next_elem_idx]
            heapq.heappush(heap, (-next_score, list_idx, next_elem_idx, next_doc_id))
    
    return merged


# Test: merge results from 3 vector indices
list_a = [(0.95, "doc_a1"), (0.80, "doc_a2"), (0.70, "doc_a3")]
list_b = [(0.92, "doc_b1"), (0.85, "doc_b2"), (0.60, "doc_b3")]
list_c = [(0.88, "doc_c1"), (0.75, "doc_c2"), (0.55, "doc_c3")]

merged = merge_k_sorted_result_lists([list_a, list_b, list_c])
print("\nMerged results from 3 indices:")
for rank, (score, doc_id) in enumerate(merged, 1):
    print(f"  #{rank}: {doc_id} — {score:.2f}")

Complexity Summary

| Operation | Min-Heap | Notes | |-----------|----------|-------| | Push | O(log n) | Sift up | | Pop minimum | O(log n) | Sift down | | Peek minimum | O(1) | heap[0] | | Build from list | O(n) | heapify | | Top-k from n | O(n log k) | Maintain heap of size k | | Merge k lists (n total) | O(n log k) | Same structure |

Interview Answer Template

When given a "find top k" problem, say:

"I'll use a min-heap of size k. For each element, if the heap isn't full yet, I push it. If it's full and the new element beats the heap minimum, I replace the minimum. This is O(n log k) time and O(k) space, which beats sorting all n elements when k is much smaller than n. After processing all elements, I pop the heap to get results in ascending order and reverse them."

Then write the code. This response shows you know the pattern, can articulate the complexity, and understand when it's the right choice.