GenAI & LLM Interviews · Lesson 13 of 30
RAG in Production: Pipelines & Pitfalls
Production RAG Architecture
A production RAG system needs more than working retrieval. It needs:
- Observability: Know when it fails and why
- Reliability: Fallbacks when components fail
- Scalability: Handle concurrent load
- Maintainability: Easy to update knowledge and iterate
┌──────────────┐
User Query ─────────► API Gateway ├─── Rate Limiting ─── Auth
└──────┬───────┘
│
┌──────▼───────┐
│ RAG API │ ← Health checks, metrics
└──────┬───────┘
┌──────┴────────────────────┐
│ │
┌──────▼──────┐ ┌────────▼──────┐
│ Retrieval │ │ Response │
│ Service │ │ Cache │
└──────┬──────┘ └────────┬──────┘
│ │
┌──────▼──────┐ │
│ Vector Store│ │
│ (Pinecone/ │ │
│ Qdrant) │ │
└─────────────┘ │
│ │
┌──────▼──────────────────────────▼──────┐
│ LLM Service │
│ (OpenAI / Claude) │
└────────────────────────────────────────┘Async RAG API with FastAPI
from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import Optional
import asyncio
import time
import uuid
app = FastAPI(title="Clinical RAG API", version="1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["https://yourdomain.com"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class QueryRequest(BaseModel):
query: str = Field(..., min_length=5, max_length=2000)
session_id: Optional[str] = None
user_id: str
filters: Optional[dict] = None
top_k: int = Field(default=5, ge=1, le=20)
class QueryResponse(BaseModel):
query_id: str
response: str
citations: list[dict]
latency_ms: float
cache_hit: bool
model_used: str
# Dependency injection for RAG pipeline
def get_rag_pipeline():
"""Returns the singleton RAG pipeline."""
return app.state.rag_pipeline
@app.post("/api/v1/query", response_model=QueryResponse)
async def query_endpoint(
request: QueryRequest,
background_tasks: BackgroundTasks,
pipeline=Depends(get_rag_pipeline),
):
"""Main RAG query endpoint."""
query_id = str(uuid.uuid4())
start_time = time.time()
try:
result = await asyncio.to_thread(
pipeline.query,
request.query,
session_id=request.session_id or request.user_id,
)
latency_ms = (time.time() - start_time) * 1000
# Log query for analytics in background
background_tasks.add_task(
log_query_event,
query_id=query_id,
user_id=request.user_id,
query=request.query,
latency_ms=latency_ms,
cache_hit=result.get("cache") != "miss",
)
return QueryResponse(
query_id=query_id,
response=result["response"],
citations=result.get("citations", []),
latency_ms=latency_ms,
cache_hit=result.get("cache") in ("exact", "semantic"),
model_used="gpt-4o",
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}")
@app.get("/health")
async def health_check():
"""Health check endpoint for load balancers."""
try:
pipeline = app.state.rag_pipeline
# Quick retrieval test
await asyncio.to_thread(pipeline.ping)
return {"status": "healthy", "timestamp": time.time()}
except Exception as e:
return {"status": "unhealthy", "error": str(e)}, 503
async def log_query_event(**kwargs) -> None:
"""Log query metadata to analytics store."""
# Send to your analytics system (Datadog, Segment, internal DB)
passObservability and Metrics
import time
from collections import defaultdict
from dataclasses import dataclass, field
@dataclass
class RAGMetrics:
"""Aggregated metrics for a RAG pipeline."""
total_queries: int = 0
cache_hits: int = 0
retrieval_failures: int = 0
llm_failures: int = 0
latencies_ms: list[float] = field(default_factory=list)
retrieval_latencies_ms: list[float] = field(default_factory=list)
llm_latencies_ms: list[float] = field(default_factory=list)
def record_query(self, total_ms: float, retrieval_ms: float, llm_ms: float) -> None:
self.total_queries += 1
self.latencies_ms.append(total_ms)
self.retrieval_latencies_ms.append(retrieval_ms)
self.llm_latencies_ms.append(llm_ms)
def get_summary(self) -> dict:
import statistics
if not self.latencies_ms:
return {}
return {
"total_queries": self.total_queries,
"cache_hit_rate": self.cache_hits / max(self.total_queries, 1),
"error_rate": (self.retrieval_failures + self.llm_failures) / max(self.total_queries, 1),
"p50_ms": statistics.median(self.latencies_ms),
"p95_ms": sorted(self.latencies_ms)[int(0.95 * len(self.latencies_ms))],
"p99_ms": sorted(self.latencies_ms)[int(0.99 * len(self.latencies_ms))],
"retrieval_p50_ms": statistics.median(self.retrieval_latencies_ms),
"llm_p50_ms": statistics.median(self.llm_latencies_ms),
}
class InstrumentedRAGPipeline:
"""RAG pipeline with built-in metrics collection."""
def __init__(self, base_pipeline, metrics_exporter=None):
self.pipeline = base_pipeline
self.metrics = RAGMetrics()
self.exporter = metrics_exporter # Datadog, Prometheus, etc.
def query(self, query: str, **kwargs) -> dict:
start = time.time()
retrieval_ms = 0
llm_ms = 0
try:
# Time retrieval
t0 = time.time()
docs = self.pipeline.retriever.retrieve(query)
retrieval_ms = (time.time() - t0) * 1000
# Time LLM
t1 = time.time()
response = self.pipeline.generate(query, docs)
llm_ms = (time.time() - t1) * 1000
total_ms = (time.time() - start) * 1000
self.metrics.record_query(total_ms, retrieval_ms, llm_ms)
if self.exporter:
self.exporter.gauge("rag.latency.total_ms", total_ms)
self.exporter.gauge("rag.latency.retrieval_ms", retrieval_ms)
self.exporter.gauge("rag.latency.llm_ms", llm_ms)
self.exporter.increment("rag.queries.total")
return {"response": response, "retrieved_docs": docs}
except Exception as e:
if "retrieval" in str(type(e).__name__).lower():
self.metrics.retrieval_failures += 1
else:
self.metrics.llm_failures += 1
if self.exporter:
self.exporter.increment("rag.errors.total")
raiseError Handling and Fallbacks
import logging
from typing import Optional
logger = logging.getLogger(__name__)
class RAGWithFallbacks:
"""RAG pipeline with graceful degradation on failures."""
def __init__(
self,
primary_retriever,
fallback_retriever,
primary_llm,
fallback_llm,
):
self.primary_retriever = primary_retriever
self.fallback_retriever = fallback_retriever
self.primary_llm = primary_llm
self.fallback_llm = fallback_llm
def retrieve(self, query_embedding: list[float], top_k: int = 5) -> list[dict]:
"""Retrieve with automatic fallback."""
try:
docs = self.primary_retriever.retrieve(query_embedding, top_k=top_k)
if docs:
return docs
logger.warning("Primary retriever returned empty results")
except Exception as e:
logger.error(f"Primary retriever failed: {e}")
# Fallback retriever (e.g., BM25 or secondary vector store)
try:
logger.info("Using fallback retriever")
return self.fallback_retriever.retrieve(query_embedding, top_k=top_k)
except Exception as e:
logger.error(f"Fallback retriever failed: {e}")
return []
def generate(self, query: str, docs: list[dict]) -> dict:
"""Generate with LLM fallback."""
context = "\n\n".join([d["content"] for d in docs]) if docs else "No context available."
messages = [
{"role": "system", "content": "Answer using only the provided context."},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"},
]
# Try primary LLM
try:
response = self.primary_llm.chat.completions.create(
model="gpt-4o",
messages=messages,
temperature=0,
timeout=30,
)
return {
"response": response.choices[0].message.content,
"model": "gpt-4o",
"fallback": False,
}
except Exception as e:
logger.error(f"Primary LLM failed: {e}")
# Fallback LLM (e.g., Claude or cheaper model)
try:
logger.info("Using fallback LLM")
response = self.fallback_llm.chat.completions.create(
model="gpt-4o-mini",
messages=messages,
temperature=0,
timeout=20,
)
return {
"response": response.choices[0].message.content,
"model": "gpt-4o-mini",
"fallback": True,
}
except Exception as e:
logger.error(f"Fallback LLM failed: {e}")
return {
"response": "I'm temporarily unable to answer. Please try again shortly.",
"model": "none",
"fallback": True,
"error": True,
}
def query(self, query: str, query_embedding: list[float]) -> dict:
docs = self.retrieve(query_embedding)
return self.generate(query, docs)A/B Testing RAG Configurations
import random
from dataclasses import dataclass
@dataclass
class RAGVariant:
"""A RAG configuration variant for A/B testing."""
name: str
retriever: object
reranker: Optional[object]
top_k: int
model: str
traffic_fraction: float
class RAGABTester:
"""Route traffic across RAG variants for experimentation."""
def __init__(self, variants: list[RAGVariant]):
total = sum(v.traffic_fraction for v in variants)
assert abs(total - 1.0) < 0.01, "Traffic fractions must sum to 1.0"
self.variants = variants
self.variant_metrics: dict[str, RAGMetrics] = {
v.name: RAGMetrics() for v in variants
}
def select_variant(self, user_id: str) -> RAGVariant:
"""Deterministic variant selection based on user_id for consistency."""
# Hash user_id to get consistent assignment
user_bucket = hash(user_id) % 100 / 100.0
cumulative = 0.0
for variant in self.variants:
cumulative += variant.traffic_fraction
if user_bucket <= cumulative:
return variant
return self.variants[-1]
def query(self, user_id: str, query: str, query_embedding: list[float]) -> dict:
"""Query with variant selection and metrics tracking."""
variant = self.select_variant(user_id)
start = time.time()
# Retrieve
docs = variant.retriever.retrieve(query_embedding, top_k=variant.top_k)
# Rerank if available
if variant.reranker:
docs = variant.reranker.rerank(query, docs, top_k=5)
# Generate
from openai import OpenAI
client = OpenAI()
context = "\n\n".join([d["content"] for d in docs])
response = client.chat.completions.create(
model=variant.model,
messages=[
{"role": "system", "content": "Answer using only the provided context."},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"},
],
temperature=0,
).choices[0].message.content
latency_ms = (time.time() - start) * 1000
self.variant_metrics[variant.name].record_query(latency_ms, 0, 0)
return {
"response": response,
"variant": variant.name,
"latency_ms": latency_ms,
}
def get_experiment_results(self) -> dict:
"""Compare metrics across variants."""
return {
name: metrics.get_summary()
for name, metrics in self.variant_metrics.items()
}Knowledge Base Update Pipeline
import schedule
import time
def incremental_update_pipeline(
document_source, # Where new/updated docs come from
ingestion_pipeline, # IngestionPipeline instance
vector_store, # The vector store to update
check_interval_minutes: int = 60,
) -> None:
"""
Periodically check for new or updated documents and ingest them.
Runs as a background job.
"""
def update_job():
print(f"Checking for document updates...")
new_docs = document_source.get_updated_since_last_run()
if not new_docs:
print("No new documents")
return
results = {"processed": 0, "skipped": 0, "failed": 0}
for doc in new_docs:
try:
result = ingestion_pipeline.process_document(doc)
if result["status"] == "processed":
results["processed"] += 1
else:
results["skipped"] += 1
except Exception as e:
results["failed"] += 1
print(f"Failed to ingest {doc.id}: {e}")
print(f"Update complete: {results}")
document_source.mark_last_run()
schedule.every(check_interval_minutes).minutes.do(update_job)
while True:
schedule.run_pending()
time.sleep(60)Production Readiness Checklist
| Category | Item | Status | |---|---|---| | Reliability | Health endpoint returning 200 | Required | | Reliability | Retriever fallback on failure | Required | | Reliability | LLM fallback to cheaper model | Required | | Reliability | Retry with exponential backoff | Required | | Performance | Response cache (exact + semantic) | Required | | Performance | Async request handling | Required | | Performance | p95 latency under 5 seconds | Required | | Observability | Structured logging on every request | Required | | Observability | Latency metrics (total, retrieval, LLM) | Required | | Observability | Error rate alerting | Required | | Security | Input validation and sanitization | Required | | Security | Rate limiting per user | Required | | Security | PII detection before serving | Required | | Operations | Incremental knowledge base updates | Recommended | | Operations | A/B testing infrastructure | Recommended | | Operations | Automated evaluation on each deployment | Recommended |