Learnixo
Back to blog
AI Systemsintermediate

RAG Caching: Semantic and Exact-Match Strategies

Reduce latency and cost in RAG systems with semantic caching, exact-match Redis caching, TTL strategies, GPTCache, and cache invalidation patterns.

Asma Hafeez KhanMay 16, 20269 min read
RAGCachingRedisGPTCacheSemantic CachePerformance
Share:𝕏

Why Cache in RAG?

RAG pipelines have multiple expensive operations:

  • Embedding generation: 1-5ms per query (negligible, but adds up)
  • Vector search: 10-50ms per query
  • LLM generation: 500ms-5s per response (most expensive)
  • Reranking: 50-200ms per query

Caching attacks the expensive parts. The right caching strategy depends on your query patterns.


Layer 1: Exact Query Cache

Cache exact string matches — fastest, highest precision:

Python
import redis
import hashlib
import json
from typing import Optional
from dataclasses import dataclass

@dataclass
class CachedResponse:
    """A cached RAG response with metadata."""
    query: str
    response: str
    retrieved_docs: list[dict]
    embedding_model: str
    llm_model: str
    timestamp: float
    ttl_seconds: int


class ExactQueryCache:
    """Redis-backed exact query cache for RAG responses."""

    def __init__(
        self,
        redis_url: str = "redis://localhost:6379",
        default_ttl: int = 3600,    # 1 hour
        key_prefix: str = "rag:exact:",
    ):
        self.client = redis.from_url(redis_url, decode_responses=True)
        self.default_ttl = default_ttl
        self.key_prefix = key_prefix

    def _make_key(self, query: str, context: str = "") -> str:
        """Generate a stable cache key from query + context."""
        combined = f"{query}|{context}".strip()
        return self.key_prefix + hashlib.sha256(combined.encode()).hexdigest()

    def get(self, query: str, context: str = "") -> Optional[CachedResponse]:
        """Retrieve a cached response, or None if not found."""
        key = self._make_key(query, context)
        data = self.client.get(key)
        if data is None:
            return None
        return CachedResponse(**json.loads(data))

    def set(
        self,
        query: str,
        response: CachedResponse,
        context: str = "",
        ttl: Optional[int] = None,
    ) -> None:
        """Cache a response with TTL."""
        key = self._make_key(query, context)
        self.client.setex(
            key,
            ttl or self.default_ttl,
            json.dumps(response.__dict__),
        )

    def invalidate(self, query: str, context: str = "") -> bool:
        """Remove a specific cached response."""
        key = self._make_key(query, context)
        return bool(self.client.delete(key))

    def invalidate_pattern(self, pattern: str) -> int:
        """Remove all keys matching a pattern (use with care)."""
        keys = self.client.keys(f"{self.key_prefix}*{pattern}*")
        if keys:
            return self.client.delete(*keys)
        return 0

    def get_stats(self) -> dict:
        """Return cache statistics."""
        keys = self.client.keys(f"{self.key_prefix}*")
        return {
            "total_cached": len(keys),
            "prefix": self.key_prefix,
        }

Layer 2: Semantic Cache

Cache based on meaning — similar questions hit the same cache entry:

Python
import numpy as np
from openai import OpenAI

oai_client = OpenAI()


def embed_query(query: str) -> np.ndarray:
    """Embed a query for semantic similarity comparison."""
    response = oai_client.embeddings.create(
        model="text-embedding-3-small",
        input=[query],
    )
    return np.array(response.data[0].embedding)


class SemanticCache:
    """
    Semantic cache using embedding similarity.
    Stores (embedding, response) pairs and retrieves based on cosine similarity.
    """

    def __init__(
        self,
        similarity_threshold: float = 0.95,
        max_size: int = 10000,
    ):
        self.similarity_threshold = similarity_threshold
        self.max_size = max_size

        self.embeddings: list[np.ndarray] = []
        self.responses: list[dict] = []
        self.queries: list[str] = []

    def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
        return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))

    def get(self, query_embedding: np.ndarray) -> Optional[dict]:
        """Find a cached response for a semantically similar query."""
        if not self.embeddings:
            return None

        emb_matrix = np.stack(self.embeddings)
        similarities = emb_matrix @ query_embedding / (
            np.linalg.norm(emb_matrix, axis=1) * np.linalg.norm(query_embedding)
        )

        best_idx = int(np.argmax(similarities))
        best_sim = float(similarities[best_idx])

        if best_sim >= self.similarity_threshold:
            return {
                **self.responses[best_idx],
                "cache_hit": True,
                "similarity": best_sim,
                "original_query": self.queries[best_idx],
            }
        return None

    def set(
        self,
        query: str,
        query_embedding: np.ndarray,
        response: dict,
    ) -> None:
        """Add a response to the semantic cache."""
        # Evict oldest entry if at capacity
        if len(self.embeddings) >= self.max_size:
            self.embeddings.pop(0)
            self.responses.pop(0)
            self.queries.pop(0)

        self.embeddings.append(query_embedding)
        self.responses.append(response)
        self.queries.append(query)

Redis-Backed Semantic Cache

Scale semantic caching with Redis for persistence and multi-instance support:

Python
import pickle

class RedisSemCache:
    """
    Redis-backed semantic cache.
    Stores embeddings in Redis using sorted sets for approximate retrieval.
    For production at scale, use Redis with vector similarity (Redis Stack).
    """

    def __init__(
        self,
        redis_url: str = "redis://localhost:6379",
        similarity_threshold: float = 0.93,
        ttl: int = 7200,  # 2 hours
        key_prefix: str = "rag:sem:",
    ):
        self.client = redis.from_url(redis_url)
        self.threshold = similarity_threshold
        self.ttl = ttl
        self.prefix = key_prefix
        self._index_key = f"{key_prefix}index"

    def _serialize_embedding(self, emb: np.ndarray) -> bytes:
        return pickle.dumps(emb)

    def _deserialize_embedding(self, data: bytes) -> np.ndarray:
        return pickle.loads(data)

    def get(self, query_embedding: np.ndarray) -> Optional[dict]:
        """Search Redis for a semantically similar cached query."""
        # Load all cached embeddings (for small caches; use Redis VSS for large)
        all_keys = self.client.keys(f"{self.prefix}emb:*")
        if not all_keys:
            return None

        best_sim = 0.0
        best_response_key = None

        for key in all_keys:
            emb_bytes = self.client.get(key)
            if emb_bytes is None:
                continue
            cached_emb = self._deserialize_embedding(emb_bytes)
            sim = float(np.dot(query_embedding, cached_emb) / (
                np.linalg.norm(query_embedding) * np.linalg.norm(cached_emb)
            ))
            if sim > best_sim:
                best_sim = sim
                # Response key mirrors embedding key
                best_response_key = key.decode().replace(":emb:", ":resp:")

        if best_sim >= self.threshold and best_response_key:
            resp_bytes = self.client.get(best_response_key)
            if resp_bytes:
                return {**json.loads(resp_bytes), "similarity": best_sim}

        return None

    def set(self, query: str, query_embedding: np.ndarray, response: dict) -> None:
        """Store embedding + response pair."""
        cache_id = hashlib.sha256(query.encode()).hexdigest()[:16]

        emb_key = f"{self.prefix}emb:{cache_id}"
        resp_key = f"{self.prefix}resp:{cache_id}"

        self.client.setex(emb_key, self.ttl, self._serialize_embedding(query_embedding))
        self.client.setex(resp_key, self.ttl, json.dumps(response))

GPTCache Integration

GPTCache provides a full-featured semantic cache with multiple backends:

Python
# pip install gptcache

from gptcache import cache
from gptcache.adapter import openai as cached_openai
from gptcache.embedding import Onnx
from gptcache.manager import CacheBase, VectorBase, get_data_manager
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation


def init_gptcache():
    """Initialize GPTCache with SQLite + Faiss backend."""
    # Embedding model (runs locally via ONNX)
    onnx = Onnx()

    # Data manager: SQLite for responses, Faiss for embeddings
    data_manager = get_data_manager(
        CacheBase("sqlite"),
        VectorBase("faiss", dimension=onnx.dimension),
    )

    cache.init(
        embedding_func=onnx.to_embeddings,
        data_manager=data_manager,
        similarity_evaluation=SearchDistanceEvaluation(),
    )

    return cache


# After init, use cached_openai drop-in instead of openai
def rag_with_gptcache(query: str, context: str) -> str:
    """RAG using GPTCache — identical API to OpenAI but with caching."""
    response = cached_openai.ChatCompletion.create(
        model="gpt-4o",
        messages=[
            {
                "role": "system",
                "content": "Answer the question using only the provided context.",
            },
            {
                "role": "user",
                "content": f"Context:\n{context}\n\nQuestion: {query}",
            },
        ],
        temperature=0,
    )
    return response.choices[0].message.content

TTL Strategy by Query Type

Different queries deserve different cache lifetimes:

Python
from enum import Enum

class QueryType(Enum):
    DRUG_INTERACTION = "drug_interaction"
    DOSING = "dosing"
    MECHANISM = "mechanism"
    CLINICAL_TRIAL = "clinical_trial"
    NEWS = "news"
    DEFINITION = "definition"


def classify_query_type(query: str) -> QueryType:
    """Classify query to assign appropriate cache TTL."""
    query_lower = query.lower()

    if any(w in query_lower for w in ["interact", "combination", "together with", "contraindicated"]):
        return QueryType.DRUG_INTERACTION
    if any(w in query_lower for w in ["dose", "dosing", "how much", "mg", "frequency"]):
        return QueryType.DOSING
    if any(w in query_lower for w in ["mechanism", "how does", "pathway", "moa"]):
        return QueryType.MECHANISM
    if any(w in query_lower for w in ["trial", "study", "research", "evidence", "efficacy"]):
        return QueryType.CLINICAL_TRIAL
    if any(w in query_lower for w in ["latest", "new", "recent", "current"]):
        return QueryType.NEWS
    return QueryType.DEFINITION


# TTL in seconds by query type
QUERY_TTL = {
    QueryType.DEFINITION: 86400 * 7,       # 7 days  definitions rarely change
    QueryType.MECHANISM: 86400 * 7,         # 7 days
    QueryType.DRUG_INTERACTION: 86400,       # 1 day  guidelines update occasionally
    QueryType.DOSING: 86400,                 # 1 day
    QueryType.CLINICAL_TRIAL: 3600 * 6,     # 6 hours  evidence evolves
    QueryType.NEWS: 1800,                    # 30 minutes  recent events
}


def get_ttl_for_query(query: str) -> int:
    """Return appropriate TTL in seconds for a query."""
    query_type = classify_query_type(query)
    return QUERY_TTL[query_type]

Full Caching Pipeline

Python
import time

class CachedRAGPipeline:
    """RAG pipeline with multi-layer caching."""

    def __init__(
        self,
        retriever,
        llm_client,
        redis_url: str = "redis://localhost:6379",
        semantic_threshold: float = 0.93,
    ):
        self.retriever = retriever
        self.llm = llm_client
        self.exact_cache = ExactQueryCache(redis_url)
        self.semantic_cache = SemanticCache(similarity_threshold=semantic_threshold)
        self.stats = {"exact_hits": 0, "semantic_hits": 0, "misses": 0, "total": 0}

    def query(self, query: str, user_id: str = "") -> dict:
        """Query with caching. Returns response and cache status."""
        self.stats["total"] += 1

        # Layer 1: Exact match cache
        cached = self.exact_cache.get(query, context=user_id)
        if cached:
            self.stats["exact_hits"] += 1
            return {
                "response": cached.response,
                "cache": "exact",
                "latency_ms": 1,
            }

        # Embed query for semantic cache + retrieval
        start = time.time()
        query_embedding = embed_query(query)
        embed_ms = (time.time() - start) * 1000

        # Layer 2: Semantic cache
        sem_cached = self.semantic_cache.get(query_embedding)
        if sem_cached:
            self.stats["semantic_hits"] += 1
            return {
                "response": sem_cached["response"],
                "cache": "semantic",
                "similarity": sem_cached["similarity"],
                "original_query": sem_cached["original_query"],
                "latency_ms": embed_ms,
            }

        # Cache miss  full RAG pipeline
        self.stats["misses"] += 1

        retrieve_start = time.time()
        docs = self.retriever.retrieve(query_embedding, top_k=5)
        retrieve_ms = (time.time() - retrieve_start) * 1000

        context = "\n\n".join([d["content"] for d in docs])

        llm_start = time.time()
        response = self.llm.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": "Answer using only the provided context."},
                {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"},
            ],
            temperature=0,
        ).choices[0].message.content
        llm_ms = (time.time() - llm_start) * 1000

        # Store in both caches
        ttl = get_ttl_for_query(query)
        response_obj = CachedResponse(
            query=query,
            response=response,
            retrieved_docs=docs,
            embedding_model="text-embedding-3-small",
            llm_model="gpt-4o",
            timestamp=time.time(),
            ttl_seconds=ttl,
        )
        self.exact_cache.set(query, response_obj, context=user_id, ttl=ttl)
        self.semantic_cache.set(query, query_embedding, {"response": response})

        return {
            "response": response,
            "cache": "miss",
            "latency_ms": embed_ms + retrieve_ms + llm_ms,
            "breakdown": {
                "embed_ms": embed_ms,
                "retrieve_ms": retrieve_ms,
                "llm_ms": llm_ms,
            },
        }

    def get_hit_rate(self) -> dict:
        """Return cache hit rate statistics."""
        total = self.stats["total"] or 1
        return {
            "total_queries": total,
            "exact_hit_rate": self.stats["exact_hits"] / total,
            "semantic_hit_rate": self.stats["semantic_hits"] / total,
            "miss_rate": self.stats["misses"] / total,
            "overall_hit_rate": (self.stats["exact_hits"] + self.stats["semantic_hits"]) / total,
        }

Cache Invalidation

Python
def invalidate_on_knowledge_update(
    exact_cache: ExactQueryCache,
    topic: str,
    affected_drug: str = None,
) -> dict:
    """
    Invalidate cache entries when underlying knowledge changes.
    Called when new drug information is ingested into the vector store.
    """
    invalidated = 0

    # Exact cache: pattern-based invalidation
    if affected_drug:
        invalidated += exact_cache.invalidate_pattern(affected_drug.lower())

    # Log the invalidation event
    event = {
        "event": "cache_invalidation",
        "topic": topic,
        "affected_drug": affected_drug,
        "entries_invalidated": invalidated,
        "timestamp": time.time(),
    }
    print(f"Cache invalidated: {event}")
    return event


def scheduled_cache_warmup(
    pipeline: CachedRAGPipeline,
    common_queries: list[str],
) -> dict:
    """
    Pre-warm cache with common queries.
    Run this after knowledge base updates to pre-populate cache.
    """
    results = {"warmed": 0, "failed": 0}
    for query in common_queries:
        try:
            pipeline.query(query)
            results["warmed"] += 1
        except Exception as e:
            print(f"Warmup failed for '{query}': {e}")
            results["failed"] += 1
    return results


# Common queries to pre-warm after a drug database update
COMMON_DRUG_QUERIES = [
    "What are the common side effects of metformin?",
    "What drugs interact with warfarin?",
    "What is the mechanism of action of statins?",
    "How does aspirin work?",
    "What are ACE inhibitors used for?",
]

When Caching Hurts

Caching is wrong for:

| Scenario | Why caching fails | Solution | |---|---|---| | Personalized queries | Same question has different answer per patient | Include user context in cache key | | Time-sensitive info | "Current guidelines" changes | Use short TTL or bypass cache | | Low-repetition workloads | Unique research queries never repeat | Disable semantic cache | | High-precision domains | Slight rephrasing needs exact answer | Set high similarity threshold (above 0.97) | | Streaming responses | Cache stores complete response only | Stream from LLM, cache for subsequent requests |

Python
def should_cache(query: str, user_context: dict) -> bool:
    """Decide if this query is worth caching."""
    # Never cache patient-specific queries
    if any(w in query.lower() for w in ["my patient", "patient id", "mrn", "dob"]):
        return False

    # Never cache queries about breaking news
    if any(w in query.lower() for w in ["today", "this week", "just announced", "breaking"]):
        return False

    # Cache everything else
    return True

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.