Learnixo

LangChain Mastery · Lesson 9 of 33

Conditional Routing with RunnableBranch

Why Conditional Routing?

Not all queries are equal. Some are simple lookups, others require multi-step reasoning. Some need drug information, others need patient counseling. Routing different queries to different specialized chains improves quality and reduces cost.

Query → Classifier → Route A (simple fact)    → Cheap model
                  → Route B (drug interaction) → Specialist chain
                  → Route C (patient question) → Empathetic chain
                  → Default                   → General chain

RunnableBranch Basics

Python
from langchain_core.runnables import RunnableBranch, RunnableLambda
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

model_fast = ChatOpenAI(model="gpt-4o-mini", temperature=0)
model_smart = ChatOpenAI(model="gpt-4o", temperature=0)
parser = StrOutputParser()

# Define specialized chains
definition_chain = (
    ChatPromptTemplate.from_template("Define {query} in medical terms.")
    | model_fast | parser
)

interaction_chain = (
    ChatPromptTemplate.from_messages([
        ("system", "You are a clinical pharmacist specializing in drug interactions. Be precise about severity."),
        ("human", "Drug interaction query: {query}"),
    ])
    | model_smart | parser
)

dosing_chain = (
    ChatPromptTemplate.from_messages([
        ("system", "You are a clinical pharmacist. Provide evidence-based dosing guidance."),
        ("human", "Dosing question: {query}"),
    ])
    | model_smart | parser
)

general_chain = (
    ChatPromptTemplate.from_template("Answer this clinical question: {query}")
    | model_smart | parser
)

# Route based on input content
def is_interaction_query(inputs: dict) -> bool:
    q = inputs["query"].lower()
    return any(w in q for w in ["interact", "combination", "together", "contraindicated", "avoid"])

def is_dosing_query(inputs: dict) -> bool:
    q = inputs["query"].lower()
    return any(w in q for w in ["dose", "dosing", "mg", "frequency", "how much", "how often"])

def is_definition_query(inputs: dict) -> bool:
    q = inputs["query"].lower()
    return q.startswith("what is") and len(q.split()) <= 8

# RunnableBranch: (condition, chain_if_true) pairs + default
router = RunnableBranch(
    (is_interaction_query, interaction_chain),
    (is_dosing_query, dosing_chain),
    (is_definition_query, definition_chain),
    general_chain,    # Default: no condition matched
)

# Test routing
queries = [
    "What is warfarin?",
    "What is the dose of warfarin for AFib?",
    "Can I combine warfarin with aspirin?",
    "What monitoring is needed for a patient on warfarin?",
]

for q in queries:
    result = router.invoke({"query": q})
    print(f"Q: {q[:50]}")
    print(f"A: {result[:80]}\n")

LLM-Based Classification Router

For more nuanced routing, use an LLM to classify the query:

Python
import json
from langchain_core.runnables import RunnableLambda

CATEGORIES = ["definition", "drug_interaction", "dosing", "safety", "general"]

classifier_prompt = ChatPromptTemplate.from_template(
    """Classify this clinical query into one category.

Categories:
- definition: asking what something is
- drug_interaction: asking about interactions between drugs  
- dosing: asking about doses, frequencies, routes of administration
- safety: asking about contraindications, adverse effects, warnings
- general: anything else

Query: {query}

Return JSON: {{"category": "one of the categories above"}}"""
)

classifier_chain = (
    classifier_prompt
    | ChatOpenAI(model="gpt-4o-mini", temperature=0, response_format={"type": "json_object"})
    | StrOutputParser()
    | RunnableLambda(json.loads)
    | RunnableLambda(lambda d: d["category"])
)


def route_by_category(inputs: dict) -> str:
    """Classify query and return its category."""
    return classifier_chain.invoke({"query": inputs["query"]})


# Route map
safety_chain = (
    ChatPromptTemplate.from_messages([
        ("system", "You are a clinical pharmacist. Emphasize safety warnings clearly."),
        ("human", "Safety question: {query}"),
    ])
    | model_smart | parser
)

ROUTE_MAP = {
    "definition": definition_chain,
    "drug_interaction": interaction_chain,
    "dosing": dosing_chain,
    "safety": safety_chain,
    "general": general_chain,
}


def llm_router(inputs: dict) -> str:
    """Route using LLM classification."""
    category = route_by_category(inputs)
    chain = ROUTE_MAP.get(category, general_chain)
    return chain.invoke(inputs)


llm_routing_chain = RunnableLambda(llm_router)

result = llm_routing_chain.invoke({"query": "What are the signs of warfarin toxicity?"})
print(result)  # Routes to safety_chain

RunnableBranch vs RunnableLambda for Routing

Two equivalent approaches:

Python
# Approach 1: RunnableBranch (declarative)
branch = RunnableBranch(
    (lambda x: x["complexity"] == "simple", simple_chain),
    (lambda x: x["complexity"] == "complex", complex_chain),
    default_chain,
)

# Approach 2: RunnableLambda with if-else (imperative)
def route_by_complexity(inputs: dict) -> str:
    if inputs.get("complexity") == "simple":
        return simple_chain.invoke(inputs)
    elif inputs.get("complexity") == "complex":
        return complex_chain.invoke(inputs)
    return default_chain.invoke(inputs)

lambda_branch = RunnableLambda(route_by_complexity)

# Both produce identical results
# RunnableBranch: more readable for simple condition lists
# RunnableLambda: more flexible for complex logic, early returns, loops

Model Routing: Cheap vs Expensive

A cost-optimization pattern — route simple queries to cheaper models:

Python
def estimate_complexity(query: str) -> str:
    """Heuristic-based complexity classification."""
    words = query.lower().split()
    
    # Simple: short questions with basic keywords
    if len(words) <= 8 and any(q in query.lower() for q in ["what is", "define", "what does"]):
        return "simple"
    
    # Complex: requires reasoning, multi-drug, comparisons
    complex_indicators = ["compare", "vs", "versus", "which is better", "step by step", "why does"]
    if any(ind in query.lower() for ind in complex_indicators):
        return "complex"
    
    return "standard"


cheap_chain = (
    ChatPromptTemplate.from_template("{query}")
    | ChatOpenAI(model="gpt-4o-mini", temperature=0)
    | parser
)
standard_chain = (
    ChatPromptTemplate.from_template("{query}")
    | ChatOpenAI(model="gpt-4o", temperature=0)
    | parser
)
premium_chain = (
    ChatPromptTemplate.from_template("{query}")
    | ChatOpenAI(model="gpt-4o", temperature=0, max_tokens=2000)
    | parser
)

cost_optimized_router = RunnableBranch(
    (lambda x: estimate_complexity(x["query"]) == "simple", cheap_chain),
    (lambda x: estimate_complexity(x["query"]) == "complex", premium_chain),
    standard_chain,
)

# Measure cost savings
test_queries = [
    "What is aspirin?",                          # simple  gpt-4o-mini
    "Compare warfarin vs rivaroxaban dosing",    # complex  gpt-4o (premium)
    "What monitoring is needed for warfarin?",   # standard  gpt-4o
]

for q in test_queries:
    complexity = estimate_complexity(q)
    result = cost_optimized_router.invoke({"query": q})
    print(f"[{complexity.upper()}] {q[:40]} → {result[:60]}")

Adding Metadata to Track Routing Decisions

Python
from langchain_core.runnables import RunnablePassthrough

def route_with_metadata(inputs: dict) -> dict:
    """Route query and track which branch was taken."""
    query = inputs["query"]

    if is_interaction_query(inputs):
        category = "drug_interaction"
        result = interaction_chain.invoke(inputs)
    elif is_dosing_query(inputs):
        category = "dosing"
        result = dosing_chain.invoke(inputs)
    else:
        category = "general"
        result = general_chain.invoke(inputs)

    return {
        "query": query,
        "answer": result,
        "route_taken": category,
        "model_used": "gpt-4o" if category in ("drug_interaction", "dosing") else "gpt-4o-mini",
    }

tracked_router = RunnableLambda(route_with_metadata)

result = tracked_router.invoke({"query": "What dose of metformin for diabetes?"})
print(f"Route: {result['route_taken']}")
print(f"Model: {result['model_used']}")
print(f"Answer: {result['answer'][:100]}")

Fallback Routing

Python
# Primary chain with fallback if it fails
primary_chain = interaction_chain.with_fallbacks([general_chain])

# Fallback for specific error types
from langchain_core.exceptions import OutputParserException

json_chain = (
    ChatPromptTemplate.from_template("Return JSON for {query}: {{answer: ...}}")
    | model_fast
    | JsonOutputParser()
)

# Fallback to string parser if JSON parsing fails
robust_chain = json_chain.with_fallbacks(
    [ChatPromptTemplate.from_template("{query}") | model_fast | parser],
    exceptions_to_handle=(OutputParserException,),
)