Recursion Interview Problems
Recursion problems common in AI engineering interviews: tree traversal, memoized Fibonacci, power sets, merge sort, JSON traversal, and recursive RAG tree summarization.
Recursion Fundamentals
Every recursive function needs:
- Base case — the condition that stops recursion
- Recursive case — a call with a smaller/simpler input that moves toward the base case
def factorial(n: int) -> int:
"""n! = n * (n-1) * ... * 1"""
if n <= 1: # Base case
return 1
return n * factorial(n - 1) # Recursive case
print(factorial(5)) # 120
# Call stack visualization:
# factorial(5) → 5 * factorial(4)
# factorial(4) → 4 * factorial(3)
# factorial(3) → 3 * factorial(2)
# factorial(2) → 2 * factorial(1)
# factorial(1) → 1 ← base case, unwinds hereProblem 1: Fibonacci with Memoization
Problem: Compute the nth Fibonacci number. Naive recursion is O(2^n) — optimize with memoization.
from functools import lru_cache
# Naive: O(2^n) — redundant computation
def fib_naive(n: int) -> int:
if n <= 1:
return n
return fib_naive(n - 1) + fib_naive(n - 2)
# Memoized: O(n) — cache results
@lru_cache(maxsize=None)
def fib(n: int) -> int:
"""O(n) time, O(n) space."""
if n <= 1:
return n
return fib(n - 1) + fib(n - 2)
print(fib(50)) # 12586269025 — instant with memoization
# Manual memo dict (interview alternative)
def fib_memo(n: int, memo: dict[int, int] | None = None) -> int:
if memo is None:
memo = {}
if n in memo:
return memo[n]
if n <= 1:
return n
memo[n] = fib_memo(n - 1, memo) + fib_memo(n - 2, memo)
return memo[n]
# Iterative: O(n) time, O(1) space — best for large n
def fib_iterative(n: int) -> int:
if n <= 1:
return n
prev, curr = 0, 1
for _ in range(2, n + 1):
prev, curr = curr, prev + curr
return currProblem 2: Power Set (All Subsets)
Problem: Given a list of unique elements, return all possible subsets (the power set).
def power_set(items: list) -> list[list]:
"""
Generate all subsets.
2^n subsets total.
AI context: generate all feature combinations for ablation studies.
"""
if not items:
return [[]] # Base case: one empty subset
first = items[0]
rest_subsets = power_set(items[1:]) # Recurse on the rest
# Each existing subset either includes 'first' or doesn't
with_first = [[first] + subset for subset in rest_subsets]
return rest_subsets + with_first
features = ["embeddings", "reranking", "hyde"]
subsets = power_set(features)
print(len(subsets)) # 8 — 2^3
for s in subsets:
print(s)
# []
# ["embeddings"]
# ["reranking"]
# ["embeddings", "reranking"]
# ...Problem 3: Merge Sort
Problem: Implement merge sort. A classic divide-and-conquer algorithm.
def merge_sort(arr: list[float]) -> list[float]:
"""
Merge sort: O(n log n) time, O(n) space.
Stable sort — preserves relative order of equal elements.
"""
if len(arr) <= 1:
return arr # Base case
mid = len(arr) // 2
left = merge_sort(arr[:mid])
right = merge_sort(arr[mid:])
return merge(left, right)
def merge(left: list[float], right: list[float]) -> list[float]:
"""Merge two sorted arrays into one sorted array."""
result = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result
# AI context: sorting retrieval results from multiple sources
scores = [0.45, 0.92, 0.78, 0.61, 0.89, 0.33, 0.75]
sorted_scores = merge_sort(scores)
print(sorted_scores)
# [0.33, 0.45, 0.61, 0.75, 0.78, 0.89, 0.92]Problem 4: Flatten Nested JSON
Problem: Given a nested dict (e.g., an LLM API response), flatten all keys with dot notation.
def flatten_json(obj: dict, prefix: str = "", sep: str = ".") -> dict:
"""
Flatten nested dict into single-level with dot-separated keys.
{"usage": {"tokens": 100}} → {"usage.tokens": 100}
AI context: extracting metrics from LLM API response payloads.
"""
result = {}
for key, value in obj.items():
full_key = f"{prefix}{sep}{key}" if prefix else key
if isinstance(value, dict):
nested = flatten_json(value, prefix=full_key, sep=sep)
result.update(nested)
else:
result[full_key] = value
return result
llm_response = {
"id": "chatcmpl-abc123",
"model": "gpt-4o",
"usage": {
"prompt_tokens": 120,
"completion_tokens": 80,
"total_tokens": 200,
},
"choices": [
{
"message": {"role": "assistant", "content": "Warfarin 5mg daily..."},
"finish_reason": "stop",
}
],
}
flat = flatten_json(llm_response)
print(flat["usage.prompt_tokens"]) # 120
print(flat["usage.total_tokens"]) # 200
# Note: list values (like "choices") are kept as-is since we only recurse into dictsProblem 5: Recursive Tree Summarization
Problem: Summarize a document tree recursively — leaf nodes hold raw text, internal nodes hold summaries of their children.
from dataclasses import dataclass, field
@dataclass
class DocumentNode:
title: str
content: str = ""
children: list["DocumentNode"] = field(default_factory=list)
@property
def is_leaf(self) -> bool:
return len(self.children) == 0
def summarize_tree(
node: DocumentNode,
summarize_fn, # Callable[[str], str] — e.g., LLM call
depth: int = 0,
) -> str:
"""
Recursive summarization:
- Leaf: summarize raw content directly
- Internal: summarize children first, then combine their summaries
This is the RAPTOR / tree-of-thoughts retrieval pattern.
"""
indent = " " * depth
if node.is_leaf:
print(f"{indent}[LEAF] Summarizing '{node.title}'")
return summarize_fn(node.content)
# Recurse into children first
child_summaries = [
summarize_tree(child, summarize_fn, depth + 1)
for child in node.children
]
# Combine children summaries into parent summary
combined = "\n\n".join(child_summaries)
print(f"{indent}[NODE] Summarizing '{node.title}' from {len(node.children)} children")
return summarize_fn(combined)
# Mock LLM summarizer
def mock_summarize(text: str) -> str:
words = text.split()[:8]
return "[Summary] " + " ".join(words) + "..."
# Build a sample document tree
tree = DocumentNode(
title="Clinical Guidelines",
children=[
DocumentNode(
title="Anticoagulation",
children=[
DocumentNode("Warfarin", content="Warfarin 5mg daily, monitor INR weekly..."),
DocumentNode("Heparin", content="Heparin IV infusion, APTT 60-100 seconds..."),
],
),
DocumentNode(
title="Antidiabetics",
children=[
DocumentNode("Metformin", content="Metformin 500mg BID with meals, renal dose adjust..."),
],
),
],
)
root_summary = summarize_tree(tree, mock_summarize)
print("\nFinal summary:", root_summary)Problem 6: Count Ways to Climb Stairs
Problem: You can climb 1 or 2 steps at a time. In how many ways can you reach the nth step?
from functools import lru_cache
@lru_cache(maxsize=None)
def climb_stairs(n: int) -> int:
"""
Count ways to climb n stairs (1 or 2 steps at a time).
This is just Fibonacci shifted by one.
O(n) time, O(n) space with memoization.
"""
if n <= 2:
return n
return climb_stairs(n - 1) + climb_stairs(n - 2)
for i in range(1, 8):
print(f"n={i}: {climb_stairs(i)} ways")
# n=1: 1 ways
# n=2: 2 ways
# n=3: 3 ways
# n=4: 5 ways
# n=5: 8 ways
# Generalized: can climb 1, 2, or 3 steps
@lru_cache(maxsize=None)
def climb_k_steps(n: int, steps: tuple[int, ...] = (1, 2, 3)) -> int:
"""Count ways to climb n stairs with variable step sizes."""
if n == 0:
return 1 # Reached the top
if n < 0:
return 0 # Overshot
return sum(climb_k_steps(n - step, steps) for step in steps)Recursion Pitfalls and Fixes
import sys
# Pitfall 1: Python's default recursion limit is 1000
def deep_recursion(n: int) -> int:
if n == 0:
return 0
return 1 + deep_recursion(n - 1)
# deep_recursion(2000) # RecursionError!
# Fix: increase limit (use sparingly) or convert to iteration
sys.setrecursionlimit(10_000)
# Fix: convert to iterative with explicit stack
def deep_iterative(n: int) -> int:
result = 0
while n > 0:
result += 1
n -= 1
return result
# Pitfall 2: missing base case → infinite recursion
def bad_recursion(n: int) -> int:
return bad_recursion(n - 1) # RecursionError — no base case!
# Pitfall 3: mutable default argument in recursion
def collect(n: int, result: list | None = None) -> list:
if result is None:
result = [] # Safe: create a new list each time
if n == 0:
return result
result.append(n)
return collect(n - 1, result)Recursion vs Iteration
| Aspect | Recursion | Iteration |
|---|---|---|
| Readability | High for tree/divide-and-conquer | High for loops |
| Stack space | O(depth) — can overflow | O(1) typically |
| Python limit | 1000 frames (configurable) | None |
| Memoization | Easy with @lru_cache | Manual caching |
| When to use | Trees, DAGs, power sets, backtracking | Fibonacci-like, flat sequences |
| When to avoid | Deep nesting (more than ~500 levels) | Never — always correct option |
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.