Learnixo
Back to blog
AI Systemsintermediate

RAG Cost Optimization

Reduce RAG system costs with model routing, caching strategies, embedding cost reduction, chunking optimization, and batch processing. Build cost-efficient clinical AI.

Asma Hafeez KhanMay 16, 20268 min read
RAGCost OptimizationLLM PricingCachingModel RoutingProduction
Share:𝕏

Where RAG Money Goes

A typical RAG query costs:

  • Embedding (query): ~$0.00002 per query (text-embedding-3-small)
  • Vector search: Negligible (free on self-hosted, pennies per million on managed)
  • LLM generation: $0.005 - $0.15 per query (the dominant cost)
  • Reranking: $0.001 - $0.01 per query (if using Cohere/cross-encoder)

For 100k queries/month at $0.05/query = $5,000/month. Optimization can cut this by 60-80%.


Cost Model

Python
from dataclasses import dataclass

# Pricing as of mid-2025 (verify at provider websites for current rates)
LLM_PRICING = {
    "gpt-4o": {"input_per_1m": 2.50, "output_per_1m": 10.00},
    "gpt-4o-mini": {"input_per_1m": 0.15, "output_per_1m": 0.60},
    "claude-opus-4-7": {"input_per_1m": 15.00, "output_per_1m": 75.00},
    "claude-sonnet-4-6": {"input_per_1m": 3.00, "output_per_1m": 15.00},
    "claude-haiku-4-5": {"input_per_1m": 0.80, "output_per_1m": 4.00},
    "gemini-1.5-pro": {"input_per_1m": 3.50, "output_per_1m": 10.50},
    "gemini-1.5-flash": {"input_per_1m": 0.075, "output_per_1m": 0.30},
}

EMBEDDING_PRICING = {
    "text-embedding-3-small": 0.02,      # per 1M tokens
    "text-embedding-3-large": 0.13,      # per 1M tokens
    "text-embedding-ada-002": 0.10,      # per 1M tokens
}

RERANKING_PRICING = {
    "cohere-rerank-english-v3": 2.00,    # per 1000 queries
}


@dataclass
class RAGQueryCost:
    """Cost breakdown for a single RAG query."""
    embedding_cost: float = 0.0
    retrieval_cost: float = 0.0
    reranking_cost: float = 0.0
    llm_input_cost: float = 0.0
    llm_output_cost: float = 0.0

    @property
    def total(self) -> float:
        return (
            self.embedding_cost + self.retrieval_cost +
            self.reranking_cost + self.llm_input_cost + self.llm_output_cost
        )


def estimate_query_cost(
    query_tokens: int = 20,
    context_tokens: int = 3000,        # Retrieved context sent to LLM
    output_tokens: int = 400,
    llm_model: str = "gpt-4o",
    embedding_model: str = "text-embedding-3-small",
    use_reranker: bool = True,
) -> RAGQueryCost:
    """Estimate cost for a single RAG query."""
    emb_price = EMBEDDING_PRICING.get(embedding_model, 0.02)
    llm_price = LLM_PRICING.get(llm_model, LLM_PRICING["gpt-4o"])

    input_tokens = query_tokens + context_tokens + 200  # +200 for system prompt

    cost = RAGQueryCost(
        embedding_cost=query_tokens / 1_000_000 * emb_price,
        llm_input_cost=input_tokens / 1_000_000 * llm_price["input_per_1m"],
        llm_output_cost=output_tokens / 1_000_000 * llm_price["output_per_1m"],
        reranking_cost=0.002 if use_reranker else 0.0,
    )
    return cost


def monthly_cost_projection(
    queries_per_day: int,
    avg_query_tokens: int = 20,
    avg_context_tokens: int = 3000,
    avg_output_tokens: int = 400,
    llm_model: str = "gpt-4o",
    cache_hit_rate: float = 0.30,
) -> dict:
    """Project monthly RAG costs."""
    per_query = estimate_query_cost(
        query_tokens=avg_query_tokens,
        context_tokens=avg_context_tokens,
        output_tokens=avg_output_tokens,
        llm_model=llm_model,
    )

    # Cache hits skip LLM + reranking
    cache_savings_per_query = per_query.llm_input_cost + per_query.llm_output_cost + per_query.reranking_cost
    effective_cost = per_query.total - (cache_hit_rate * cache_savings_per_query)

    monthly_queries = queries_per_day * 30
    return {
        "queries_per_month": monthly_queries,
        "cost_per_query": round(per_query.total, 6),
        "effective_cost_per_query": round(effective_cost, 6),
        "monthly_cost_without_cache": round(per_query.total * monthly_queries, 2),
        "monthly_cost_with_cache": round(effective_cost * monthly_queries, 2),
        "monthly_savings_from_cache": round(
            (per_query.total - effective_cost) * monthly_queries, 2
        ),
    }

Model Routing by Query Complexity

Route simple queries to cheaper models:

Python
from openai import OpenAI
import json

client = OpenAI()

ROUTING_MODELS = {
    "simple": "gpt-4o-mini",        # Factual lookups, definitions
    "standard": "gpt-4o",           # Clinical questions, reasoning
    "complex": "claude-opus-4-7",   # Multi-step reasoning, differential diagnosis
}


def classify_query_complexity(query: str, context_length: int) -> str:
    """Classify query complexity to route to appropriate model."""
    # Fast heuristic first
    query_lower = query.lower()

    # Simple: definition lookups, single-fact questions
    simple_patterns = ["what is", "define", "what does", "what are the side effects of", "dose of"]
    if any(query_lower.startswith(p) for p in simple_patterns) and context_length < 1500:
        return "simple"

    # Complex: multi-drug interactions, differential diagnosis
    complex_patterns = [
        "compare", "differential", "evaluate", "assess", "step-by-step",
        "explain the mechanism", "analyze", "which is better",
    ]
    complex_word_count = sum(1 for p in complex_patterns if p in query_lower)
    if complex_word_count >= 2 or context_length > 4000:
        return "complex"

    # LLM classification for ambiguous cases
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {
                "role": "user",
                "content": f"""Classify this clinical query complexity.

Query: {query}
Context length: {context_length} tokens

Classification:
- simple: single factual lookup, definition, dose check
- standard: clinical reasoning, drug interaction check, treatment question
- complex: multi-drug analysis, differential diagnosis, protocol design

Return JSON: {{"complexity": "simple|standard|complex", "reason": "brief reason"}}""",
            }
        ],
        response_format={"type": "json_object"},
        temperature=0,
    )
    result = json.loads(response.choices[0].message.content)
    return result.get("complexity", "standard")


def routed_rag_generate(
    query: str,
    context: str,
    force_model: str = None,
) -> dict:
    """Generate response with cost-optimized model routing."""
    if force_model:
        model = force_model
        complexity = "manual"
    else:
        complexity = classify_query_complexity(query, len(context.split()))
        model = ROUTING_MODELS.get(complexity, "gpt-4o")

    response = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": "Answer using only the provided context."},
            {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"},
        ],
        temperature=0,
    )

    answer = response.choices[0].message.content
    cost = estimate_query_cost(
        query_tokens=len(query.split()),
        context_tokens=len(context.split()),
        output_tokens=len(answer.split()),
        llm_model=model,
    )

    return {
        "response": answer,
        "model_used": model,
        "complexity": complexity,
        "cost": cost.total,
    }

Context Window Optimization

Reduce the amount of context sent to the LLM:

Python
def trim_context_to_budget(
    documents: list[dict],
    max_tokens: int = 3000,     # Context token budget
    chars_per_token: int = 4,   # Rough approximation
) -> list[dict]:
    """
    Trim retrieved documents to fit within a token budget.
    Prioritizes higher-scoring documents.
    """
    # Sort by score (descending)  already sorted from retrieval
    max_chars = max_tokens * chars_per_token
    selected = []
    total_chars = 0

    for doc in documents:
        doc_chars = len(doc.get("content", ""))
        if total_chars + doc_chars <= max_chars:
            selected.append(doc)
            total_chars += doc_chars
        else:
            # Include truncated version of the next doc if space remains
            remaining = max_chars - total_chars
            if remaining > 200:  # At least 200 chars = ~50 tokens = worth including
                truncated = {
                    **doc,
                    "content": doc["content"][:remaining],
                    "truncated": True,
                }
                selected.append(truncated)
            break

    return selected


def compress_context_with_llm(
    query: str,
    documents: list[dict],
    target_tokens: int = 1500,
) -> str:
    """
    Use a cheap LLM to compress context before sending to expensive model.
    Can reduce LLM input costs by 40-60%.
    """
    full_context = "\n\n".join([d["content"] for d in documents])
    target_chars = target_tokens * 4

    compressed = client.chat.completions.create(
        model="gpt-4o-mini",    # Cheap model for compression
        messages=[
            {
                "role": "user",
                "content": f"""Extract only the information from this context that is needed to answer the question.
Remove: irrelevant sections, repetition, preamble, disclaimers.
Keep: specific facts, numbers, clinical details directly relevant to the question.
Target length: around {target_tokens} tokens.

Question: {query}

Context:
{full_context[:6000]}

Compressed context:""",
            }
        ],
        temperature=0,
        max_tokens=target_tokens + 100,
    ).choices[0].message.content

    return compressed

Prompt Caching (Anthropic)

Anthropic's prompt caching saves 90% on repeated context:

Python
import anthropic

claude_client = anthropic.Anthropic()

def rag_with_claude_caching(
    system_prompt: str,
    static_context: str,     # Large static context (clinical guidelines, etc.)
    query: str,
    model: str = "claude-sonnet-4-6",
) -> dict:
    """
    Use Claude prompt caching for repeated context.
    Cache the system prompt and static context to save input tokens.
    
    Cache saves ~90% on cached tokens.
    Cache TTL: 5 minutes (refreshed on each use).
    """
    response = claude_client.messages.create(
        model=model,
        max_tokens=1000,
        system=[
            {
                "type": "text",
                "text": system_prompt,
                "cache_control": {"type": "ephemeral"},     # Cache the system prompt
            }
        ],
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": f"Reference material:\n{static_context}",
                        "cache_control": {"type": "ephemeral"},  # Cache static context
                    },
                    {
                        "type": "text",
                        "text": f"\nQuestion: {query}",     # Not cached (varies per query)
                    },
                ],
            }
        ],
    )

    # Usage breakdown shows cached vs uncached tokens
    usage = response.usage
    return {
        "response": response.content[0].text,
        "input_tokens": usage.input_tokens,
        "output_tokens": usage.output_tokens,
        "cache_creation_tokens": getattr(usage, "cache_creation_input_tokens", 0),
        "cache_read_tokens": getattr(usage, "cache_read_input_tokens", 0),
    }

Embedding Cost Reduction

Python
from sentence_transformers import SentenceTransformer
import numpy as np

# Use a local model instead of OpenAI API for embeddings
# Eliminates embedding costs entirely for high-volume pipelines

class LocalEmbedder:
    """Free local embedding using a small sentence transformer."""

    def __init__(self, model_name: str = "BAAI/bge-small-en-v1.5"):
        # bge-small: 384 dims, 133MB, runs on CPU
        self.model = SentenceTransformer(model_name)
        self.dim = self.model.get_sentence_embedding_dimension()

    def embed(self, texts: list[str], batch_size: int = 64) -> np.ndarray:
        """Embed texts locally. Zero API cost."""
        return self.model.encode(
            texts,
            batch_size=batch_size,
            normalize_embeddings=True,  # Needed for cosine similarity
            show_progress_bar=False,
        )

    def embed_single(self, text: str) -> list[float]:
        return self.embed([text])[0].tolist()


# Cost comparison per 1M queries (each query = 20 tokens = 1 embedding call):
# OpenAI text-embedding-3-small: $0.02/1M tokens = $0.0004 per query = $400/month at 1M/month
# Local BAAI/bge-small: $0 per query (compute only)
# Break-even: if local inference costs less than $400/month in compute, go local

Cost Dashboard

Python
class RAGCostTracker:
    """Track and report RAG costs in production."""

    def __init__(self):
        self.records: list[dict] = []
        self.daily_budget: float = 100.0  # USD

    def record(self, query_cost: RAGQueryCost, model: str, cache_hit: bool) -> None:
        import datetime
        self.records.append({
            "timestamp": datetime.datetime.utcnow().isoformat(),
            "total_cost": query_cost.total,
            "llm_cost": query_cost.llm_input_cost + query_cost.llm_output_cost,
            "embedding_cost": query_cost.embedding_cost,
            "model": model,
            "cache_hit": cache_hit,
        })

    def get_daily_summary(self) -> dict:
        import datetime
        today = datetime.date.today().isoformat()
        today_records = [r for r in self.records if r["timestamp"].startswith(today)]

        if not today_records:
            return {"date": today, "queries": 0, "total_cost": 0}

        total = sum(r["total_cost"] for r in today_records)
        llm_cost = sum(r["llm_cost"] for r in today_records)
        cache_hits = sum(1 for r in today_records if r["cache_hit"])

        return {
            "date": today,
            "queries": len(today_records),
            "total_cost": round(total, 4),
            "llm_cost": round(llm_cost, 4),
            "embedding_cost": round(sum(r["embedding_cost"] for r in today_records), 4),
            "cache_hit_rate": cache_hits / len(today_records),
            "avg_cost_per_query": round(total / len(today_records), 6),
            "budget_used_pct": round(total / self.daily_budget * 100, 1),
            "budget_remaining": round(self.daily_budget - total, 2),
        }

Cost Reduction Strategy Summary

| Strategy | Typical Savings | Complexity | |---|---|---| | Exact + semantic caching | 30-60% | Low | | Model routing (cheap for simple queries) | 20-40% | Medium | | Context window trimming | 15-30% | Low | | Local embeddings | 100% on embedding cost | Medium | | Claude prompt caching | 50-80% on static context | Low | | Context compression (LLM) | 30-50% on LLM input | Medium | | Streaming (reduce wasted output) | 5-15% | Low | | Shorter system prompts | 5-10% | Low |

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.