Heaps for Top-K Retrieval
Min-heap and max-heap operations for AI systems: top-k most similar embeddings without sorting all results, heapq module, and complete implementations for vector search.
Why Heaps Beat Sorting for Top-K
In retrieval systems, you never want to sort all n results just to return the top k. If you have 1 million documents and want top 10, sorting gives you O(n log n) — you've done a lot of work for results you'll discard. A heap gives you O(n log k): you process every document exactly once, maintaining only k candidates.
For k=10 and n=1,000,000: log(n)=20, log(k)=3.3. The heap approach is roughly 6x more work per element than brute force, but you only keep k elements — not n.
This is the go-to pattern for any "top-k with a score" problem in AI interviews.
Heap Fundamentals
A heap is a complete binary tree where the parent is always smaller (min-heap) or larger (max-heap) than its children.
import heapq
# Python's heapq is a MIN-HEAP by default
# The smallest element is always at index 0
# Push and pop — O(log n)
h = []
heapq.heappush(h, 5)
heapq.heappush(h, 1)
heapq.heappush(h, 3)
print(h[0]) # 1 — minimum is always at root
smallest = heapq.heappop(h) # removes and returns 1
print(smallest, h[0]) # 1, 3
# Build heap from list — O(n) using heapify
data = [5, 1, 3, 2, 4]
heapq.heapify(data) # in-place, O(n)
print(data[0]) # 1
# Push and pop atomically — more efficient than separate push + pop
result = heapq.heappushpop(data, 0) # pushes 0, pops minimum
print(result) # 0 (smaller than current min)
result = heapq.heapreplace(data, 100) # pops minimum, then pushes 100
print(result) # the value that was poppedSimulating a Max-Heap
Python only gives you a min-heap. The standard trick for a max-heap is to negate values.
# MAX-HEAP trick: store negated values
max_heap = []
scores = [0.82, 0.91, 0.74, 0.88, 0.65]
for score in scores:
heapq.heappush(max_heap, -score) # negate!
# Largest score is at the root (most negative = most positive original)
largest = -heapq.heappop(max_heap)
print(f"Largest score: {largest}") # 0.91
# For (score, value) pairs — negate only the score
candidates = [
(0.82, "doc_a"),
(0.91, "doc_b"),
(0.74, "doc_c"),
]
max_heap_pairs = []
for score, doc_id in candidates:
heapq.heappush(max_heap_pairs, (-score, doc_id))
score, doc_id = heapq.heappop(max_heap_pairs)
print(f"Most similar: {doc_id} with score {-score:.2f}") # doc_b with 0.91Shortcuts: nlargest and nsmallest
For one-shot "give me top k" without maintaining a running heap:
import heapq
scores = {
"doc_a": 0.82,
"doc_b": 0.91,
"doc_c": 0.74,
"doc_d": 0.88,
"doc_e": 0.65,
"doc_f": 0.95,
"doc_g": 0.71,
}
# Top 3 by score value
top3 = heapq.nlargest(3, scores.items(), key=lambda x: x[1])
print("Top 3:", top3)
# [('doc_f', 0.95), ('doc_b', 0.91), ('doc_d', 0.88)]
# Bottom 2 by score (least similar)
bottom2 = heapq.nsmallest(2, scores.items(), key=lambda x: x[1])
print("Bottom 2:", bottom2)
# When to use nlargest vs sorted:
# - k << n: nlargest is O(n log k), sorted is O(n log n) — use nlargest
# - k is close to n: just sort the full list — O(n log n) with smaller constant
# Rule of thumb: if k < n // 10, use nlargest/nsmallestComplete Solution: Top-K Similar Documents Using Heap
This is the full implementation you'd write in an AI engineering interview.
import heapq
import math
from dataclasses import dataclass
@dataclass
class Document:
doc_id: str
embedding: list[float]
text: str
def dot_product(a: list[float], b: list[float]) -> float:
return sum(x * y for x, y in zip(a, b))
def magnitude(v: list[float]) -> float:
return math.sqrt(sum(x * x for x in v))
def cosine_similarity(a: list[float], b: list[float]) -> float:
"""
Cosine similarity between two vectors.
Returns value in [-1, 1]. Higher = more similar.
Returns 0.0 for zero vectors to avoid division by zero.
"""
mag_a = magnitude(a)
mag_b = magnitude(b)
if mag_a == 0.0 or mag_b == 0.0:
return 0.0
return dot_product(a, b) / (mag_a * mag_b)
def top_k_similar_documents(
query_embedding: list[float],
documents: list[Document],
k: int,
) -> list[tuple[float, Document]]:
"""
Find the k documents most similar to the query embedding.
Algorithm: min-heap of size k
- We maintain a min-heap of the k best (score, doc) pairs seen so far
- For each document, compute similarity
- If heap has fewer than k elements, push
- If similarity > heap minimum, pop minimum and push new doc
- At the end, pop all k elements (they come out in ascending order, so reverse)
Time: O(n log k) — n docs, each heap op is O(log k)
Space: O(k) — heap only ever holds k elements
vs. sort-all approach:
Time: O(n + n log n) = O(n log n)
Space: O(n)
For k=10, n=100,000: heap is ~5x faster
"""
if k <= 0:
return []
# Min-heap: (score, tie_breaker, document)
# We use a tie_breaker integer to avoid comparing Document objects
heap: list[tuple[float, int, Document]] = []
for idx, doc in enumerate(documents):
score = cosine_similarity(query_embedding, doc.embedding)
if len(heap) < k:
# Heap not full yet — always push
heapq.heappush(heap, (score, idx, doc))
elif score > heap[0][0]:
# New score beats the current minimum — replace it
heapq.heapreplace(heap, (score, idx, doc))
# Otherwise: score is not in top k, discard
# Extract results in descending order (most similar first)
results = []
while heap:
score, _, doc = heapq.heappop(heap)
results.append((score, doc))
results.reverse() # heappop gives ascending order, we want descending
return results
# Demonstration
import random
random.seed(42)
def random_embedding(dim: int = 8) -> list[float]:
v = [random.gauss(0, 1) for _ in range(dim)]
mag = magnitude(v)
return [x / mag for x in v] # normalize to unit sphere
# Create 20 test documents with 8-dim embeddings
docs = [
Document(
doc_id=f"doc_{i:02d}",
embedding=random_embedding(8),
text=f"This is document number {i}.",
)
for i in range(20)
]
# Query: a random vector
query = random_embedding(8)
# Get top 5
top5 = top_k_similar_documents(query, docs, k=5)
print("Top 5 similar documents:")
for rank, (score, doc) in enumerate(top5, 1):
print(f" #{rank}: {doc.doc_id} — similarity = {score:.4f}")
# Verify: compare against brute-force sort
all_scored = sorted(
[(cosine_similarity(query, d.embedding), d) for d in docs],
key=lambda x: -x[0],
)
print("\nGround truth (brute-force sort):")
for rank, (score, doc) in enumerate(all_scored[:5], 1):
print(f" #{rank}: {doc.doc_id} — similarity = {score:.4f}")
heap_ids = [doc.doc_id for _, doc in top5]
sort_ids = [doc.doc_id for _, doc in all_scored[:5]]
assert heap_ids == sort_ids, f"Mismatch: {heap_ids} vs {sort_ids}"
print("\nResults match! Heap and brute-force agree on top 5.")Handling Streaming Results with a Heap
In production, you might not have all documents in memory at once. A heap is the right structure for streaming top-k.
from typing import Iterator, Generator
def stream_documents(n: int, dim: int = 8) -> Iterator[Document]:
"""Simulate a streaming document iterator (e.g., from a database cursor)."""
for i in range(n):
yield Document(
doc_id=f"stream_doc_{i:04d}",
embedding=random_embedding(dim),
text=f"Streamed document {i}.",
)
def top_k_streaming(
query_embedding: list[float],
document_stream: Iterator[Document],
k: int,
) -> list[tuple[float, Document]]:
"""
Top-k from a streaming source — never loads all documents into memory.
This works because heaps only need to hold k elements at a time,
regardless of how many documents flow through the stream.
Time: O(n log k)
Space: O(k) — only k documents in memory at any point
"""
heap: list[tuple[float, int, Document]] = []
for idx, doc in enumerate(document_stream):
score = cosine_similarity(query_embedding, doc.embedding)
if len(heap) < k:
heapq.heappush(heap, (score, idx, doc))
elif score > heap[0][0]:
heapq.heapreplace(heap, (score, idx, doc))
results = sorted(heap, key=lambda x: -x[0])
return [(score, doc) for score, _, doc in results]
# Test streaming
query = random_embedding(8)
stream = stream_documents(n=10_000, dim=8)
top10 = top_k_streaming(query, stream, k=10)
print(f"\nTop 10 from stream of 10,000 documents:")
for rank, (score, doc) in enumerate(top10, 1):
print(f" #{rank}: {doc.doc_id} — score {score:.4f}")Merge K Sorted Lists (Bonus Heap Problem)
This pattern appears in RAG systems when you retrieve from multiple vector indices.
def merge_k_sorted_result_lists(
result_lists: list[list[tuple[float, str]]],
) -> list[tuple[float, str]]:
"""
Merge k sorted lists of (score, doc_id) into one sorted list (descending).
Application: you query k different vector indices and want to merge results.
Algorithm: use a max-heap initialized with the first element from each list.
Each pop gives the global maximum; then push the next element from that list.
Time: O(n log k) where n = total elements across all lists
Space: O(k) for the heap
"""
# Min-heap with negated scores for max-heap behavior
# Elements: (-score, list_index, element_index, doc_id)
heap = []
# Initialize with first element from each non-empty list
for list_idx, result_list in enumerate(result_lists):
if result_list:
score, doc_id = result_list[0]
heapq.heappush(heap, (-score, list_idx, 0, doc_id))
merged = []
while heap:
neg_score, list_idx, elem_idx, doc_id = heapq.heappop(heap)
merged.append((-neg_score, doc_id))
# Push the next element from the same list
next_elem_idx = elem_idx + 1
if next_elem_idx < len(result_lists[list_idx]):
next_score, next_doc_id = result_lists[list_idx][next_elem_idx]
heapq.heappush(heap, (-next_score, list_idx, next_elem_idx, next_doc_id))
return merged
# Test: merge results from 3 vector indices
list_a = [(0.95, "doc_a1"), (0.80, "doc_a2"), (0.70, "doc_a3")]
list_b = [(0.92, "doc_b1"), (0.85, "doc_b2"), (0.60, "doc_b3")]
list_c = [(0.88, "doc_c1"), (0.75, "doc_c2"), (0.55, "doc_c3")]
merged = merge_k_sorted_result_lists([list_a, list_b, list_c])
print("\nMerged results from 3 indices:")
for rank, (score, doc_id) in enumerate(merged, 1):
print(f" #{rank}: {doc_id} — {score:.2f}")Complexity Summary
| Operation | Min-Heap | Notes | |-----------|----------|-------| | Push | O(log n) | Sift up | | Pop minimum | O(log n) | Sift down | | Peek minimum | O(1) | heap[0] | | Build from list | O(n) | heapify | | Top-k from n | O(n log k) | Maintain heap of size k | | Merge k lists (n total) | O(n log k) | Same structure |
Interview Answer Template
When given a "find top k" problem, say:
"I'll use a min-heap of size k. For each element, if the heap isn't full yet, I push it. If it's full and the new element beats the heap minimum, I replace the minimum. This is O(n log k) time and O(k) space, which beats sorting all n elements when k is much smaller than n. After processing all elements, I pop the heap to get results in ascending order and reverse them."
Then write the code. This response shows you know the pattern, can articulate the complexity, and understand when it's the right choice.
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.