Learnixo
Back to blog
AI Systemsintermediate

Implement LLM Response Caching

Build a semantic cache for LLM responses. Cache exact matches with a hash, and semantically similar queries with embedding similarity to reduce API costs and latency.

Asma Hafeez KhanMay 16, 20265 min read
Live CodingCachingLLMPython
Share:𝕏

Why Cache LLM Responses?

LLM API calls are expensive (cost) and slow (latency). Users often ask similar questions. Caching lets you serve repeated queries instantly from cache rather than calling the API.

Two caching strategies:

  • Exact match: Hash the prompt, check cache — O(1) lookup
  • Semantic match: Find cached responses for similar (not identical) queries using embedding similarity

Exact Match Cache (LRU)

The simplest cache — store prompt hash → response:

Python
import hashlib
from collections import OrderedDict
from typing import Optional

class LRUCache:
    """Least Recently Used cache with maximum size."""

    def __init__(self, max_size: int = 1000):
        self.max_size = max_size
        self._cache: OrderedDict[str, str] = OrderedDict()

    def _hash_key(self, prompt: str) -> str:
        return hashlib.sha256(prompt.encode()).hexdigest()

    def get(self, prompt: str) -> Optional[str]:
        key = self._hash_key(prompt)
        if key in self._cache:
            self._cache.move_to_end(key)  # Mark as recently used
            return self._cache[key]
        return None

    def set(self, prompt: str, response: str):
        key = self._hash_key(prompt)
        if key in self._cache:
            self._cache.move_to_end(key)
        self._cache[key] = response
        if len(self._cache) > self.max_size:
            self._cache.popitem(last=False)  # Remove least recently used

    def size(self) -> int:
        return len(self._cache)

# Test
cache = LRUCache(max_size=3)
cache.set("What is metformin?", "Metformin is a biguanide antidiabetic drug...")
cache.set("What is warfarin?", "Warfarin is an anticoagulant...")

print(cache.get("What is metformin?"))  # Cache hit
print(cache.get("Tell me about aspirin"))  # None  cache miss

Semantic Cache

Match queries based on meaning, not exact text:

Python
import numpy as np
from dataclasses import dataclass

@dataclass
class CacheEntry:
    prompt: str
    response: str
    embedding: np.ndarray
    hit_count: int = 0

class SemanticCache:
    """Cache LLM responses using semantic similarity for lookup."""

    def __init__(
        self,
        similarity_threshold: float = 0.92,
        max_entries: int = 500,
    ):
        self.threshold = similarity_threshold
        self.max_entries = max_entries
        self.entries: list[CacheEntry] = []
        self._embeddings_matrix: np.ndarray | None = None

    def _rebuild_matrix(self):
        """Rebuild the embedding matrix for efficient search."""
        if self.entries:
            embeddings = np.stack([e.embedding for e in self.entries])
            norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
            self._embeddings_matrix = embeddings / (norms + 1e-10)

    def get(self, query: str, query_embedding: np.ndarray) -> str | None:
        """Find a cached response with high similarity to the query."""
        if not self.entries or self._embeddings_matrix is None:
            return None

        # Normalize query embedding
        query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-10)

        # Cosine similarities against all cached entries
        similarities = self._embeddings_matrix @ query_norm  # (n,)
        best_idx = int(np.argmax(similarities))
        best_score = float(similarities[best_idx])

        if best_score >= self.threshold:
            self.entries[best_idx].hit_count += 1
            return self.entries[best_idx].response

        return None

    def set(self, prompt: str, response: str, embedding: np.ndarray):
        """Cache a prompt-response pair with its embedding."""
        if len(self.entries) >= self.max_entries:
            # Evict the entry with the fewest hits
            min_hits_idx = min(range(len(self.entries)), key=lambda i: self.entries[i].hit_count)
            self.entries.pop(min_hits_idx)

        self.entries.append(CacheEntry(
            prompt=prompt,
            response=response,
            embedding=embedding,
        ))
        self._rebuild_matrix()

# Simulated embedding function
def embed(text: str) -> np.ndarray:
    """Placeholder — replace with actual embedding API call."""
    rng = np.random.RandomState(hash(text) % (2**31))
    return rng.randn(128).astype(np.float32)

# Test
cache = SemanticCache(similarity_threshold=0.95)

# Cache a response
q1 = "What is the mechanism of action of metformin?"
emb1 = embed(q1)
cache.set(q1, "Metformin activates AMPK, reducing hepatic glucose production...", emb1)

# Exact same query  hit
result = cache.get(q1, embed(q1))
print(f"Exact query: {'HIT' if result else 'MISS'}")

# Very different query  miss (different embedding)
q2 = "What are the side effects of warfarin?"
result = cache.get(q2, embed(q2))
print(f"Different query: {'HIT' if result else 'MISS'}")

Full Caching Pipeline

Combining exact and semantic cache:

Python
import time
from openai import OpenAI

client = OpenAI()

class LLMCacheLayer:
    """Two-level cache: exact match first, then semantic."""

    def __init__(
        self,
        exact_cache_size: int = 10_000,
        semantic_threshold: float = 0.93,
    ):
        self.exact_cache = LRUCache(max_size=exact_cache_size)
        self.semantic_cache = SemanticCache(similarity_threshold=semantic_threshold)
        self.stats = {"hits_exact": 0, "hits_semantic": 0, "misses": 0, "api_calls": 0}

    def _embed(self, text: str) -> np.ndarray:
        """Embed using OpenAI (replace with your embedding model)."""
        resp = client.embeddings.create(
            input=text,
            model="text-embedding-3-small",
        )
        return np.array(resp.data[0].embedding)

    def _call_llm(self, prompt: str, system: str) -> str:
        self.stats["api_calls"] += 1
        resp = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system},
                {"role": "user", "content": prompt},
            ],
            temperature=0.1,
        )
        return resp.choices[0].message.content

    def query(self, prompt: str, system: str = "") -> dict:
        start = time.time()
        full_key = f"{system}||{prompt}"

        # Level 1: Exact match
        exact_result = self.exact_cache.get(full_key)
        if exact_result:
            self.stats["hits_exact"] += 1
            return {"response": exact_result, "source": "exact_cache", "latency_ms": (time.time()-start)*1000}

        # Level 2: Semantic match
        embedding = self._embed(prompt)
        semantic_result = self.semantic_cache.get(prompt, embedding)
        if semantic_result:
            self.stats["hits_semantic"] += 1
            self.exact_cache.set(full_key, semantic_result)  # Promote to exact cache
            return {"response": semantic_result, "source": "semantic_cache", "latency_ms": (time.time()-start)*1000}

        # Level 3: API call
        self.stats["misses"] += 1
        response = self._call_llm(prompt, system)

        # Cache for future use
        self.exact_cache.set(full_key, response)
        self.semantic_cache.set(prompt, response, embedding)

        return {"response": response, "source": "api", "latency_ms": (time.time()-start)*1000}

    def cache_stats(self) -> dict:
        total = sum(self.stats.values()) - self.stats["api_calls"]
        if total == 0:
            return self.stats
        hit_rate = (self.stats["hits_exact"] + self.stats["hits_semantic"]) / max(1, total)
        return {**self.stats, "hit_rate": hit_rate}

TTL and Invalidation

For medical content, cached responses may become outdated as guidelines change. Add TTL (Time To Live):

Python
import time
from dataclasses import dataclass, field

@dataclass
class TTLCacheEntry:
    value: str
    expires_at: float

class TTLCache:
    def __init__(self, ttl_seconds: float = 86400):  # 24 hours default
        self.ttl = ttl_seconds
        self._store: dict[str, TTLCacheEntry] = {}

    def get(self, key: str) -> str | None:
        entry = self._store.get(key)
        if entry is None:
            return None
        if time.monotonic() > entry.expires_at:
            del self._store[key]
            return None
        return entry.value

    def set(self, key: str, value: str):
        self._store[key] = TTLCacheEntry(
            value=value,
            expires_at=time.monotonic() + self.ttl,
        )

    def invalidate(self, key: str):
        self._store.pop(key, None)

    def clear_expired(self):
        now = time.monotonic()
        expired = [k for k, v in self._store.items() if now > v.expires_at]
        for k in expired:
            del self._store[k]

For pharmaceutical content: set TTL to 30 days (drug information is stable but occasional updates happen). Add a webhook or manual invalidation for when a drug's labeling changes.

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.