RAG Cost Optimization
Reduce RAG system costs with model routing, caching strategies, embedding cost reduction, chunking optimization, and batch processing. Build cost-efficient clinical AI.
Where RAG Money Goes
A typical RAG query costs:
- Embedding (query): ~$0.00002 per query (text-embedding-3-small)
- Vector search: Negligible (free on self-hosted, pennies per million on managed)
- LLM generation: $0.005 - $0.15 per query (the dominant cost)
- Reranking: $0.001 - $0.01 per query (if using Cohere/cross-encoder)
For 100k queries/month at $0.05/query = $5,000/month. Optimization can cut this by 60-80%.
Cost Model
from dataclasses import dataclass
# Pricing as of mid-2025 (verify at provider websites for current rates)
LLM_PRICING = {
"gpt-4o": {"input_per_1m": 2.50, "output_per_1m": 10.00},
"gpt-4o-mini": {"input_per_1m": 0.15, "output_per_1m": 0.60},
"claude-opus-4-7": {"input_per_1m": 15.00, "output_per_1m": 75.00},
"claude-sonnet-4-6": {"input_per_1m": 3.00, "output_per_1m": 15.00},
"claude-haiku-4-5": {"input_per_1m": 0.80, "output_per_1m": 4.00},
"gemini-1.5-pro": {"input_per_1m": 3.50, "output_per_1m": 10.50},
"gemini-1.5-flash": {"input_per_1m": 0.075, "output_per_1m": 0.30},
}
EMBEDDING_PRICING = {
"text-embedding-3-small": 0.02, # per 1M tokens
"text-embedding-3-large": 0.13, # per 1M tokens
"text-embedding-ada-002": 0.10, # per 1M tokens
}
RERANKING_PRICING = {
"cohere-rerank-english-v3": 2.00, # per 1000 queries
}
@dataclass
class RAGQueryCost:
"""Cost breakdown for a single RAG query."""
embedding_cost: float = 0.0
retrieval_cost: float = 0.0
reranking_cost: float = 0.0
llm_input_cost: float = 0.0
llm_output_cost: float = 0.0
@property
def total(self) -> float:
return (
self.embedding_cost + self.retrieval_cost +
self.reranking_cost + self.llm_input_cost + self.llm_output_cost
)
def estimate_query_cost(
query_tokens: int = 20,
context_tokens: int = 3000, # Retrieved context sent to LLM
output_tokens: int = 400,
llm_model: str = "gpt-4o",
embedding_model: str = "text-embedding-3-small",
use_reranker: bool = True,
) -> RAGQueryCost:
"""Estimate cost for a single RAG query."""
emb_price = EMBEDDING_PRICING.get(embedding_model, 0.02)
llm_price = LLM_PRICING.get(llm_model, LLM_PRICING["gpt-4o"])
input_tokens = query_tokens + context_tokens + 200 # +200 for system prompt
cost = RAGQueryCost(
embedding_cost=query_tokens / 1_000_000 * emb_price,
llm_input_cost=input_tokens / 1_000_000 * llm_price["input_per_1m"],
llm_output_cost=output_tokens / 1_000_000 * llm_price["output_per_1m"],
reranking_cost=0.002 if use_reranker else 0.0,
)
return cost
def monthly_cost_projection(
queries_per_day: int,
avg_query_tokens: int = 20,
avg_context_tokens: int = 3000,
avg_output_tokens: int = 400,
llm_model: str = "gpt-4o",
cache_hit_rate: float = 0.30,
) -> dict:
"""Project monthly RAG costs."""
per_query = estimate_query_cost(
query_tokens=avg_query_tokens,
context_tokens=avg_context_tokens,
output_tokens=avg_output_tokens,
llm_model=llm_model,
)
# Cache hits skip LLM + reranking
cache_savings_per_query = per_query.llm_input_cost + per_query.llm_output_cost + per_query.reranking_cost
effective_cost = per_query.total - (cache_hit_rate * cache_savings_per_query)
monthly_queries = queries_per_day * 30
return {
"queries_per_month": monthly_queries,
"cost_per_query": round(per_query.total, 6),
"effective_cost_per_query": round(effective_cost, 6),
"monthly_cost_without_cache": round(per_query.total * monthly_queries, 2),
"monthly_cost_with_cache": round(effective_cost * monthly_queries, 2),
"monthly_savings_from_cache": round(
(per_query.total - effective_cost) * monthly_queries, 2
),
}Model Routing by Query Complexity
Route simple queries to cheaper models:
from openai import OpenAI
import json
client = OpenAI()
ROUTING_MODELS = {
"simple": "gpt-4o-mini", # Factual lookups, definitions
"standard": "gpt-4o", # Clinical questions, reasoning
"complex": "claude-opus-4-7", # Multi-step reasoning, differential diagnosis
}
def classify_query_complexity(query: str, context_length: int) -> str:
"""Classify query complexity to route to appropriate model."""
# Fast heuristic first
query_lower = query.lower()
# Simple: definition lookups, single-fact questions
simple_patterns = ["what is", "define", "what does", "what are the side effects of", "dose of"]
if any(query_lower.startswith(p) for p in simple_patterns) and context_length < 1500:
return "simple"
# Complex: multi-drug interactions, differential diagnosis
complex_patterns = [
"compare", "differential", "evaluate", "assess", "step-by-step",
"explain the mechanism", "analyze", "which is better",
]
complex_word_count = sum(1 for p in complex_patterns if p in query_lower)
if complex_word_count >= 2 or context_length > 4000:
return "complex"
# LLM classification for ambiguous cases
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": f"""Classify this clinical query complexity.
Query: {query}
Context length: {context_length} tokens
Classification:
- simple: single factual lookup, definition, dose check
- standard: clinical reasoning, drug interaction check, treatment question
- complex: multi-drug analysis, differential diagnosis, protocol design
Return JSON: {{"complexity": "simple|standard|complex", "reason": "brief reason"}}""",
}
],
response_format={"type": "json_object"},
temperature=0,
)
result = json.loads(response.choices[0].message.content)
return result.get("complexity", "standard")
def routed_rag_generate(
query: str,
context: str,
force_model: str = None,
) -> dict:
"""Generate response with cost-optimized model routing."""
if force_model:
model = force_model
complexity = "manual"
else:
complexity = classify_query_complexity(query, len(context.split()))
model = ROUTING_MODELS.get(complexity, "gpt-4o")
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "Answer using only the provided context."},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"},
],
temperature=0,
)
answer = response.choices[0].message.content
cost = estimate_query_cost(
query_tokens=len(query.split()),
context_tokens=len(context.split()),
output_tokens=len(answer.split()),
llm_model=model,
)
return {
"response": answer,
"model_used": model,
"complexity": complexity,
"cost": cost.total,
}Context Window Optimization
Reduce the amount of context sent to the LLM:
def trim_context_to_budget(
documents: list[dict],
max_tokens: int = 3000, # Context token budget
chars_per_token: int = 4, # Rough approximation
) -> list[dict]:
"""
Trim retrieved documents to fit within a token budget.
Prioritizes higher-scoring documents.
"""
# Sort by score (descending) — already sorted from retrieval
max_chars = max_tokens * chars_per_token
selected = []
total_chars = 0
for doc in documents:
doc_chars = len(doc.get("content", ""))
if total_chars + doc_chars <= max_chars:
selected.append(doc)
total_chars += doc_chars
else:
# Include truncated version of the next doc if space remains
remaining = max_chars - total_chars
if remaining > 200: # At least 200 chars = ~50 tokens = worth including
truncated = {
**doc,
"content": doc["content"][:remaining],
"truncated": True,
}
selected.append(truncated)
break
return selected
def compress_context_with_llm(
query: str,
documents: list[dict],
target_tokens: int = 1500,
) -> str:
"""
Use a cheap LLM to compress context before sending to expensive model.
Can reduce LLM input costs by 40-60%.
"""
full_context = "\n\n".join([d["content"] for d in documents])
target_chars = target_tokens * 4
compressed = client.chat.completions.create(
model="gpt-4o-mini", # Cheap model for compression
messages=[
{
"role": "user",
"content": f"""Extract only the information from this context that is needed to answer the question.
Remove: irrelevant sections, repetition, preamble, disclaimers.
Keep: specific facts, numbers, clinical details directly relevant to the question.
Target length: around {target_tokens} tokens.
Question: {query}
Context:
{full_context[:6000]}
Compressed context:""",
}
],
temperature=0,
max_tokens=target_tokens + 100,
).choices[0].message.content
return compressedPrompt Caching (Anthropic)
Anthropic's prompt caching saves 90% on repeated context:
import anthropic
claude_client = anthropic.Anthropic()
def rag_with_claude_caching(
system_prompt: str,
static_context: str, # Large static context (clinical guidelines, etc.)
query: str,
model: str = "claude-sonnet-4-6",
) -> dict:
"""
Use Claude prompt caching for repeated context.
Cache the system prompt and static context to save input tokens.
Cache saves ~90% on cached tokens.
Cache TTL: 5 minutes (refreshed on each use).
"""
response = claude_client.messages.create(
model=model,
max_tokens=1000,
system=[
{
"type": "text",
"text": system_prompt,
"cache_control": {"type": "ephemeral"}, # Cache the system prompt
}
],
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": f"Reference material:\n{static_context}",
"cache_control": {"type": "ephemeral"}, # Cache static context
},
{
"type": "text",
"text": f"\nQuestion: {query}", # Not cached (varies per query)
},
],
}
],
)
# Usage breakdown shows cached vs uncached tokens
usage = response.usage
return {
"response": response.content[0].text,
"input_tokens": usage.input_tokens,
"output_tokens": usage.output_tokens,
"cache_creation_tokens": getattr(usage, "cache_creation_input_tokens", 0),
"cache_read_tokens": getattr(usage, "cache_read_input_tokens", 0),
}Embedding Cost Reduction
from sentence_transformers import SentenceTransformer
import numpy as np
# Use a local model instead of OpenAI API for embeddings
# Eliminates embedding costs entirely for high-volume pipelines
class LocalEmbedder:
"""Free local embedding using a small sentence transformer."""
def __init__(self, model_name: str = "BAAI/bge-small-en-v1.5"):
# bge-small: 384 dims, 133MB, runs on CPU
self.model = SentenceTransformer(model_name)
self.dim = self.model.get_sentence_embedding_dimension()
def embed(self, texts: list[str], batch_size: int = 64) -> np.ndarray:
"""Embed texts locally. Zero API cost."""
return self.model.encode(
texts,
batch_size=batch_size,
normalize_embeddings=True, # Needed for cosine similarity
show_progress_bar=False,
)
def embed_single(self, text: str) -> list[float]:
return self.embed([text])[0].tolist()
# Cost comparison per 1M queries (each query = 20 tokens = 1 embedding call):
# OpenAI text-embedding-3-small: $0.02/1M tokens = $0.0004 per query = $400/month at 1M/month
# Local BAAI/bge-small: $0 per query (compute only)
# Break-even: if local inference costs less than $400/month in compute, go localCost Dashboard
class RAGCostTracker:
"""Track and report RAG costs in production."""
def __init__(self):
self.records: list[dict] = []
self.daily_budget: float = 100.0 # USD
def record(self, query_cost: RAGQueryCost, model: str, cache_hit: bool) -> None:
import datetime
self.records.append({
"timestamp": datetime.datetime.utcnow().isoformat(),
"total_cost": query_cost.total,
"llm_cost": query_cost.llm_input_cost + query_cost.llm_output_cost,
"embedding_cost": query_cost.embedding_cost,
"model": model,
"cache_hit": cache_hit,
})
def get_daily_summary(self) -> dict:
import datetime
today = datetime.date.today().isoformat()
today_records = [r for r in self.records if r["timestamp"].startswith(today)]
if not today_records:
return {"date": today, "queries": 0, "total_cost": 0}
total = sum(r["total_cost"] for r in today_records)
llm_cost = sum(r["llm_cost"] for r in today_records)
cache_hits = sum(1 for r in today_records if r["cache_hit"])
return {
"date": today,
"queries": len(today_records),
"total_cost": round(total, 4),
"llm_cost": round(llm_cost, 4),
"embedding_cost": round(sum(r["embedding_cost"] for r in today_records), 4),
"cache_hit_rate": cache_hits / len(today_records),
"avg_cost_per_query": round(total / len(today_records), 6),
"budget_used_pct": round(total / self.daily_budget * 100, 1),
"budget_remaining": round(self.daily_budget - total, 2),
}Cost Reduction Strategy Summary
| Strategy | Typical Savings | Complexity | |---|---|---| | Exact + semantic caching | 30-60% | Low | | Model routing (cheap for simple queries) | 20-40% | Medium | | Context window trimming | 15-30% | Low | | Local embeddings | 100% on embedding cost | Medium | | Claude prompt caching | 50-80% on static context | Low | | Context compression (LLM) | 30-50% on LLM input | Medium | | Streaming (reduce wasted output) | 5-15% | Low | | Shorter system prompts | 5-10% | Low |
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.