Learnixo
Back to blog
AI Systemsadvanced

Time and Space Complexity for AI Engineers

Big-O complexity for AI engineering interviews: understand O(1) through O(nΒ²), analyze Python data structures, and apply complexity reasoning to embedding search, RAG pipelines, and LLM cost estimation.

Asma Hafeez KhanMay 16, 20267 min read
PythonComplexityBig-OInterviewAlgorithmMachine LearningPerformance
Share:𝕏

What is Big-O?

Big-O describes how runtime or memory scales as input grows. It ignores constants and lower-order terms β€” we care about the dominant factor.

O(1) < O(log n) < O(n) < O(n log n) < O(nΒ²) < O(2^n) < O(n!)
     ↑             ↑           ↑              ↑          ↑
  Constant      Logarithmic  Linear       Quadratic  Exponential

Common Complexities

O(1) β€” Constant Time

Runtime does not change regardless of input size.

Python
# Dict/set lookup β€” O(1) average
config = {"model": "gpt-4o", "temperature": 0.0}
model = config["model"]   # O(1)

# List access by index β€” O(1)
embeddings = [...]
first = embeddings[0]   # O(1)

# Set membership β€” O(1) average
known_drugs = {"warfarin", "aspirin", "metformin"}
is_known = "warfarin" in known_drugs   # O(1)

# List append β€” O(1) amortized
results = []
results.append("doc_1")   # O(1) amortized

O(log n) β€” Logarithmic Time

Input is halved at each step. Typical for binary search and balanced trees.

Python
import bisect

def binary_search(sorted_arr: list[float], target: float) -> int:
    """O(log n) β€” halves the search space each step."""
    lo, hi = 0, len(sorted_arr) - 1

    while lo <= hi:
        mid = (lo + hi) // 2
        if sorted_arr[mid] == target:
            return mid
        elif sorted_arr[mid] < target:
            lo = mid + 1
        else:
            hi = mid - 1

    return -1

# For n=1,000,000: at most ~20 steps (logβ‚‚(10^6) β‰ˆ 20)
sorted_scores = sorted([0.9, 0.4, 0.85, 0.72, 0.61])
idx = binary_search(sorted_scores, 0.85)   # O(log n)

# bisect module: O(log n) insertion point
import bisect
pos = bisect.bisect_left(sorted_scores, 0.85)

O(n) β€” Linear Time

One pass through the data. Most "iterate everything once" algorithms.

Python
def find_max_score(scores: list[float]) -> float:
    """O(n) β€” must check every element."""
    max_score = float("-inf")
    for score in scores:
        if score > max_score:
            max_score = score
    return max_score


def count_tokens(texts: list[str]) -> int:
    """O(n) where n is total characters across all texts."""
    return sum(len(text.split()) for text in texts)


# Linear scan of embedding index: O(n_docs)
def brute_force_search(query_emb, doc_embs, k: int = 5):
    scores = [cosine_sim(query_emb, d) for d in doc_embs]  # O(n * dim)
    return sorted(enumerate(scores), key=lambda x: -x[1])[:k]

O(n log n) β€” Linearithmic Time

Sorting algorithms, merge sort, many divide-and-conquer approaches.

Python
# Python's sort: O(n log n) β€” Timsort
scores = [0.45, 0.92, 0.78, 0.61, 0.89]
scores.sort()   # O(n log n)

# Merge sort is O(n log n) β€” see recursion article
# Heap-based top-k: O(n log k)
import heapq
top_5 = heapq.nlargest(5, scores)   # O(n log 5) β‰ˆ O(n)

# np.argsort: O(n log n)
import numpy as np
arr = np.array(scores)
ranked = np.argsort(arr)[::-1]   # O(n log n)

O(nΒ²) β€” Quadratic Time

Nested loops over the same data. Becomes slow quickly as n grows.

Python
# Pairwise similarity matrix β€” O(nΒ² * dim)
def pairwise_similarities(embeddings: list) -> list[list[float]]:
    n = len(embeddings)
    matrix = [[0.0] * n for _ in range(n)]
    for i in range(n):
        for j in range(i, n):   # Still O(nΒ²) even with triangle optimization
            sim = cosine_sim(embeddings[i], embeddings[j])
            matrix[i][j] = matrix[j][i] = sim
    return matrix

# For n=10,000 documents: 10^8 comparisons β€” too slow for production
# Use matrix multiply instead: O(nΒ² * dim) but with C-level speed (NumPy)
import numpy as np
embs = np.array(embeddings)
sim_matrix = embs @ embs.T   # Vectorized: same complexity, much faster


# Bubble sort: O(nΒ²) β€” never use in production
def bubble_sort(arr):
    n = len(arr)
    for i in range(n):
        for j in range(n - i - 1):
            if arr[j] > arr[j + 1]:
                arr[j], arr[j + 1] = arr[j + 1], arr[j]

O(2^n) β€” Exponential Time

Typically from generating all subsets or brute-force search.

Python
def power_set(items: list) -> list[list]:
    """O(2^n) β€” unavoidable: there are 2^n subsets."""
    if not items:
        return [[]]
    first, rest = items[0], items[1:]
    rest_subsets = power_set(rest)
    return rest_subsets + [[first] + s for s in rest_subsets]

# For n=20: 1 million subsets β€” still manageable
# For n=40: 1 trillion subsets β€” not feasible

Python Data Structure Complexities

Python
# LIST
lst = [1, 2, 3, 4, 5]

lst[2]          # O(1) β€” random access
lst.append(6)   # O(1) amortized
lst.pop()       # O(1) β€” from end
lst.pop(0)      # O(n) β€” from front (shifts everything)
lst.insert(0, 0)# O(n) β€” shift on insert
6 in lst        # O(n) β€” linear scan
lst.remove(3)   # O(n) β€” find + shift

# DICT
d = {"a": 1}
d["a"]          # O(1) average
d["b"] = 2      # O(1) average
del d["a"]      # O(1) average
"a" in d        # O(1) average

# SET
s = {1, 2, 3}
s.add(4)        # O(1) average
s.remove(2)     # O(1) average
3 in s          # O(1) average

# DEQUE (from collections)
from collections import deque
dq = deque([1, 2, 3])
dq.appendleft(0)  # O(1) β€” fast front insert
dq.popleft()      # O(1) β€” fast front remove
# Use deque for queues; list.pop(0) is O(n)

Analyzing AI Pipeline Complexity

RAG Pipeline

Query β†’ Embed β†’ Search Index β†’ Rerank β†’ Generate

Step             Time Complexity        Notes
─────────────────────────────────────────────────────
Embed query      O(L)                   L = query length
Brute-force search O(n Γ— dim)           n = # docs
FAISS ANN search ~O(sqrt(n) Γ— dim)      Approximate
Rerank top-k     O(k Γ— L_doc)           k << n
LLM generate     O(context_lenΒ²)        Attention is quadratic
Python
# Token cost estimation: O(1) per call, O(n) for a batch
def estimate_cost(n_docs: int, tokens_per_doc: int, price_per_1k: float) -> float:
    total_tokens = n_docs * tokens_per_doc   # O(n) to compute, O(1) formula
    return (total_tokens / 1000) * price_per_1k

print(estimate_cost(100, 500, 0.002))   # $0.10

Chunking Documents

Python
def chunk_documents(docs: list[str], chunk_size: int) -> list[str]:
    """
    Time: O(total_chars) β€” must visit every character
    Space: O(total_chars) β€” output is proportional to input
    """
    chunks = []
    for doc in docs:   # O(n_docs)
        words = doc.split()
        for i in range(0, len(words), chunk_size):   # O(n_words / chunk_size)
            chunks.append(" ".join(words[i:i + chunk_size]))
    return chunks

Top-K with np.argpartition vs np.argsort

Python
import numpy as np
import time

n_docs = 100_000
scores = np.random.randn(n_docs)
k = 10

# argsort: O(n log n) β€” sorts everything
start = time.perf_counter()
top_k_sort = np.argsort(scores)[-k:][::-1]
argsort_time = time.perf_counter() - start

# argpartition: O(n) β€” only guarantees top-k in partition, not fully sorted
start = time.perf_counter()
top_k_part = np.argpartition(scores, -k)[-k:]
top_k_part = top_k_part[np.argsort(scores[top_k_part])[::-1]]   # Sort just top-k: O(k log k)
partition_time = time.perf_counter() - start

print(f"argsort:       {argsort_time * 1000:.2f}ms")
print(f"argpartition:  {partition_time * 1000:.2f}ms")
# argpartition is typically 2-5x faster for large n, small k

Interview: Identify the Complexity

Python
# Q: What is the time complexity of this function?
def find_common(a: list[str], b: list[str]) -> list[str]:
    result = []
    for x in a:          # O(n)
        if x in b:       # O(m) β€” list membership scan!
            result.append(x)
    return result
# A: O(n * m) β€” quadratic

# Optimized: O(n + m) using a set
def find_common_fast(a: list[str], b: list[str]) -> list[str]:
    b_set = set(b)                  # O(m) to build
    return [x for x in a if x in b_set]   # O(n) β€” O(1) per lookup
# A: O(n + m) β€” linear


# Q: What about this?
def nested_loop(matrix: list[list[int]]) -> int:
    total = 0
    for row in matrix:       # O(n) rows
        for val in row:      # O(m) cols
            total += val
    return total
# A: O(n * m) β€” but this is also O(total_elements), which is O(nΒ²) for square matrices

Complexity Quick Reference

| Operation | Data Structure | Time | |---|---|---| | Index access | list | O(1) | | Append to end | list | O(1) amortized | | Insert at front | list | O(n) | | Remove from front | list | O(n) | | Membership test | list | O(n) | | Membership test | set / dict | O(1) avg | | Insert / delete | set / dict | O(1) avg | | Sort | list | O(n log n) | | Binary search | sorted list | O(log n) | | Heap push/pop | heapq | O(log n) | | argpartition top-k | NumPy | O(n) | | argsort | NumPy | O(n log n) | | Matrix multiply | NumPy (BLAS) | O(nΒ² * dim) vectorized | | Attention (Transformer) | β€” | O(context_lenΒ²) |

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.