LangChain Mastery · Lesson 9 of 33
Conditional Routing with RunnableBranch
Why Conditional Routing?
Not all queries are equal. Some are simple lookups, others require multi-step reasoning. Some need drug information, others need patient counseling. Routing different queries to different specialized chains improves quality and reduces cost.
Query → Classifier → Route A (simple fact) → Cheap model
→ Route B (drug interaction) → Specialist chain
→ Route C (patient question) → Empathetic chain
→ Default → General chainRunnableBranch Basics
Python
from langchain_core.runnables import RunnableBranch, RunnableLambda
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
model_fast = ChatOpenAI(model="gpt-4o-mini", temperature=0)
model_smart = ChatOpenAI(model="gpt-4o", temperature=0)
parser = StrOutputParser()
# Define specialized chains
definition_chain = (
ChatPromptTemplate.from_template("Define {query} in medical terms.")
| model_fast | parser
)
interaction_chain = (
ChatPromptTemplate.from_messages([
("system", "You are a clinical pharmacist specializing in drug interactions. Be precise about severity."),
("human", "Drug interaction query: {query}"),
])
| model_smart | parser
)
dosing_chain = (
ChatPromptTemplate.from_messages([
("system", "You are a clinical pharmacist. Provide evidence-based dosing guidance."),
("human", "Dosing question: {query}"),
])
| model_smart | parser
)
general_chain = (
ChatPromptTemplate.from_template("Answer this clinical question: {query}")
| model_smart | parser
)
# Route based on input content
def is_interaction_query(inputs: dict) -> bool:
q = inputs["query"].lower()
return any(w in q for w in ["interact", "combination", "together", "contraindicated", "avoid"])
def is_dosing_query(inputs: dict) -> bool:
q = inputs["query"].lower()
return any(w in q for w in ["dose", "dosing", "mg", "frequency", "how much", "how often"])
def is_definition_query(inputs: dict) -> bool:
q = inputs["query"].lower()
return q.startswith("what is") and len(q.split()) <= 8
# RunnableBranch: (condition, chain_if_true) pairs + default
router = RunnableBranch(
(is_interaction_query, interaction_chain),
(is_dosing_query, dosing_chain),
(is_definition_query, definition_chain),
general_chain, # Default: no condition matched
)
# Test routing
queries = [
"What is warfarin?",
"What is the dose of warfarin for AFib?",
"Can I combine warfarin with aspirin?",
"What monitoring is needed for a patient on warfarin?",
]
for q in queries:
result = router.invoke({"query": q})
print(f"Q: {q[:50]}")
print(f"A: {result[:80]}\n")LLM-Based Classification Router
For more nuanced routing, use an LLM to classify the query:
Python
import json
from langchain_core.runnables import RunnableLambda
CATEGORIES = ["definition", "drug_interaction", "dosing", "safety", "general"]
classifier_prompt = ChatPromptTemplate.from_template(
"""Classify this clinical query into one category.
Categories:
- definition: asking what something is
- drug_interaction: asking about interactions between drugs
- dosing: asking about doses, frequencies, routes of administration
- safety: asking about contraindications, adverse effects, warnings
- general: anything else
Query: {query}
Return JSON: {{"category": "one of the categories above"}}"""
)
classifier_chain = (
classifier_prompt
| ChatOpenAI(model="gpt-4o-mini", temperature=0, response_format={"type": "json_object"})
| StrOutputParser()
| RunnableLambda(json.loads)
| RunnableLambda(lambda d: d["category"])
)
def route_by_category(inputs: dict) -> str:
"""Classify query and return its category."""
return classifier_chain.invoke({"query": inputs["query"]})
# Route map
safety_chain = (
ChatPromptTemplate.from_messages([
("system", "You are a clinical pharmacist. Emphasize safety warnings clearly."),
("human", "Safety question: {query}"),
])
| model_smart | parser
)
ROUTE_MAP = {
"definition": definition_chain,
"drug_interaction": interaction_chain,
"dosing": dosing_chain,
"safety": safety_chain,
"general": general_chain,
}
def llm_router(inputs: dict) -> str:
"""Route using LLM classification."""
category = route_by_category(inputs)
chain = ROUTE_MAP.get(category, general_chain)
return chain.invoke(inputs)
llm_routing_chain = RunnableLambda(llm_router)
result = llm_routing_chain.invoke({"query": "What are the signs of warfarin toxicity?"})
print(result) # Routes to safety_chainRunnableBranch vs RunnableLambda for Routing
Two equivalent approaches:
Python
# Approach 1: RunnableBranch (declarative)
branch = RunnableBranch(
(lambda x: x["complexity"] == "simple", simple_chain),
(lambda x: x["complexity"] == "complex", complex_chain),
default_chain,
)
# Approach 2: RunnableLambda with if-else (imperative)
def route_by_complexity(inputs: dict) -> str:
if inputs.get("complexity") == "simple":
return simple_chain.invoke(inputs)
elif inputs.get("complexity") == "complex":
return complex_chain.invoke(inputs)
return default_chain.invoke(inputs)
lambda_branch = RunnableLambda(route_by_complexity)
# Both produce identical results
# RunnableBranch: more readable for simple condition lists
# RunnableLambda: more flexible for complex logic, early returns, loopsModel Routing: Cheap vs Expensive
A cost-optimization pattern — route simple queries to cheaper models:
Python
def estimate_complexity(query: str) -> str:
"""Heuristic-based complexity classification."""
words = query.lower().split()
# Simple: short questions with basic keywords
if len(words) <= 8 and any(q in query.lower() for q in ["what is", "define", "what does"]):
return "simple"
# Complex: requires reasoning, multi-drug, comparisons
complex_indicators = ["compare", "vs", "versus", "which is better", "step by step", "why does"]
if any(ind in query.lower() for ind in complex_indicators):
return "complex"
return "standard"
cheap_chain = (
ChatPromptTemplate.from_template("{query}")
| ChatOpenAI(model="gpt-4o-mini", temperature=0)
| parser
)
standard_chain = (
ChatPromptTemplate.from_template("{query}")
| ChatOpenAI(model="gpt-4o", temperature=0)
| parser
)
premium_chain = (
ChatPromptTemplate.from_template("{query}")
| ChatOpenAI(model="gpt-4o", temperature=0, max_tokens=2000)
| parser
)
cost_optimized_router = RunnableBranch(
(lambda x: estimate_complexity(x["query"]) == "simple", cheap_chain),
(lambda x: estimate_complexity(x["query"]) == "complex", premium_chain),
standard_chain,
)
# Measure cost savings
test_queries = [
"What is aspirin?", # simple → gpt-4o-mini
"Compare warfarin vs rivaroxaban dosing", # complex → gpt-4o (premium)
"What monitoring is needed for warfarin?", # standard → gpt-4o
]
for q in test_queries:
complexity = estimate_complexity(q)
result = cost_optimized_router.invoke({"query": q})
print(f"[{complexity.upper()}] {q[:40]} → {result[:60]}")Adding Metadata to Track Routing Decisions
Python
from langchain_core.runnables import RunnablePassthrough
def route_with_metadata(inputs: dict) -> dict:
"""Route query and track which branch was taken."""
query = inputs["query"]
if is_interaction_query(inputs):
category = "drug_interaction"
result = interaction_chain.invoke(inputs)
elif is_dosing_query(inputs):
category = "dosing"
result = dosing_chain.invoke(inputs)
else:
category = "general"
result = general_chain.invoke(inputs)
return {
"query": query,
"answer": result,
"route_taken": category,
"model_used": "gpt-4o" if category in ("drug_interaction", "dosing") else "gpt-4o-mini",
}
tracked_router = RunnableLambda(route_with_metadata)
result = tracked_router.invoke({"query": "What dose of metformin for diabetes?"})
print(f"Route: {result['route_taken']}")
print(f"Model: {result['model_used']}")
print(f"Answer: {result['answer'][:100]}")Fallback Routing
Python
# Primary chain with fallback if it fails
primary_chain = interaction_chain.with_fallbacks([general_chain])
# Fallback for specific error types
from langchain_core.exceptions import OutputParserException
json_chain = (
ChatPromptTemplate.from_template("Return JSON for {query}: {{answer: ...}}")
| model_fast
| JsonOutputParser()
)
# Fallback to string parser if JSON parsing fails
robust_chain = json_chain.with_fallbacks(
[ChatPromptTemplate.from_template("{query}") | model_fast | parser],
exceptions_to_handle=(OutputParserException,),
)