Scenario Based Questions · Lesson 3 of 13
Scenario: RAG System Is Too Slow at Scale
The Scenario
Your RAG chatbot is functionally correct but painfully slow. Telemetry shows:
- P50 latency: 8 seconds
- P95 latency: 12 seconds
- User drop-off (session abandonment) has increased by 35% since launch
Research from Google shows that users expect web responses in under 2 seconds, and each additional second of latency reduces conversions by 7%. A 12-second RAG response is a product-killing problem.
Let's break down where the time goes, then attack each component.
Step 1: Instrument the Pipeline
You cannot optimize what you do not measure. Add timing spans to every stage:
import time
import logging
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Dict
logger = logging.getLogger("rag.latency")
@dataclass
class LatencyProfile:
query_id: str
spans: Dict[str, float] = field(default_factory=dict)
total_ms: float = 0.0
def record(self, stage: str, duration_ms: float):
self.spans[stage] = duration_ms
def report(self):
logger.info({
"query_id": self.query_id,
"total_ms": self.total_ms,
"breakdown": self.spans,
})
@contextmanager
def timed_stage(profile: LatencyProfile, stage: str):
start = time.perf_counter()
try:
yield
finally:
elapsed_ms = (time.perf_counter() - start) * 1000
profile.record(stage, elapsed_ms)
async def rag_with_profiling(query: str, query_id: str):
profile = LatencyProfile(query_id=query_id)
total_start = time.perf_counter()
# Stage 1: Embed the query
with timed_stage(profile, "query_embedding"):
query_embedding = await embed_query(query)
# Stage 2: Vector search
with timed_stage(profile, "vector_search"):
chunks = await vector_store.search(query_embedding, k=5)
# Stage 3: Parse and format context
with timed_stage(profile, "context_assembly"):
context = build_context_string(chunks)
# Stage 4: LLM completion
with timed_stage(profile, "llm_completion"):
response = await call_llm(context, query)
profile.total_ms = (time.perf_counter() - total_start) * 1000
profile.report()
return responseAfter one day of profiling, a typical breakdown looks like this:
| Stage | Typical Duration | Share of Total | |---|---|---| | Query embedding | 400-600 ms | 6% | | Vector search | 200-400 ms | 4% | | Context assembly | 50-100 ms | 1% | | LLM completion | 6,000-9,000 ms | 89% |
The LLM is almost always the bottleneck. Any optimization that avoids calling the LLM entirely is the highest-leverage change.
Biggest Win: Semantic Cache
A semantic cache stores LLM responses and returns cached answers for semantically similar queries — not just exact duplicates. If someone asks "how many vacation days do I get?" and the cache has a response for "what is my annual leave entitlement?", the cache hits and you skip the LLM entirely.
This can eliminate 40-70% of LLM calls in steady-state production.
Architecture
User Query
│
▼
[Embed Query]
│
▼
[Redis: search for similar cached query]
│
├── HIT (cosine similarity > 0.92): return cached response (under 50ms total)
│
└── MISS: run full RAG pipeline → store result in Redis → return responseImplementation
import redis
import json
import hashlib
import numpy as np
from openai import AzureOpenAI
client = AzureOpenAI(
azure_endpoint="https://your-resource.openai.azure.com",
api_version="2024-02-01",
)
redis_client = redis.Redis(host="your-redis.cache.windows.net", port=6380, ssl=True)
CACHE_SIMILARITY_THRESHOLD = 0.92
CACHE_TTL_SECONDS = 3600 # 1 hour
def embed_text(text: str) -> list[float]:
return client.embeddings.create(
model="text-embedding-3-large",
input=text,
).data[0].embedding
def cosine_similarity(a: list[float], b: list[float]) -> float:
a_np, b_np = np.array(a), np.array(b)
return float(np.dot(a_np, b_np) / (np.linalg.norm(a_np) * np.linalg.norm(b_np)))
class SemanticCache:
def __init__(self, redis_client, namespace: str = "rag_cache"):
self.redis = redis_client
self.namespace = namespace
self._cache_index = [] # In production, use Redis Search or a separate vector store
def _load_index(self):
"""Load all cached embeddings into memory for similarity search."""
keys = self.redis.keys(f"{self.namespace}:*")
self._cache_index = []
for key in keys:
entry = json.loads(self.redis.get(key))
self._cache_index.append({
"key": key.decode(),
"embedding": entry["query_embedding"],
"response": entry["response"],
})
def get(self, query_embedding: list[float]) -> str | None:
"""
Search the cache for a similar query.
Returns the cached response if found, else None.
"""
self._load_index() # In production, use an efficient index
best_score = 0.0
best_response = None
for entry in self._cache_index:
score = cosine_similarity(query_embedding, entry["embedding"])
if score > best_score:
best_score = score
best_response = entry["response"]
if best_score >= CACHE_SIMILARITY_THRESHOLD:
return best_response
return None
def set(self, query: str, query_embedding: list[float], response: str):
"""Store a query-response pair in the cache."""
cache_key = f"{self.namespace}:{hashlib.sha256(query.encode()).hexdigest()[:16]}"
entry = {
"query": query,
"query_embedding": query_embedding,
"response": response,
}
self.redis.setex(cache_key, CACHE_TTL_SECONDS, json.dumps(entry))
semantic_cache = SemanticCache(redis_client)
async def cached_rag_query(query: str) -> dict:
# Step 1: Embed query (fast: ~400ms)
query_embedding = embed_text(query)
# Step 2: Check semantic cache
cached_response = semantic_cache.get(query_embedding)
if cached_response:
return {
"response": cached_response,
"source": "cache",
"latency_saved_ms": 7000, # approximate LLM call cost
}
# Step 3: Cache miss — run full RAG pipeline
chunks = await vector_store.search(query_embedding, k=5)
context = build_context_string(chunks)
response = await call_llm(context, query)
# Step 4: Store in cache for future queries
semantic_cache.set(query, query_embedding, response)
return {
"response": response,
"source": "llm",
"latency_saved_ms": 0,
}Production note: The in-memory index scan above is fine for caches under a few thousand entries. For larger caches, use Redis Search with vector indexing, or a dedicated vector store like Qdrant as your cache backend.
Fix 2: Streaming to Reduce Perceived Latency
Streaming does not reduce total time-to-completion, but it dramatically improves perceived responsiveness. Users start reading after the first token arrives (typically 1-2 seconds) rather than waiting for the full 8-second response.
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import asyncio
app = FastAPI()
async def stream_rag_response(query: str):
"""
Generator that yields tokens as they arrive from the LLM.
"""
# Retrieve context first (non-streamed, fast)
query_embedding = embed_text(query)
chunks = await vector_store.search(query_embedding, k=5)
context = build_context_string(chunks)
# Stream the LLM response
stream = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": f"Context:\n{context}"},
{"role": "user", "content": query},
],
stream=True,
)
for chunk in stream:
delta = chunk.choices[0].delta.content
if delta:
yield f"data: {json.dumps({'token': delta})}\n\n"
await asyncio.sleep(0) # yield control to event loop
yield "data: [DONE]\n\n"
@app.get("/query/stream")
async def stream_endpoint(query: str):
return StreamingResponse(
stream_rag_response(query),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no", # disable nginx buffering
},
)On the frontend, display tokens as they arrive:
async function streamQuery(query: string, onToken: (t: string) => void) {
const response = await fetch(`/query/stream?query=${encodeURIComponent(query)}`);
const reader = response.body!.getReader();
const decoder = new TextDecoder();
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
for (const line of chunk.split("\n")) {
if (line.startsWith("data: ") && line !== "data: [DONE]") {
const data = JSON.parse(line.slice(6));
onToken(data.token);
}
}
}
}Fix 3: Async Parallel Retrieval
If you query multiple vector stores or knowledge bases, run them in parallel:
import asyncio
async def parallel_retrieval(query: str, k_per_source: int = 3) -> List[dict]:
"""
Retrieve from multiple sources concurrently instead of sequentially.
"""
query_embedding = embed_text(query)
# Launch all retrievals concurrently
tasks = [
product_docs_store.search(query_embedding, k=k_per_source),
policy_store.search(query_embedding, k=k_per_source),
faq_store.search(query_embedding, k=k_per_source),
]
results = await asyncio.gather(*tasks)
# Merge and deduplicate
all_chunks = [chunk for source_results in results for chunk in source_results]
return deduplicate_and_rerank(all_chunks, query)Sequential retrieval from three sources: 3 x 300ms = 900ms. Parallel: ~300ms.
Fix 4: Reduce Context Size
More context tokens = slower LLM response. Trim aggressively:
def trim_context(chunks: List[str], max_tokens: int = 1500) -> str:
"""
Include chunks until token budget is reached.
Most RAG tasks need no more than 1500 context tokens.
"""
import tiktoken
enc = tiktoken.encoding_for_model("gpt-4o")
context_parts = []
used_tokens = 0
for chunk in chunks:
chunk_tokens = len(enc.encode(chunk))
if used_tokens + chunk_tokens > max_tokens:
break
context_parts.append(chunk)
used_tokens += chunk_tokens
return "\n\n".join(context_parts)Cutting context from 4,000 to 1,500 tokens can reduce LLM latency by 30-40%.
Latency Budget After Optimization
| Stage | Before | After | |---|---|---| | Query embedding | 500 ms | 500 ms (unchanged) | | Vector search | 300 ms | 300 ms (unchanged) | | LLM call (cache hit) | 7,000 ms | 0 ms (skipped) | | LLM call (cache miss) | 7,000 ms | 4,500 ms (shorter context) | | Time to first token (streaming) | 7,000 ms | 900 ms |
With a 60% cache hit rate, the average total latency becomes approximately:
- 60% of requests: under 1 second (cache hit)
- 40% of requests: approximately 5.5 seconds (cache miss, shorter context)
- Weighted average: approximately 2.8 seconds
That is a drop from 8 seconds to 2.8 seconds — a 65% improvement — from caching and context trimming alone.
Cache Hit Rate Tracking
Monitor your cache performance continuously:
from prometheus_client import Counter, Gauge
cache_hits = Counter("rag_cache_hits_total", "Total semantic cache hits")
cache_misses = Counter("rag_cache_misses_total", "Total semantic cache misses")
cache_hit_rate = Gauge("rag_cache_hit_rate", "Rolling cache hit rate")
async def monitored_rag_query(query: str) -> dict:
result = await cached_rag_query(query)
if result["source"] == "cache":
cache_hits.inc()
else:
cache_misses.inc()
total = cache_hits._value.get() + cache_misses._value.get()
if total > 0:
cache_hit_rate.set(cache_hits._value.get() / total)
return resultSummary
| Optimization | Latency Impact | Effort | |---|---|---| | Semantic cache | Eliminates 40-70% of LLM calls | Medium | | Streaming | Reduces perceived latency to 1-2 seconds | Low | | Async parallel retrieval | Saves 200-600 ms | Low | | Context trimming | Saves 1-3 seconds per LLM call | Low | | Smaller model for simple queries | Saves 3-5 seconds | Medium |
Start with streaming (low effort, immediate user experience win), then implement the semantic cache (highest latency impact), then measure what remains.