Learnixo

Scenario Based Questions · Lesson 3 of 13

Scenario: RAG System Is Too Slow at Scale

The Scenario

Your RAG chatbot is functionally correct but painfully slow. Telemetry shows:

  • P50 latency: 8 seconds
  • P95 latency: 12 seconds
  • User drop-off (session abandonment) has increased by 35% since launch

Research from Google shows that users expect web responses in under 2 seconds, and each additional second of latency reduces conversions by 7%. A 12-second RAG response is a product-killing problem.

Let's break down where the time goes, then attack each component.

Step 1: Instrument the Pipeline

You cannot optimize what you do not measure. Add timing spans to every stage:

Python
import time
import logging
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Dict

logger = logging.getLogger("rag.latency")

@dataclass
class LatencyProfile:
    query_id: str
    spans: Dict[str, float] = field(default_factory=dict)
    total_ms: float = 0.0

    def record(self, stage: str, duration_ms: float):
        self.spans[stage] = duration_ms

    def report(self):
        logger.info({
            "query_id": self.query_id,
            "total_ms": self.total_ms,
            "breakdown": self.spans,
        })

@contextmanager
def timed_stage(profile: LatencyProfile, stage: str):
    start = time.perf_counter()
    try:
        yield
    finally:
        elapsed_ms = (time.perf_counter() - start) * 1000
        profile.record(stage, elapsed_ms)

async def rag_with_profiling(query: str, query_id: str):
    profile = LatencyProfile(query_id=query_id)
    total_start = time.perf_counter()

    # Stage 1: Embed the query
    with timed_stage(profile, "query_embedding"):
        query_embedding = await embed_query(query)

    # Stage 2: Vector search
    with timed_stage(profile, "vector_search"):
        chunks = await vector_store.search(query_embedding, k=5)

    # Stage 3: Parse and format context
    with timed_stage(profile, "context_assembly"):
        context = build_context_string(chunks)

    # Stage 4: LLM completion
    with timed_stage(profile, "llm_completion"):
        response = await call_llm(context, query)

    profile.total_ms = (time.perf_counter() - total_start) * 1000
    profile.report()
    return response

After one day of profiling, a typical breakdown looks like this:

| Stage | Typical Duration | Share of Total | |---|---|---| | Query embedding | 400-600 ms | 6% | | Vector search | 200-400 ms | 4% | | Context assembly | 50-100 ms | 1% | | LLM completion | 6,000-9,000 ms | 89% |

The LLM is almost always the bottleneck. Any optimization that avoids calling the LLM entirely is the highest-leverage change.

Biggest Win: Semantic Cache

A semantic cache stores LLM responses and returns cached answers for semantically similar queries — not just exact duplicates. If someone asks "how many vacation days do I get?" and the cache has a response for "what is my annual leave entitlement?", the cache hits and you skip the LLM entirely.

This can eliminate 40-70% of LLM calls in steady-state production.

Architecture

User Query
    │
    ▼
[Embed Query]
    │
    ▼
[Redis: search for similar cached query]
    │
    ├── HIT (cosine similarity > 0.92): return cached response (under 50ms total)
    │
    └── MISS: run full RAG pipeline → store result in Redis → return response

Implementation

Python
import redis
import json
import hashlib
import numpy as np
from openai import AzureOpenAI

client = AzureOpenAI(
    azure_endpoint="https://your-resource.openai.azure.com",
    api_version="2024-02-01",
)

redis_client = redis.Redis(host="your-redis.cache.windows.net", port=6380, ssl=True)

CACHE_SIMILARITY_THRESHOLD = 0.92
CACHE_TTL_SECONDS = 3600  # 1 hour

def embed_text(text: str) -> list[float]:
    return client.embeddings.create(
        model="text-embedding-3-large",
        input=text,
    ).data[0].embedding

def cosine_similarity(a: list[float], b: list[float]) -> float:
    a_np, b_np = np.array(a), np.array(b)
    return float(np.dot(a_np, b_np) / (np.linalg.norm(a_np) * np.linalg.norm(b_np)))

class SemanticCache:
    def __init__(self, redis_client, namespace: str = "rag_cache"):
        self.redis = redis_client
        self.namespace = namespace
        self._cache_index = []  # In production, use Redis Search or a separate vector store

    def _load_index(self):
        """Load all cached embeddings into memory for similarity search."""
        keys = self.redis.keys(f"{self.namespace}:*")
        self._cache_index = []
        for key in keys:
            entry = json.loads(self.redis.get(key))
            self._cache_index.append({
                "key": key.decode(),
                "embedding": entry["query_embedding"],
                "response": entry["response"],
            })

    def get(self, query_embedding: list[float]) -> str | None:
        """
        Search the cache for a similar query.
        Returns the cached response if found, else None.
        """
        self._load_index()  # In production, use an efficient index
        best_score = 0.0
        best_response = None

        for entry in self._cache_index:
            score = cosine_similarity(query_embedding, entry["embedding"])
            if score > best_score:
                best_score = score
                best_response = entry["response"]

        if best_score >= CACHE_SIMILARITY_THRESHOLD:
            return best_response
        return None

    def set(self, query: str, query_embedding: list[float], response: str):
        """Store a query-response pair in the cache."""
        cache_key = f"{self.namespace}:{hashlib.sha256(query.encode()).hexdigest()[:16]}"
        entry = {
            "query": query,
            "query_embedding": query_embedding,
            "response": response,
        }
        self.redis.setex(cache_key, CACHE_TTL_SECONDS, json.dumps(entry))

semantic_cache = SemanticCache(redis_client)

async def cached_rag_query(query: str) -> dict:
    # Step 1: Embed query (fast: ~400ms)
    query_embedding = embed_text(query)

    # Step 2: Check semantic cache
    cached_response = semantic_cache.get(query_embedding)
    if cached_response:
        return {
            "response": cached_response,
            "source": "cache",
            "latency_saved_ms": 7000,  # approximate LLM call cost
        }

    # Step 3: Cache miss  run full RAG pipeline
    chunks = await vector_store.search(query_embedding, k=5)
    context = build_context_string(chunks)
    response = await call_llm(context, query)

    # Step 4: Store in cache for future queries
    semantic_cache.set(query, query_embedding, response)

    return {
        "response": response,
        "source": "llm",
        "latency_saved_ms": 0,
    }

Production note: The in-memory index scan above is fine for caches under a few thousand entries. For larger caches, use Redis Search with vector indexing, or a dedicated vector store like Qdrant as your cache backend.

Fix 2: Streaming to Reduce Perceived Latency

Streaming does not reduce total time-to-completion, but it dramatically improves perceived responsiveness. Users start reading after the first token arrives (typically 1-2 seconds) rather than waiting for the full 8-second response.

Python
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import asyncio

app = FastAPI()

async def stream_rag_response(query: str):
    """
    Generator that yields tokens as they arrive from the LLM.
    """
    # Retrieve context first (non-streamed, fast)
    query_embedding = embed_text(query)
    chunks = await vector_store.search(query_embedding, k=5)
    context = build_context_string(chunks)

    # Stream the LLM response
    stream = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": f"Context:\n{context}"},
            {"role": "user", "content": query},
        ],
        stream=True,
    )

    for chunk in stream:
        delta = chunk.choices[0].delta.content
        if delta:
            yield f"data: {json.dumps({'token': delta})}\n\n"
            await asyncio.sleep(0)  # yield control to event loop

    yield "data: [DONE]\n\n"

@app.get("/query/stream")
async def stream_endpoint(query: str):
    return StreamingResponse(
        stream_rag_response(query),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "X-Accel-Buffering": "no",  # disable nginx buffering
        },
    )

On the frontend, display tokens as they arrive:

TYPESCRIPT
async function streamQuery(query: string, onToken: (t: string) => void) {
  const response = await fetch(`/query/stream?query=${encodeURIComponent(query)}`);
  const reader = response.body!.getReader();
  const decoder = new TextDecoder();

  while (true) {
    const { done, value } = await reader.read();
    if (done) break;

    const chunk = decoder.decode(value);
    for (const line of chunk.split("\n")) {
      if (line.startsWith("data: ") && line !== "data: [DONE]") {
        const data = JSON.parse(line.slice(6));
        onToken(data.token);
      }
    }
  }
}

Fix 3: Async Parallel Retrieval

If you query multiple vector stores or knowledge bases, run them in parallel:

Python
import asyncio

async def parallel_retrieval(query: str, k_per_source: int = 3) -> List[dict]:
    """
    Retrieve from multiple sources concurrently instead of sequentially.
    """
    query_embedding = embed_text(query)

    # Launch all retrievals concurrently
    tasks = [
        product_docs_store.search(query_embedding, k=k_per_source),
        policy_store.search(query_embedding, k=k_per_source),
        faq_store.search(query_embedding, k=k_per_source),
    ]
    results = await asyncio.gather(*tasks)

    # Merge and deduplicate
    all_chunks = [chunk for source_results in results for chunk in source_results]
    return deduplicate_and_rerank(all_chunks, query)

Sequential retrieval from three sources: 3 x 300ms = 900ms. Parallel: ~300ms.

Fix 4: Reduce Context Size

More context tokens = slower LLM response. Trim aggressively:

Python
def trim_context(chunks: List[str], max_tokens: int = 1500) -> str:
    """
    Include chunks until token budget is reached.
    Most RAG tasks need no more than 1500 context tokens.
    """
    import tiktoken
    enc = tiktoken.encoding_for_model("gpt-4o")

    context_parts = []
    used_tokens = 0

    for chunk in chunks:
        chunk_tokens = len(enc.encode(chunk))
        if used_tokens + chunk_tokens > max_tokens:
            break
        context_parts.append(chunk)
        used_tokens += chunk_tokens

    return "\n\n".join(context_parts)

Cutting context from 4,000 to 1,500 tokens can reduce LLM latency by 30-40%.

Latency Budget After Optimization

| Stage | Before | After | |---|---|---| | Query embedding | 500 ms | 500 ms (unchanged) | | Vector search | 300 ms | 300 ms (unchanged) | | LLM call (cache hit) | 7,000 ms | 0 ms (skipped) | | LLM call (cache miss) | 7,000 ms | 4,500 ms (shorter context) | | Time to first token (streaming) | 7,000 ms | 900 ms |

With a 60% cache hit rate, the average total latency becomes approximately:

  • 60% of requests: under 1 second (cache hit)
  • 40% of requests: approximately 5.5 seconds (cache miss, shorter context)
  • Weighted average: approximately 2.8 seconds

That is a drop from 8 seconds to 2.8 seconds — a 65% improvement — from caching and context trimming alone.

Cache Hit Rate Tracking

Monitor your cache performance continuously:

Python
from prometheus_client import Counter, Gauge

cache_hits = Counter("rag_cache_hits_total", "Total semantic cache hits")
cache_misses = Counter("rag_cache_misses_total", "Total semantic cache misses")
cache_hit_rate = Gauge("rag_cache_hit_rate", "Rolling cache hit rate")

async def monitored_rag_query(query: str) -> dict:
    result = await cached_rag_query(query)

    if result["source"] == "cache":
        cache_hits.inc()
    else:
        cache_misses.inc()

    total = cache_hits._value.get() + cache_misses._value.get()
    if total > 0:
        cache_hit_rate.set(cache_hits._value.get() / total)

    return result

Summary

| Optimization | Latency Impact | Effort | |---|---|---| | Semantic cache | Eliminates 40-70% of LLM calls | Medium | | Streaming | Reduces perceived latency to 1-2 seconds | Low | | Async parallel retrieval | Saves 200-600 ms | Low | | Context trimming | Saves 1-3 seconds per LLM call | Low | | Smaller model for simple queries | Saves 3-5 seconds | Medium |

Start with streaming (low effort, immediate user experience win), then implement the semantic cache (highest latency impact), then measure what remains.