RAG Caching: Semantic and Exact-Match Strategies
Reduce latency and cost in RAG systems with semantic caching, exact-match Redis caching, TTL strategies, GPTCache, and cache invalidation patterns.
Why Cache in RAG?
RAG pipelines have multiple expensive operations:
- Embedding generation: 1-5ms per query (negligible, but adds up)
- Vector search: 10-50ms per query
- LLM generation: 500ms-5s per response (most expensive)
- Reranking: 50-200ms per query
Caching attacks the expensive parts. The right caching strategy depends on your query patterns.
Layer 1: Exact Query Cache
Cache exact string matches — fastest, highest precision:
import redis
import hashlib
import json
from typing import Optional
from dataclasses import dataclass
@dataclass
class CachedResponse:
"""A cached RAG response with metadata."""
query: str
response: str
retrieved_docs: list[dict]
embedding_model: str
llm_model: str
timestamp: float
ttl_seconds: int
class ExactQueryCache:
"""Redis-backed exact query cache for RAG responses."""
def __init__(
self,
redis_url: str = "redis://localhost:6379",
default_ttl: int = 3600, # 1 hour
key_prefix: str = "rag:exact:",
):
self.client = redis.from_url(redis_url, decode_responses=True)
self.default_ttl = default_ttl
self.key_prefix = key_prefix
def _make_key(self, query: str, context: str = "") -> str:
"""Generate a stable cache key from query + context."""
combined = f"{query}|{context}".strip()
return self.key_prefix + hashlib.sha256(combined.encode()).hexdigest()
def get(self, query: str, context: str = "") -> Optional[CachedResponse]:
"""Retrieve a cached response, or None if not found."""
key = self._make_key(query, context)
data = self.client.get(key)
if data is None:
return None
return CachedResponse(**json.loads(data))
def set(
self,
query: str,
response: CachedResponse,
context: str = "",
ttl: Optional[int] = None,
) -> None:
"""Cache a response with TTL."""
key = self._make_key(query, context)
self.client.setex(
key,
ttl or self.default_ttl,
json.dumps(response.__dict__),
)
def invalidate(self, query: str, context: str = "") -> bool:
"""Remove a specific cached response."""
key = self._make_key(query, context)
return bool(self.client.delete(key))
def invalidate_pattern(self, pattern: str) -> int:
"""Remove all keys matching a pattern (use with care)."""
keys = self.client.keys(f"{self.key_prefix}*{pattern}*")
if keys:
return self.client.delete(*keys)
return 0
def get_stats(self) -> dict:
"""Return cache statistics."""
keys = self.client.keys(f"{self.key_prefix}*")
return {
"total_cached": len(keys),
"prefix": self.key_prefix,
}Layer 2: Semantic Cache
Cache based on meaning — similar questions hit the same cache entry:
import numpy as np
from openai import OpenAI
oai_client = OpenAI()
def embed_query(query: str) -> np.ndarray:
"""Embed a query for semantic similarity comparison."""
response = oai_client.embeddings.create(
model="text-embedding-3-small",
input=[query],
)
return np.array(response.data[0].embedding)
class SemanticCache:
"""
Semantic cache using embedding similarity.
Stores (embedding, response) pairs and retrieves based on cosine similarity.
"""
def __init__(
self,
similarity_threshold: float = 0.95,
max_size: int = 10000,
):
self.similarity_threshold = similarity_threshold
self.max_size = max_size
self.embeddings: list[np.ndarray] = []
self.responses: list[dict] = []
self.queries: list[str] = []
def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
def get(self, query_embedding: np.ndarray) -> Optional[dict]:
"""Find a cached response for a semantically similar query."""
if not self.embeddings:
return None
emb_matrix = np.stack(self.embeddings)
similarities = emb_matrix @ query_embedding / (
np.linalg.norm(emb_matrix, axis=1) * np.linalg.norm(query_embedding)
)
best_idx = int(np.argmax(similarities))
best_sim = float(similarities[best_idx])
if best_sim >= self.similarity_threshold:
return {
**self.responses[best_idx],
"cache_hit": True,
"similarity": best_sim,
"original_query": self.queries[best_idx],
}
return None
def set(
self,
query: str,
query_embedding: np.ndarray,
response: dict,
) -> None:
"""Add a response to the semantic cache."""
# Evict oldest entry if at capacity
if len(self.embeddings) >= self.max_size:
self.embeddings.pop(0)
self.responses.pop(0)
self.queries.pop(0)
self.embeddings.append(query_embedding)
self.responses.append(response)
self.queries.append(query)Redis-Backed Semantic Cache
Scale semantic caching with Redis for persistence and multi-instance support:
import pickle
class RedisSemCache:
"""
Redis-backed semantic cache.
Stores embeddings in Redis using sorted sets for approximate retrieval.
For production at scale, use Redis with vector similarity (Redis Stack).
"""
def __init__(
self,
redis_url: str = "redis://localhost:6379",
similarity_threshold: float = 0.93,
ttl: int = 7200, # 2 hours
key_prefix: str = "rag:sem:",
):
self.client = redis.from_url(redis_url)
self.threshold = similarity_threshold
self.ttl = ttl
self.prefix = key_prefix
self._index_key = f"{key_prefix}index"
def _serialize_embedding(self, emb: np.ndarray) -> bytes:
return pickle.dumps(emb)
def _deserialize_embedding(self, data: bytes) -> np.ndarray:
return pickle.loads(data)
def get(self, query_embedding: np.ndarray) -> Optional[dict]:
"""Search Redis for a semantically similar cached query."""
# Load all cached embeddings (for small caches; use Redis VSS for large)
all_keys = self.client.keys(f"{self.prefix}emb:*")
if not all_keys:
return None
best_sim = 0.0
best_response_key = None
for key in all_keys:
emb_bytes = self.client.get(key)
if emb_bytes is None:
continue
cached_emb = self._deserialize_embedding(emb_bytes)
sim = float(np.dot(query_embedding, cached_emb) / (
np.linalg.norm(query_embedding) * np.linalg.norm(cached_emb)
))
if sim > best_sim:
best_sim = sim
# Response key mirrors embedding key
best_response_key = key.decode().replace(":emb:", ":resp:")
if best_sim >= self.threshold and best_response_key:
resp_bytes = self.client.get(best_response_key)
if resp_bytes:
return {**json.loads(resp_bytes), "similarity": best_sim}
return None
def set(self, query: str, query_embedding: np.ndarray, response: dict) -> None:
"""Store embedding + response pair."""
cache_id = hashlib.sha256(query.encode()).hexdigest()[:16]
emb_key = f"{self.prefix}emb:{cache_id}"
resp_key = f"{self.prefix}resp:{cache_id}"
self.client.setex(emb_key, self.ttl, self._serialize_embedding(query_embedding))
self.client.setex(resp_key, self.ttl, json.dumps(response))GPTCache Integration
GPTCache provides a full-featured semantic cache with multiple backends:
# pip install gptcache
from gptcache import cache
from gptcache.adapter import openai as cached_openai
from gptcache.embedding import Onnx
from gptcache.manager import CacheBase, VectorBase, get_data_manager
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation
def init_gptcache():
"""Initialize GPTCache with SQLite + Faiss backend."""
# Embedding model (runs locally via ONNX)
onnx = Onnx()
# Data manager: SQLite for responses, Faiss for embeddings
data_manager = get_data_manager(
CacheBase("sqlite"),
VectorBase("faiss", dimension=onnx.dimension),
)
cache.init(
embedding_func=onnx.to_embeddings,
data_manager=data_manager,
similarity_evaluation=SearchDistanceEvaluation(),
)
return cache
# After init, use cached_openai drop-in instead of openai
def rag_with_gptcache(query: str, context: str) -> str:
"""RAG using GPTCache — identical API to OpenAI but with caching."""
response = cached_openai.ChatCompletion.create(
model="gpt-4o",
messages=[
{
"role": "system",
"content": "Answer the question using only the provided context.",
},
{
"role": "user",
"content": f"Context:\n{context}\n\nQuestion: {query}",
},
],
temperature=0,
)
return response.choices[0].message.contentTTL Strategy by Query Type
Different queries deserve different cache lifetimes:
from enum import Enum
class QueryType(Enum):
DRUG_INTERACTION = "drug_interaction"
DOSING = "dosing"
MECHANISM = "mechanism"
CLINICAL_TRIAL = "clinical_trial"
NEWS = "news"
DEFINITION = "definition"
def classify_query_type(query: str) -> QueryType:
"""Classify query to assign appropriate cache TTL."""
query_lower = query.lower()
if any(w in query_lower for w in ["interact", "combination", "together with", "contraindicated"]):
return QueryType.DRUG_INTERACTION
if any(w in query_lower for w in ["dose", "dosing", "how much", "mg", "frequency"]):
return QueryType.DOSING
if any(w in query_lower for w in ["mechanism", "how does", "pathway", "moa"]):
return QueryType.MECHANISM
if any(w in query_lower for w in ["trial", "study", "research", "evidence", "efficacy"]):
return QueryType.CLINICAL_TRIAL
if any(w in query_lower for w in ["latest", "new", "recent", "current"]):
return QueryType.NEWS
return QueryType.DEFINITION
# TTL in seconds by query type
QUERY_TTL = {
QueryType.DEFINITION: 86400 * 7, # 7 days — definitions rarely change
QueryType.MECHANISM: 86400 * 7, # 7 days
QueryType.DRUG_INTERACTION: 86400, # 1 day — guidelines update occasionally
QueryType.DOSING: 86400, # 1 day
QueryType.CLINICAL_TRIAL: 3600 * 6, # 6 hours — evidence evolves
QueryType.NEWS: 1800, # 30 minutes — recent events
}
def get_ttl_for_query(query: str) -> int:
"""Return appropriate TTL in seconds for a query."""
query_type = classify_query_type(query)
return QUERY_TTL[query_type]Full Caching Pipeline
import time
class CachedRAGPipeline:
"""RAG pipeline with multi-layer caching."""
def __init__(
self,
retriever,
llm_client,
redis_url: str = "redis://localhost:6379",
semantic_threshold: float = 0.93,
):
self.retriever = retriever
self.llm = llm_client
self.exact_cache = ExactQueryCache(redis_url)
self.semantic_cache = SemanticCache(similarity_threshold=semantic_threshold)
self.stats = {"exact_hits": 0, "semantic_hits": 0, "misses": 0, "total": 0}
def query(self, query: str, user_id: str = "") -> dict:
"""Query with caching. Returns response and cache status."""
self.stats["total"] += 1
# Layer 1: Exact match cache
cached = self.exact_cache.get(query, context=user_id)
if cached:
self.stats["exact_hits"] += 1
return {
"response": cached.response,
"cache": "exact",
"latency_ms": 1,
}
# Embed query for semantic cache + retrieval
start = time.time()
query_embedding = embed_query(query)
embed_ms = (time.time() - start) * 1000
# Layer 2: Semantic cache
sem_cached = self.semantic_cache.get(query_embedding)
if sem_cached:
self.stats["semantic_hits"] += 1
return {
"response": sem_cached["response"],
"cache": "semantic",
"similarity": sem_cached["similarity"],
"original_query": sem_cached["original_query"],
"latency_ms": embed_ms,
}
# Cache miss — full RAG pipeline
self.stats["misses"] += 1
retrieve_start = time.time()
docs = self.retriever.retrieve(query_embedding, top_k=5)
retrieve_ms = (time.time() - retrieve_start) * 1000
context = "\n\n".join([d["content"] for d in docs])
llm_start = time.time()
response = self.llm.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "Answer using only the provided context."},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"},
],
temperature=0,
).choices[0].message.content
llm_ms = (time.time() - llm_start) * 1000
# Store in both caches
ttl = get_ttl_for_query(query)
response_obj = CachedResponse(
query=query,
response=response,
retrieved_docs=docs,
embedding_model="text-embedding-3-small",
llm_model="gpt-4o",
timestamp=time.time(),
ttl_seconds=ttl,
)
self.exact_cache.set(query, response_obj, context=user_id, ttl=ttl)
self.semantic_cache.set(query, query_embedding, {"response": response})
return {
"response": response,
"cache": "miss",
"latency_ms": embed_ms + retrieve_ms + llm_ms,
"breakdown": {
"embed_ms": embed_ms,
"retrieve_ms": retrieve_ms,
"llm_ms": llm_ms,
},
}
def get_hit_rate(self) -> dict:
"""Return cache hit rate statistics."""
total = self.stats["total"] or 1
return {
"total_queries": total,
"exact_hit_rate": self.stats["exact_hits"] / total,
"semantic_hit_rate": self.stats["semantic_hits"] / total,
"miss_rate": self.stats["misses"] / total,
"overall_hit_rate": (self.stats["exact_hits"] + self.stats["semantic_hits"]) / total,
}Cache Invalidation
def invalidate_on_knowledge_update(
exact_cache: ExactQueryCache,
topic: str,
affected_drug: str = None,
) -> dict:
"""
Invalidate cache entries when underlying knowledge changes.
Called when new drug information is ingested into the vector store.
"""
invalidated = 0
# Exact cache: pattern-based invalidation
if affected_drug:
invalidated += exact_cache.invalidate_pattern(affected_drug.lower())
# Log the invalidation event
event = {
"event": "cache_invalidation",
"topic": topic,
"affected_drug": affected_drug,
"entries_invalidated": invalidated,
"timestamp": time.time(),
}
print(f"Cache invalidated: {event}")
return event
def scheduled_cache_warmup(
pipeline: CachedRAGPipeline,
common_queries: list[str],
) -> dict:
"""
Pre-warm cache with common queries.
Run this after knowledge base updates to pre-populate cache.
"""
results = {"warmed": 0, "failed": 0}
for query in common_queries:
try:
pipeline.query(query)
results["warmed"] += 1
except Exception as e:
print(f"Warmup failed for '{query}': {e}")
results["failed"] += 1
return results
# Common queries to pre-warm after a drug database update
COMMON_DRUG_QUERIES = [
"What are the common side effects of metformin?",
"What drugs interact with warfarin?",
"What is the mechanism of action of statins?",
"How does aspirin work?",
"What are ACE inhibitors used for?",
]When Caching Hurts
Caching is wrong for:
| Scenario | Why caching fails | Solution | |---|---|---| | Personalized queries | Same question has different answer per patient | Include user context in cache key | | Time-sensitive info | "Current guidelines" changes | Use short TTL or bypass cache | | Low-repetition workloads | Unique research queries never repeat | Disable semantic cache | | High-precision domains | Slight rephrasing needs exact answer | Set high similarity threshold (above 0.97) | | Streaming responses | Cache stores complete response only | Stream from LLM, cache for subsequent requests |
def should_cache(query: str, user_context: dict) -> bool:
"""Decide if this query is worth caching."""
# Never cache patient-specific queries
if any(w in query.lower() for w in ["my patient", "patient id", "mrn", "dob"]):
return False
# Never cache queries about breaking news
if any(w in query.lower() for w in ["today", "this week", "just announced", "breaking"]):
return False
# Cache everything else
return TrueFound this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.