Implement LLM Response Caching
Build a semantic cache for LLM responses. Cache exact matches with a hash, and semantically similar queries with embedding similarity to reduce API costs and latency.
Why Cache LLM Responses?
LLM API calls are expensive (cost) and slow (latency). Users often ask similar questions. Caching lets you serve repeated queries instantly from cache rather than calling the API.
Two caching strategies:
- Exact match: Hash the prompt, check cache — O(1) lookup
- Semantic match: Find cached responses for similar (not identical) queries using embedding similarity
Exact Match Cache (LRU)
The simplest cache — store prompt hash → response:
import hashlib
from collections import OrderedDict
from typing import Optional
class LRUCache:
"""Least Recently Used cache with maximum size."""
def __init__(self, max_size: int = 1000):
self.max_size = max_size
self._cache: OrderedDict[str, str] = OrderedDict()
def _hash_key(self, prompt: str) -> str:
return hashlib.sha256(prompt.encode()).hexdigest()
def get(self, prompt: str) -> Optional[str]:
key = self._hash_key(prompt)
if key in self._cache:
self._cache.move_to_end(key) # Mark as recently used
return self._cache[key]
return None
def set(self, prompt: str, response: str):
key = self._hash_key(prompt)
if key in self._cache:
self._cache.move_to_end(key)
self._cache[key] = response
if len(self._cache) > self.max_size:
self._cache.popitem(last=False) # Remove least recently used
def size(self) -> int:
return len(self._cache)
# Test
cache = LRUCache(max_size=3)
cache.set("What is metformin?", "Metformin is a biguanide antidiabetic drug...")
cache.set("What is warfarin?", "Warfarin is an anticoagulant...")
print(cache.get("What is metformin?")) # Cache hit
print(cache.get("Tell me about aspirin")) # None — cache missSemantic Cache
Match queries based on meaning, not exact text:
import numpy as np
from dataclasses import dataclass
@dataclass
class CacheEntry:
prompt: str
response: str
embedding: np.ndarray
hit_count: int = 0
class SemanticCache:
"""Cache LLM responses using semantic similarity for lookup."""
def __init__(
self,
similarity_threshold: float = 0.92,
max_entries: int = 500,
):
self.threshold = similarity_threshold
self.max_entries = max_entries
self.entries: list[CacheEntry] = []
self._embeddings_matrix: np.ndarray | None = None
def _rebuild_matrix(self):
"""Rebuild the embedding matrix for efficient search."""
if self.entries:
embeddings = np.stack([e.embedding for e in self.entries])
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
self._embeddings_matrix = embeddings / (norms + 1e-10)
def get(self, query: str, query_embedding: np.ndarray) -> str | None:
"""Find a cached response with high similarity to the query."""
if not self.entries or self._embeddings_matrix is None:
return None
# Normalize query embedding
query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-10)
# Cosine similarities against all cached entries
similarities = self._embeddings_matrix @ query_norm # (n,)
best_idx = int(np.argmax(similarities))
best_score = float(similarities[best_idx])
if best_score >= self.threshold:
self.entries[best_idx].hit_count += 1
return self.entries[best_idx].response
return None
def set(self, prompt: str, response: str, embedding: np.ndarray):
"""Cache a prompt-response pair with its embedding."""
if len(self.entries) >= self.max_entries:
# Evict the entry with the fewest hits
min_hits_idx = min(range(len(self.entries)), key=lambda i: self.entries[i].hit_count)
self.entries.pop(min_hits_idx)
self.entries.append(CacheEntry(
prompt=prompt,
response=response,
embedding=embedding,
))
self._rebuild_matrix()
# Simulated embedding function
def embed(text: str) -> np.ndarray:
"""Placeholder — replace with actual embedding API call."""
rng = np.random.RandomState(hash(text) % (2**31))
return rng.randn(128).astype(np.float32)
# Test
cache = SemanticCache(similarity_threshold=0.95)
# Cache a response
q1 = "What is the mechanism of action of metformin?"
emb1 = embed(q1)
cache.set(q1, "Metformin activates AMPK, reducing hepatic glucose production...", emb1)
# Exact same query — hit
result = cache.get(q1, embed(q1))
print(f"Exact query: {'HIT' if result else 'MISS'}")
# Very different query — miss (different embedding)
q2 = "What are the side effects of warfarin?"
result = cache.get(q2, embed(q2))
print(f"Different query: {'HIT' if result else 'MISS'}")Full Caching Pipeline
Combining exact and semantic cache:
import time
from openai import OpenAI
client = OpenAI()
class LLMCacheLayer:
"""Two-level cache: exact match first, then semantic."""
def __init__(
self,
exact_cache_size: int = 10_000,
semantic_threshold: float = 0.93,
):
self.exact_cache = LRUCache(max_size=exact_cache_size)
self.semantic_cache = SemanticCache(similarity_threshold=semantic_threshold)
self.stats = {"hits_exact": 0, "hits_semantic": 0, "misses": 0, "api_calls": 0}
def _embed(self, text: str) -> np.ndarray:
"""Embed using OpenAI (replace with your embedding model)."""
resp = client.embeddings.create(
input=text,
model="text-embedding-3-small",
)
return np.array(resp.data[0].embedding)
def _call_llm(self, prompt: str, system: str) -> str:
self.stats["api_calls"] += 1
resp = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": system},
{"role": "user", "content": prompt},
],
temperature=0.1,
)
return resp.choices[0].message.content
def query(self, prompt: str, system: str = "") -> dict:
start = time.time()
full_key = f"{system}||{prompt}"
# Level 1: Exact match
exact_result = self.exact_cache.get(full_key)
if exact_result:
self.stats["hits_exact"] += 1
return {"response": exact_result, "source": "exact_cache", "latency_ms": (time.time()-start)*1000}
# Level 2: Semantic match
embedding = self._embed(prompt)
semantic_result = self.semantic_cache.get(prompt, embedding)
if semantic_result:
self.stats["hits_semantic"] += 1
self.exact_cache.set(full_key, semantic_result) # Promote to exact cache
return {"response": semantic_result, "source": "semantic_cache", "latency_ms": (time.time()-start)*1000}
# Level 3: API call
self.stats["misses"] += 1
response = self._call_llm(prompt, system)
# Cache for future use
self.exact_cache.set(full_key, response)
self.semantic_cache.set(prompt, response, embedding)
return {"response": response, "source": "api", "latency_ms": (time.time()-start)*1000}
def cache_stats(self) -> dict:
total = sum(self.stats.values()) - self.stats["api_calls"]
if total == 0:
return self.stats
hit_rate = (self.stats["hits_exact"] + self.stats["hits_semantic"]) / max(1, total)
return {**self.stats, "hit_rate": hit_rate}TTL and Invalidation
For medical content, cached responses may become outdated as guidelines change. Add TTL (Time To Live):
import time
from dataclasses import dataclass, field
@dataclass
class TTLCacheEntry:
value: str
expires_at: float
class TTLCache:
def __init__(self, ttl_seconds: float = 86400): # 24 hours default
self.ttl = ttl_seconds
self._store: dict[str, TTLCacheEntry] = {}
def get(self, key: str) -> str | None:
entry = self._store.get(key)
if entry is None:
return None
if time.monotonic() > entry.expires_at:
del self._store[key]
return None
return entry.value
def set(self, key: str, value: str):
self._store[key] = TTLCacheEntry(
value=value,
expires_at=time.monotonic() + self.ttl,
)
def invalidate(self, key: str):
self._store.pop(key, None)
def clear_expired(self):
now = time.monotonic()
expired = [k for k, v in self._store.items() if now > v.expires_at]
for k in expired:
del self._store[k]For pharmaceutical content: set TTL to 30 days (drug information is stable but occasional updates happen). Add a webhook or manual invalidation for when a drug's labeling changes.
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.