Learnixo

GenAI & LLM Interviews · Lesson 13 of 30

RAG in Production: Pipelines & Pitfalls

Production RAG Architecture

A production RAG system needs more than working retrieval. It needs:

  • Observability: Know when it fails and why
  • Reliability: Fallbacks when components fail
  • Scalability: Handle concurrent load
  • Maintainability: Easy to update knowledge and iterate
                    ┌──────────────┐
User Query ─────────► API Gateway  ├─── Rate Limiting ─── Auth
                    └──────┬───────┘
                           │
                    ┌──────▼───────┐
                    │   RAG API    │  ← Health checks, metrics
                    └──────┬───────┘
                    ┌──────┴────────────────────┐
                    │                           │
             ┌──────▼──────┐          ┌────────▼──────┐
             │  Retrieval  │          │  Response     │
             │  Service    │          │  Cache        │
             └──────┬──────┘          └────────┬──────┘
                    │                          │
             ┌──────▼──────┐                  │
             │ Vector Store│                  │
             │ (Pinecone/  │                  │
             │  Qdrant)    │                  │
             └─────────────┘                  │
                    │                          │
             ┌──────▼──────────────────────────▼──────┐
             │              LLM Service                │
             │         (OpenAI / Claude)               │
             └────────────────────────────────────────┘

Async RAG API with FastAPI

Python
from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import Optional
import asyncio
import time
import uuid

app = FastAPI(title="Clinical RAG API", version="1.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["https://yourdomain.com"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


class QueryRequest(BaseModel):
    query: str = Field(..., min_length=5, max_length=2000)
    session_id: Optional[str] = None
    user_id: str
    filters: Optional[dict] = None
    top_k: int = Field(default=5, ge=1, le=20)


class QueryResponse(BaseModel):
    query_id: str
    response: str
    citations: list[dict]
    latency_ms: float
    cache_hit: bool
    model_used: str


# Dependency injection for RAG pipeline
def get_rag_pipeline():
    """Returns the singleton RAG pipeline."""
    return app.state.rag_pipeline


@app.post("/api/v1/query", response_model=QueryResponse)
async def query_endpoint(
    request: QueryRequest,
    background_tasks: BackgroundTasks,
    pipeline=Depends(get_rag_pipeline),
):
    """Main RAG query endpoint."""
    query_id = str(uuid.uuid4())
    start_time = time.time()

    try:
        result = await asyncio.to_thread(
            pipeline.query,
            request.query,
            session_id=request.session_id or request.user_id,
        )

        latency_ms = (time.time() - start_time) * 1000

        # Log query for analytics in background
        background_tasks.add_task(
            log_query_event,
            query_id=query_id,
            user_id=request.user_id,
            query=request.query,
            latency_ms=latency_ms,
            cache_hit=result.get("cache") != "miss",
        )

        return QueryResponse(
            query_id=query_id,
            response=result["response"],
            citations=result.get("citations", []),
            latency_ms=latency_ms,
            cache_hit=result.get("cache") in ("exact", "semantic"),
            model_used="gpt-4o",
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}")


@app.get("/health")
async def health_check():
    """Health check endpoint for load balancers."""
    try:
        pipeline = app.state.rag_pipeline
        # Quick retrieval test
        await asyncio.to_thread(pipeline.ping)
        return {"status": "healthy", "timestamp": time.time()}
    except Exception as e:
        return {"status": "unhealthy", "error": str(e)}, 503


async def log_query_event(**kwargs) -> None:
    """Log query metadata to analytics store."""
    # Send to your analytics system (Datadog, Segment, internal DB)
    pass

Observability and Metrics

Python
import time
from collections import defaultdict
from dataclasses import dataclass, field


@dataclass
class RAGMetrics:
    """Aggregated metrics for a RAG pipeline."""
    total_queries: int = 0
    cache_hits: int = 0
    retrieval_failures: int = 0
    llm_failures: int = 0
    latencies_ms: list[float] = field(default_factory=list)
    retrieval_latencies_ms: list[float] = field(default_factory=list)
    llm_latencies_ms: list[float] = field(default_factory=list)

    def record_query(self, total_ms: float, retrieval_ms: float, llm_ms: float) -> None:
        self.total_queries += 1
        self.latencies_ms.append(total_ms)
        self.retrieval_latencies_ms.append(retrieval_ms)
        self.llm_latencies_ms.append(llm_ms)

    def get_summary(self) -> dict:
        import statistics
        if not self.latencies_ms:
            return {}

        return {
            "total_queries": self.total_queries,
            "cache_hit_rate": self.cache_hits / max(self.total_queries, 1),
            "error_rate": (self.retrieval_failures + self.llm_failures) / max(self.total_queries, 1),
            "p50_ms": statistics.median(self.latencies_ms),
            "p95_ms": sorted(self.latencies_ms)[int(0.95 * len(self.latencies_ms))],
            "p99_ms": sorted(self.latencies_ms)[int(0.99 * len(self.latencies_ms))],
            "retrieval_p50_ms": statistics.median(self.retrieval_latencies_ms),
            "llm_p50_ms": statistics.median(self.llm_latencies_ms),
        }


class InstrumentedRAGPipeline:
    """RAG pipeline with built-in metrics collection."""

    def __init__(self, base_pipeline, metrics_exporter=None):
        self.pipeline = base_pipeline
        self.metrics = RAGMetrics()
        self.exporter = metrics_exporter  # Datadog, Prometheus, etc.

    def query(self, query: str, **kwargs) -> dict:
        start = time.time()
        retrieval_ms = 0
        llm_ms = 0

        try:
            # Time retrieval
            t0 = time.time()
            docs = self.pipeline.retriever.retrieve(query)
            retrieval_ms = (time.time() - t0) * 1000

            # Time LLM
            t1 = time.time()
            response = self.pipeline.generate(query, docs)
            llm_ms = (time.time() - t1) * 1000

            total_ms = (time.time() - start) * 1000
            self.metrics.record_query(total_ms, retrieval_ms, llm_ms)

            if self.exporter:
                self.exporter.gauge("rag.latency.total_ms", total_ms)
                self.exporter.gauge("rag.latency.retrieval_ms", retrieval_ms)
                self.exporter.gauge("rag.latency.llm_ms", llm_ms)
                self.exporter.increment("rag.queries.total")

            return {"response": response, "retrieved_docs": docs}

        except Exception as e:
            if "retrieval" in str(type(e).__name__).lower():
                self.metrics.retrieval_failures += 1
            else:
                self.metrics.llm_failures += 1

            if self.exporter:
                self.exporter.increment("rag.errors.total")

            raise

Error Handling and Fallbacks

Python
import logging
from typing import Optional

logger = logging.getLogger(__name__)


class RAGWithFallbacks:
    """RAG pipeline with graceful degradation on failures."""

    def __init__(
        self,
        primary_retriever,
        fallback_retriever,
        primary_llm,
        fallback_llm,
    ):
        self.primary_retriever = primary_retriever
        self.fallback_retriever = fallback_retriever
        self.primary_llm = primary_llm
        self.fallback_llm = fallback_llm

    def retrieve(self, query_embedding: list[float], top_k: int = 5) -> list[dict]:
        """Retrieve with automatic fallback."""
        try:
            docs = self.primary_retriever.retrieve(query_embedding, top_k=top_k)
            if docs:
                return docs
            logger.warning("Primary retriever returned empty results")
        except Exception as e:
            logger.error(f"Primary retriever failed: {e}")

        # Fallback retriever (e.g., BM25 or secondary vector store)
        try:
            logger.info("Using fallback retriever")
            return self.fallback_retriever.retrieve(query_embedding, top_k=top_k)
        except Exception as e:
            logger.error(f"Fallback retriever failed: {e}")
            return []

    def generate(self, query: str, docs: list[dict]) -> dict:
        """Generate with LLM fallback."""
        context = "\n\n".join([d["content"] for d in docs]) if docs else "No context available."
        messages = [
            {"role": "system", "content": "Answer using only the provided context."},
            {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"},
        ]

        # Try primary LLM
        try:
            response = self.primary_llm.chat.completions.create(
                model="gpt-4o",
                messages=messages,
                temperature=0,
                timeout=30,
            )
            return {
                "response": response.choices[0].message.content,
                "model": "gpt-4o",
                "fallback": False,
            }
        except Exception as e:
            logger.error(f"Primary LLM failed: {e}")

        # Fallback LLM (e.g., Claude or cheaper model)
        try:
            logger.info("Using fallback LLM")
            response = self.fallback_llm.chat.completions.create(
                model="gpt-4o-mini",
                messages=messages,
                temperature=0,
                timeout=20,
            )
            return {
                "response": response.choices[0].message.content,
                "model": "gpt-4o-mini",
                "fallback": True,
            }
        except Exception as e:
            logger.error(f"Fallback LLM failed: {e}")
            return {
                "response": "I'm temporarily unable to answer. Please try again shortly.",
                "model": "none",
                "fallback": True,
                "error": True,
            }

    def query(self, query: str, query_embedding: list[float]) -> dict:
        docs = self.retrieve(query_embedding)
        return self.generate(query, docs)

A/B Testing RAG Configurations

Python
import random
from dataclasses import dataclass


@dataclass
class RAGVariant:
    """A RAG configuration variant for A/B testing."""
    name: str
    retriever: object
    reranker: Optional[object]
    top_k: int
    model: str
    traffic_fraction: float


class RAGABTester:
    """Route traffic across RAG variants for experimentation."""

    def __init__(self, variants: list[RAGVariant]):
        total = sum(v.traffic_fraction for v in variants)
        assert abs(total - 1.0) < 0.01, "Traffic fractions must sum to 1.0"
        self.variants = variants
        self.variant_metrics: dict[str, RAGMetrics] = {
            v.name: RAGMetrics() for v in variants
        }

    def select_variant(self, user_id: str) -> RAGVariant:
        """Deterministic variant selection based on user_id for consistency."""
        # Hash user_id to get consistent assignment
        user_bucket = hash(user_id) % 100 / 100.0
        cumulative = 0.0
        for variant in self.variants:
            cumulative += variant.traffic_fraction
            if user_bucket <= cumulative:
                return variant
        return self.variants[-1]

    def query(self, user_id: str, query: str, query_embedding: list[float]) -> dict:
        """Query with variant selection and metrics tracking."""
        variant = self.select_variant(user_id)
        start = time.time()

        # Retrieve
        docs = variant.retriever.retrieve(query_embedding, top_k=variant.top_k)

        # Rerank if available
        if variant.reranker:
            docs = variant.reranker.rerank(query, docs, top_k=5)

        # Generate
        from openai import OpenAI
        client = OpenAI()
        context = "\n\n".join([d["content"] for d in docs])
        response = client.chat.completions.create(
            model=variant.model,
            messages=[
                {"role": "system", "content": "Answer using only the provided context."},
                {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"},
            ],
            temperature=0,
        ).choices[0].message.content

        latency_ms = (time.time() - start) * 1000
        self.variant_metrics[variant.name].record_query(latency_ms, 0, 0)

        return {
            "response": response,
            "variant": variant.name,
            "latency_ms": latency_ms,
        }

    def get_experiment_results(self) -> dict:
        """Compare metrics across variants."""
        return {
            name: metrics.get_summary()
            for name, metrics in self.variant_metrics.items()
        }

Knowledge Base Update Pipeline

Python
import schedule
import time

def incremental_update_pipeline(
    document_source,          # Where new/updated docs come from
    ingestion_pipeline,       # IngestionPipeline instance
    vector_store,             # The vector store to update
    check_interval_minutes: int = 60,
) -> None:
    """
    Periodically check for new or updated documents and ingest them.
    Runs as a background job.
    """

    def update_job():
        print(f"Checking for document updates...")
        new_docs = document_source.get_updated_since_last_run()

        if not new_docs:
            print("No new documents")
            return

        results = {"processed": 0, "skipped": 0, "failed": 0}
        for doc in new_docs:
            try:
                result = ingestion_pipeline.process_document(doc)
                if result["status"] == "processed":
                    results["processed"] += 1
                else:
                    results["skipped"] += 1
            except Exception as e:
                results["failed"] += 1
                print(f"Failed to ingest {doc.id}: {e}")

        print(f"Update complete: {results}")
        document_source.mark_last_run()

    schedule.every(check_interval_minutes).minutes.do(update_job)

    while True:
        schedule.run_pending()
        time.sleep(60)

Production Readiness Checklist

| Category | Item | Status | |---|---|---| | Reliability | Health endpoint returning 200 | Required | | Reliability | Retriever fallback on failure | Required | | Reliability | LLM fallback to cheaper model | Required | | Reliability | Retry with exponential backoff | Required | | Performance | Response cache (exact + semantic) | Required | | Performance | Async request handling | Required | | Performance | p95 latency under 5 seconds | Required | | Observability | Structured logging on every request | Required | | Observability | Latency metrics (total, retrieval, LLM) | Required | | Observability | Error rate alerting | Required | | Security | Input validation and sanitization | Required | | Security | Rate limiting per user | Required | | Security | PII detection before serving | Required | | Operations | Incremental knowledge base updates | Recommended | | Operations | A/B testing infrastructure | Recommended | | Operations | Automated evaluation on each deployment | Recommended |