Learnixo

LangChain Mastery · Lesson 10 of 33

Interview: Design a Multi-Step LangChain Pipeline

Q1: Design a pipeline that takes a raw clinical note and extracts structured drug information, then flags interactions.

Answer:

This requires a sequential pipeline: extract first, then analyze.

Python
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.runnables import RunnablePassthrough
from pydantic import BaseModel, Field

model = ChatOpenAI(model="gpt-4o", temperature=0)

class ExtractedMedications(BaseModel):
    medications: list[str] = Field(description="List of medication names")
    conditions: list[str] = Field(description="Active diagnoses")
    patient_age: str = Field(default="unknown")

class InteractionAnalysis(BaseModel):
    pairs: list[dict] = Field(description="Drug pairs with interaction analysis")
    high_risk_interactions: list[str] = Field(description="Interactions requiring immediate attention")
    recommendation: str

# Step 1: Extract medications
extraction_parser = PydanticOutputParser(pydantic_object=ExtractedMedications)
extraction_chain = (
    ChatPromptTemplate.from_messages([
        ("system", "Extract medications and conditions. {format_instructions}"),
        ("human", "Clinical note:\n{note}"),
    ]).partial(format_instructions=extraction_parser.get_format_instructions())
    | model | extraction_parser
)

# Step 2: Analyze interactions
interaction_parser = PydanticOutputParser(pydantic_object=InteractionAnalysis)
interaction_chain = (
    ChatPromptTemplate.from_messages([
        ("system", "Analyze drug interactions. {format_instructions}"),
        ("human", "Medications: {medications}\nPatient conditions: {conditions}"),
    ]).partial(format_instructions=interaction_parser.get_format_instructions())
    | model | interaction_parser
)

# Full pipeline
clinical_pipeline = (
    RunnablePassthrough.assign(
        extracted=extraction_chain
    )
    | RunnablePassthrough.assign(
        interactions=lambda d: interaction_chain.invoke({
            "medications": ", ".join(d["extracted"].medications),
            "conditions": ", ".join(d["extracted"].conditions),
        })
    )
)

result = clinical_pipeline.invoke({
    "note": "Patient is a 72-year-old with AFib and T2DM. Current medications: warfarin 5mg daily, metformin 1000mg BID, aspirin 81mg daily."
})
print(result["extracted"].medications)
print(result["interactions"].high_risk_interactions)

Key design decisions:

  • Sequential because step 2 depends on step 1's output
  • Pydantic parsers for structured, typed output
  • Pass all context forward with RunnablePassthrough

Q2: How would you build a chain that routes queries to either a drug database or a patient counseling specialist?

Answer:

Python
from langchain_core.runnables import RunnableBranch, RunnableLambda
from langchain_core.output_parsers import StrOutputParser
import json

parser = StrOutputParser()
classifier_model = ChatOpenAI(model="gpt-4o-mini", temperature=0)
specialist_model = ChatOpenAI(model="gpt-4o", temperature=0)

# Classify query type
classifier_prompt = ChatPromptTemplate.from_template(
    """Classify query: is it a factual drug lookup or patient counseling?
Query: {query}
Return JSON: {{"type": "drug_lookup" or "patient_counseling"}}"""
)

classifier = (
    classifier_prompt | classifier_model
    | StrOutputParser()
    | RunnableLambda(json.loads)
    | RunnableLambda(lambda d: d["type"])
)

# Specialized chains
drug_lookup_chain = (
    ChatPromptTemplate.from_messages([
        ("system", "You are a clinical reference. Provide factual drug information."),
        ("human", "{query}"),
    ])
    | specialist_model | parser
)

counseling_chain = (
    ChatPromptTemplate.from_messages([
        ("system", "You are a clinical pharmacist. Counsel patients clearly and empathetically. Use plain language."),
        ("human", "{query}"),
    ])
    | specialist_model | parser
)

# Route based on classification
def route(inputs: dict) -> str:
    query_type = classifier.invoke({"query": inputs["query"]})
    chain = drug_lookup_chain if query_type == "drug_lookup" else counseling_chain
    return chain.invoke(inputs)

routing_chain = RunnableLambda(route)

Interview follow-up: "What if you have 10 categories, not 2?"

Use a dict-based router:

Python
CHAIN_MAP = {
    "drug_lookup": drug_lookup_chain,
    "patient_counseling": counseling_chain,
    "interaction_check": interaction_chain,
    # ... more categories
}

def dict_router(inputs: dict) -> str:
    category = classifier.invoke({"query": inputs["query"]})
    return CHAIN_MAP.get(category, drug_lookup_chain).invoke(inputs)

Q3: Design a chain with a fallback — if the primary chain fails or returns low-confidence output, try a backup chain.

Answer:

Python
from langchain_core.runnables import RunnableParallel, RunnableLambda
from langchain_core.output_parsers import JsonOutputParser

# Approach 1: LangChain's built-in .with_fallbacks()
primary_chain = (
    ChatPromptTemplate.from_template("Answer precisely: {query}")
    | ChatOpenAI(model="gpt-4o", temperature=0)
    | parser
)

backup_chain = (
    ChatPromptTemplate.from_template("Answer simply: {query}")
    | ChatOpenAI(model="gpt-4o-mini", temperature=0)
    | parser
)

# Primary tries first, falls back on any exception
chain_with_fallback = primary_chain.with_fallbacks([backup_chain])

# Approach 2: Confidence-based fallback (semantic check)
class ConfidentAnswer(BaseModel):
    answer: str
    confidence: float = Field(ge=0.0, le=1.0)
    requires_more_info: bool

conf_parser = PydanticOutputParser(pydantic_object=ConfidentAnswer)
confident_chain = (
    ChatPromptTemplate.from_messages([
        ("system", "Answer with a confidence score. {format_instructions}"),
        ("human", "{query}"),
    ]).partial(format_instructions=conf_parser.get_format_instructions())
    | ChatOpenAI(model="gpt-4o", temperature=0)
    | conf_parser
)

def confident_or_escalate(inputs: dict) -> str:
    result = confident_chain.invoke(inputs)
    
    if result.confidence >= 0.8 and not result.requires_more_info:
        return result.answer
    
    # Low confidence — escalate to more powerful model with explicit retrieval
    retrieval_chain = (
        ChatPromptTemplate.from_messages([
            ("system", "You specialize in clinical pharmacology. Be thorough."),
            ("human", "Answer in detail: {query}"),
        ])
        | ChatOpenAI(model="gpt-4o", temperature=0, max_tokens=2000)
        | parser
    )
    return retrieval_chain.invoke(inputs)

adaptive_chain = RunnableLambda(confident_or_escalate)

Q4: How would you implement a chain that runs in parallel stages, then aggregates results?

Answer:

Python
from langchain_core.runnables import RunnableParallel, RunnablePassthrough

# Parallel analysis  synthesis pipeline

drug = "warfarin"

# Define independent analytical chains
class DrugAnalysis(BaseModel):
    findings: list[str]
    risk_level: str  # low, moderate, high

# Four parallel analyses
pharmacokinetics_chain = (
    ChatPromptTemplate.from_template("Analyze {drug} pharmacokinetics: absorption, distribution, metabolism, excretion.")
    | model | parser
)
pharmacodynamics_chain = (
    ChatPromptTemplate.from_template("Analyze {drug} pharmacodynamics: receptor targets, mechanisms, effects.")
    | model | parser
)
safety_chain = (
    ChatPromptTemplate.from_template("Analyze {drug} safety: adverse effects, contraindications, monitoring.")
    | model | parser
)
interactions_chain = (
    ChatPromptTemplate.from_template("Analyze {drug} major drug interactions and mechanisms.")
    | model | parser
)

# Synthesis prompt
synthesis_prompt = ChatPromptTemplate.from_template(
    """Synthesize this {drug} analysis into a 3-bullet clinical summary:

Pharmacokinetics: {pk}
Pharmacodynamics: {pd}
Safety: {safety}
Interactions: {interactions}

Focus on what prescribers need to know."""
)

# Full pipeline: parallel analysis  sequential synthesis
drug_analysis_pipeline = (
    RunnableParallel(
        pk=pharmacokinetics_chain,
        pd=pharmacodynamics_chain,
        safety=safety_chain,
        interactions=interactions_chain,
    )
    | RunnableLambda(lambda d: {**d, "drug": drug})  # Add drug name for synthesis prompt
    | synthesis_prompt
    | model | parser
)

result = drug_analysis_pipeline.invoke({"drug": drug})
print(result)

Q5: A chain is taking 8 seconds for a complex multi-step query. How do you optimize it?

Answer:

Diagnosis first:

Python
import time

def profile_chain(chain, inputs):
    """Time each step to find the bottleneck."""
    with_timing = []

    class TimingCallback:
        def on_chain_start(self, *args, **kwargs): self.start = time.time()
        def on_chain_end(self, *args, **kwargs):
            elapsed = (time.time() - self.start) * 1000
            with_timing.append(elapsed)

    result = chain.invoke(inputs)
    return result, with_timing

Optimization strategies:

Python
# Strategy 1: Parallelize independent steps
# Before: sequential (4s + 3s + 1s = 8s)
slow_chain = (
    RunnablePassthrough.assign(classification=classify_chain)   # 4s
    | RunnablePassthrough.assign(safety=safety_chain)           # 3s (independent!)
    | RunnablePassthrough.assign(summary=summary_chain)         # 1s
)

# After: parallelize independent steps (max(4,3) + 1 = 5s)
fast_chain = (
    RunnablePassthrough.assign(
        **RunnableParallel(
            classification=classify_chain,  # 4s ─┐
            safety=safety_chain,            # 3s ─┘ run simultaneously
        ).steps
    )
    | RunnablePassthrough.assign(summary=summary_chain)  # 1s after
)

# Strategy 2: Use cheaper model for classification/preprocessing
model_smart = ChatOpenAI(model="gpt-4o")
model_cheap = ChatOpenAI(model="gpt-4o-mini")

# Classification doesn't need the smartest model
classify_chain_fast = (
    ChatPromptTemplate.from_template("Classify: {query}. One word.")
    | model_cheap  # Was model_smart — saves 60% cost and latency
    | parser
)

# Strategy 3: Cache repeated sub-chains
from functools import lru_cache

@lru_cache(maxsize=1000)
def cached_classify(query: str) -> str:
    return classify_chain_fast.invoke({"query": query})

# Strategy 4: Reduce context size
# Instead of sending all 5 retrieved docs to LLM, compress to most relevant
def top_doc_only(docs):
    return docs[0]["content"] if docs else ""  # Only best doc

# Strategy 5: Stream — doesn't reduce total time, but perceived latency drops
for chunk in slow_chain.stream(inputs):
    yield chunk  # User sees first token in 2s instead of waiting 8s

Results: For a typical 8s chain: parallelizing 2 independent 3-4s steps + cheap classifier reduces to ~3s. Add streaming and perceived latency is 1-2s.