Learnixo
Back to blog
AI Systemsadvanced

Recursion Interview Problems

Recursion problems common in AI engineering interviews: tree traversal, memoized Fibonacci, power sets, merge sort, JSON traversal, and recursive RAG tree summarization.

Asma Hafeez KhanMay 16, 20267 min read
PythonRecursionInterviewAlgorithmTreesMachine Learning
Share:𝕏

Recursion Fundamentals

Every recursive function needs:

  1. Base case — the condition that stops recursion
  2. Recursive case — a call with a smaller/simpler input that moves toward the base case
Python
def factorial(n: int) -> int:
    """n! = n * (n-1) * ... * 1"""
    if n <= 1:        # Base case
        return 1
    return n * factorial(n - 1)   # Recursive case

print(factorial(5))   # 120

# Call stack visualization:
# factorial(5)  5 * factorial(4)
# factorial(4)  4 * factorial(3)
# factorial(3)  3 * factorial(2)
# factorial(2)  2 * factorial(1)
# factorial(1)  1    base case, unwinds here

Problem 1: Fibonacci with Memoization

Problem: Compute the nth Fibonacci number. Naive recursion is O(2^n) — optimize with memoization.

Python
from functools import lru_cache

# Naive: O(2^n)  redundant computation
def fib_naive(n: int) -> int:
    if n <= 1:
        return n
    return fib_naive(n - 1) + fib_naive(n - 2)


# Memoized: O(n)  cache results
@lru_cache(maxsize=None)
def fib(n: int) -> int:
    """O(n) time, O(n) space."""
    if n <= 1:
        return n
    return fib(n - 1) + fib(n - 2)


print(fib(50))   # 12586269025  instant with memoization


# Manual memo dict (interview alternative)
def fib_memo(n: int, memo: dict[int, int] | None = None) -> int:
    if memo is None:
        memo = {}
    if n in memo:
        return memo[n]
    if n <= 1:
        return n
    memo[n] = fib_memo(n - 1, memo) + fib_memo(n - 2, memo)
    return memo[n]


# Iterative: O(n) time, O(1) space  best for large n
def fib_iterative(n: int) -> int:
    if n <= 1:
        return n
    prev, curr = 0, 1
    for _ in range(2, n + 1):
        prev, curr = curr, prev + curr
    return curr

Problem 2: Power Set (All Subsets)

Problem: Given a list of unique elements, return all possible subsets (the power set).

Python
def power_set(items: list) -> list[list]:
    """
    Generate all subsets.
    2^n subsets total.
    
    AI context: generate all feature combinations for ablation studies.
    """
    if not items:
        return [[]]   # Base case: one empty subset

    first = items[0]
    rest_subsets = power_set(items[1:])   # Recurse on the rest

    # Each existing subset either includes 'first' or doesn't
    with_first = [[first] + subset for subset in rest_subsets]
    return rest_subsets + with_first


features = ["embeddings", "reranking", "hyde"]
subsets = power_set(features)
print(len(subsets))   # 8 — 2^3
for s in subsets:
    print(s)
# []
# ["embeddings"]
# ["reranking"]
# ["embeddings", "reranking"]
# ...

Problem 3: Merge Sort

Problem: Implement merge sort. A classic divide-and-conquer algorithm.

Python
def merge_sort(arr: list[float]) -> list[float]:
    """
    Merge sort: O(n log n) time, O(n) space.
    Stable sort — preserves relative order of equal elements.
    """
    if len(arr) <= 1:
        return arr   # Base case

    mid = len(arr) // 2
    left  = merge_sort(arr[:mid])
    right = merge_sort(arr[mid:])
    return merge(left, right)


def merge(left: list[float], right: list[float]) -> list[float]:
    """Merge two sorted arrays into one sorted array."""
    result = []
    i = j = 0

    while i < len(left) and j < len(right):
        if left[i] <= right[j]:
            result.append(left[i])
            i += 1
        else:
            result.append(right[j])
            j += 1

    result.extend(left[i:])
    result.extend(right[j:])
    return result


# AI context: sorting retrieval results from multiple sources
scores = [0.45, 0.92, 0.78, 0.61, 0.89, 0.33, 0.75]
sorted_scores = merge_sort(scores)
print(sorted_scores)
# [0.33, 0.45, 0.61, 0.75, 0.78, 0.89, 0.92]

Problem 4: Flatten Nested JSON

Problem: Given a nested dict (e.g., an LLM API response), flatten all keys with dot notation.

Python
def flatten_json(obj: dict, prefix: str = "", sep: str = ".") -> dict:
    """
    Flatten nested dict into single-level with dot-separated keys.
    
    {"usage": {"tokens": 100}} → {"usage.tokens": 100}
    
    AI context: extracting metrics from LLM API response payloads.
    """
    result = {}

    for key, value in obj.items():
        full_key = f"{prefix}{sep}{key}" if prefix else key

        if isinstance(value, dict):
            nested = flatten_json(value, prefix=full_key, sep=sep)
            result.update(nested)
        else:
            result[full_key] = value

    return result


llm_response = {
    "id": "chatcmpl-abc123",
    "model": "gpt-4o",
    "usage": {
        "prompt_tokens": 120,
        "completion_tokens": 80,
        "total_tokens": 200,
    },
    "choices": [
        {
            "message": {"role": "assistant", "content": "Warfarin 5mg daily..."},
            "finish_reason": "stop",
        }
    ],
}

flat = flatten_json(llm_response)
print(flat["usage.prompt_tokens"])          # 120
print(flat["usage.total_tokens"])           # 200
# Note: list values (like "choices") are kept as-is since we only recurse into dicts

Problem 5: Recursive Tree Summarization

Problem: Summarize a document tree recursively — leaf nodes hold raw text, internal nodes hold summaries of their children.

Python
from dataclasses import dataclass, field

@dataclass
class DocumentNode:
    title: str
    content: str = ""
    children: list["DocumentNode"] = field(default_factory=list)

    @property
    def is_leaf(self) -> bool:
        return len(self.children) == 0


def summarize_tree(
    node: DocumentNode,
    summarize_fn,   # Callable[[str], str]  e.g., LLM call
    depth: int = 0,
) -> str:
    """
    Recursive summarization:
    - Leaf: summarize raw content directly
    - Internal: summarize children first, then combine their summaries

    This is the RAPTOR / tree-of-thoughts retrieval pattern.
    """
    indent = "  " * depth

    if node.is_leaf:
        print(f"{indent}[LEAF] Summarizing '{node.title}'")
        return summarize_fn(node.content)

    # Recurse into children first
    child_summaries = [
        summarize_tree(child, summarize_fn, depth + 1)
        for child in node.children
    ]

    # Combine children summaries into parent summary
    combined = "\n\n".join(child_summaries)
    print(f"{indent}[NODE] Summarizing '{node.title}' from {len(node.children)} children")
    return summarize_fn(combined)


# Mock LLM summarizer
def mock_summarize(text: str) -> str:
    words = text.split()[:8]
    return "[Summary] " + " ".join(words) + "..."


# Build a sample document tree
tree = DocumentNode(
    title="Clinical Guidelines",
    children=[
        DocumentNode(
            title="Anticoagulation",
            children=[
                DocumentNode("Warfarin", content="Warfarin 5mg daily, monitor INR weekly..."),
                DocumentNode("Heparin", content="Heparin IV infusion, APTT 60-100 seconds..."),
            ],
        ),
        DocumentNode(
            title="Antidiabetics",
            children=[
                DocumentNode("Metformin", content="Metformin 500mg BID with meals, renal dose adjust..."),
            ],
        ),
    ],
)

root_summary = summarize_tree(tree, mock_summarize)
print("\nFinal summary:", root_summary)

Problem 6: Count Ways to Climb Stairs

Problem: You can climb 1 or 2 steps at a time. In how many ways can you reach the nth step?

Python
from functools import lru_cache

@lru_cache(maxsize=None)
def climb_stairs(n: int) -> int:
    """
    Count ways to climb n stairs (1 or 2 steps at a time).
    This is just Fibonacci shifted by one.
    O(n) time, O(n) space with memoization.
    """
    if n <= 2:
        return n
    return climb_stairs(n - 1) + climb_stairs(n - 2)


for i in range(1, 8):
    print(f"n={i}: {climb_stairs(i)} ways")
# n=1: 1 ways
# n=2: 2 ways
# n=3: 3 ways
# n=4: 5 ways
# n=5: 8 ways


# Generalized: can climb 1, 2, or 3 steps
@lru_cache(maxsize=None)
def climb_k_steps(n: int, steps: tuple[int, ...] = (1, 2, 3)) -> int:
    """Count ways to climb n stairs with variable step sizes."""
    if n == 0:
        return 1   # Reached the top
    if n < 0:
        return 0   # Overshot

    return sum(climb_k_steps(n - step, steps) for step in steps)

Recursion Pitfalls and Fixes

Python
import sys

# Pitfall 1: Python's default recursion limit is 1000
def deep_recursion(n: int) -> int:
    if n == 0:
        return 0
    return 1 + deep_recursion(n - 1)

# deep_recursion(2000)  # RecursionError!

# Fix: increase limit (use sparingly) or convert to iteration
sys.setrecursionlimit(10_000)

# Fix: convert to iterative with explicit stack
def deep_iterative(n: int) -> int:
    result = 0
    while n > 0:
        result += 1
        n -= 1
    return result


# Pitfall 2: missing base case → infinite recursion
def bad_recursion(n: int) -> int:
    return bad_recursion(n - 1)   # RecursionError — no base case!

# Pitfall 3: mutable default argument in recursion
def collect(n: int, result: list | None = None) -> list:
    if result is None:
        result = []   # Safe: create a new list each time
    if n == 0:
        return result
    result.append(n)
    return collect(n - 1, result)

Recursion vs Iteration

| Aspect | Recursion | Iteration | |---|---|---| | Readability | High for tree/divide-and-conquer | High for loops | | Stack space | O(depth) — can overflow | O(1) typically | | Python limit | 1000 frames (configurable) | None | | Memoization | Easy with @lru_cache | Manual caching | | When to use | Trees, DAGs, power sets, backtracking | Fibonacci-like, flat sequences | | When to avoid | Deep nesting (more than ~500 levels) | Never — always correct option |

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.