Prompt Chaining: Decomposing Complex Tasks
Break complex tasks into sequential prompts where each output feeds the next. Build reliable pipelines with validation, branching, and error recovery between steps.
Why Chain Prompts?
A single prompt asking a model to do many things simultaneously often produces lower quality than breaking the task into focused steps. Each step can be:
- Optimized independently
- Validated before proceeding
- Retried in isolation if it fails
- Tested in isolation during development
Think of prompt chaining like a function call pipeline: each function has one job and a clear interface.
Basic Linear Chain
from openai import OpenAI
from typing import Any, Callable
client = OpenAI()
def llm_call(prompt: str, system: str = "", temperature: float = 0) -> str:
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
response = client.chat.completions.create(
model="gpt-4o",
messages=messages,
temperature=temperature,
)
return response.choices[0].message.content
# Chain: Extract → Classify → Summarize → Format
def process_clinical_note(raw_note: str) -> dict:
"""Multi-step processing of a clinical note."""
# Step 1: Extract key facts
extracted = llm_call(
prompt=f"Extract all medications, diagnoses, and lab values from this clinical note. Return a structured list.\n\n{raw_note}",
system="You are a clinical data extraction specialist. Be precise and complete."
)
print("Step 1 complete: extraction")
# Step 2: Classify risk
risk_assessment = llm_call(
prompt=f"Based on these extracted clinical facts, assess the patient's medication risk level (low/moderate/high) and explain why:\n\n{extracted}",
system="You are a clinical pharmacist. Focus on drug interactions, dosing safety, and contraindications."
)
print("Step 2 complete: risk assessment")
# Step 3: Generate recommendations
recommendations = llm_call(
prompt=f"""Clinical facts: {extracted}
Risk assessment: {risk_assessment}
Generate specific, prioritized recommendations for the clinical team. Number them.""",
system="You are a clinical pharmacist generating action items. Be specific and actionable."
)
print("Step 3 complete: recommendations")
# Step 4: Format for EHR output
formatted = llm_call(
prompt=f"""Format the following clinical pharmacy review for entry into an EHR system.
Use these exact sections: SUMMARY | RISK LEVEL | RECOMMENDATIONS | MONITORING PLAN
Risk Assessment: {risk_assessment}
Recommendations: {recommendations}""",
system="Format clinical documentation. Be concise and professional."
)
print("Step 4 complete: formatting")
return {
"extracted_data": extracted,
"risk_assessment": risk_assessment,
"recommendations": recommendations,
"formatted_output": formatted,
}
# Run
note = """
Patient: 68F with AFib and T2DM
Medications: Warfarin 5mg daily, Metformin 1000mg BID, Clarithromycin 500mg BID (new, started 2 days ago)
INR today: 3.8 (was 2.4 two weeks ago, target 2.0-3.0)
Labs: SCr 1.2 mg/dL, eGFR 52
"""
result = process_clinical_note(note)
print("\n=== FINAL OUTPUT ===")
print(result["formatted_output"])Chain with Validation Gates
Add validation steps that catch errors before they propagate:
from pydantic import BaseModel, ValidationError
import json
class ExtractedMedications(BaseModel):
medications: list[str]
diagnoses: list[str]
abnormal_labs: list[str]
concerns_identified: bool
def extract_with_validation(note: str) -> ExtractedMedications:
"""Extract and validate structured data from a clinical note."""
prompt = f"""Extract from this clinical note. Return valid JSON matching this schema:
{{
"medications": ["list of medication names with doses"],
"diagnoses": ["list of diagnoses"],
"abnormal_labs": ["list of abnormal lab findings with values"],
"concerns_identified": true/false
}}
Return ONLY the JSON, no other text.
Note: {note}"""
for attempt in range(3):
raw = llm_call(prompt)
# Strip markdown if present
clean = raw.strip().lstrip("```json").lstrip("```").rstrip("```").strip()
try:
data = json.loads(clean)
return ExtractedMedications(**data)
except (json.JSONDecodeError, ValidationError) as e:
if attempt == 2:
raise ValueError(f"Failed to extract valid data after 3 attempts: {e}")
prompt = f"Your previous response was invalid. Error: {e}\nPlease try again with valid JSON:\n\n{note}"
def chain_with_gates(note: str) -> dict:
"""Chain with explicit validation between steps."""
# Step 1 with validation
extracted = extract_with_validation(note)
print(f"Validated: {len(extracted.medications)} medications, {len(extracted.diagnoses)} diagnoses")
if not extracted.concerns_identified:
return {"status": "no_concerns", "extracted": extracted}
# Step 2: Only run if concerns were found
risk_prompt = f"""Evaluate these medications for interaction risks:
Medications: {extracted.medications}
Diagnoses: {extracted.diagnoses}
Abnormal labs: {extracted.abnormal_labs}
Return JSON:
{{"risk_level": "low|moderate|high", "primary_concern": "string", "interactions": ["list of interactions"]}}"""
risk_raw = llm_call(risk_prompt)
risk = json.loads(risk_raw.strip().lstrip("```json").rstrip("```").strip())
return {"status": "concerns_found", "extracted": extracted, "risk": risk}Conditional Branching
Different chains for different cases:
def adaptive_chain(question: str) -> str:
"""Route to different processing chains based on question type."""
# Step 1: Classify the question type
classification = llm_call(
f"""Classify this pharmacology question into exactly one category:
- INTERACTION (about drug-drug or drug-food interactions)
- DOSING (about dose selection, adjustment, or titration)
- MECHANISM (about how a drug works)
- GENERAL (other pharmacology questions)
Question: {question}
Respond with only the category name."""
).strip().upper()
print(f"Question classified as: {classification}")
# Step 2: Route to specialized chain
if classification == "INTERACTION":
return interaction_chain(question)
elif classification == "DOSING":
return dosing_chain(question)
elif classification == "MECHANISM":
return mechanism_chain(question)
else:
return general_chain(question)
def interaction_chain(question: str) -> str:
# Specialized processing for interaction questions
drugs = llm_call(f"Extract only the drug names from this question: {question}")
interactions = llm_call(f"What are the clinical interactions between {drugs}? Focus on mechanism and severity.")
management = llm_call(f"For these interactions:\n{interactions}\n\nWhat is the recommended clinical management?")
return llm_call(f"Combine into a clear clinical response:\nInteractions: {interactions}\nManagement: {management}")
def dosing_chain(question: str) -> str:
# Specialized for dosing questions
context = llm_call(f"Extract all patient parameters relevant to dosing from: {question}")
standard_dose = llm_call(f"What is the standard dose for the drug in this question: {question}")
adjusted_dose = llm_call(f"Given patient parameters:\n{context}\n\nAdjust the standard dose:\n{standard_dose}")
return adjusted_dose
def mechanism_chain(question: str) -> str:
return llm_call(f"Explain the pharmacological mechanism for: {question}", temperature=0)
def general_chain(question: str) -> str:
return llm_call(question, system="You are a clinical pharmacology expert.", temperature=0)
# Example
print(adaptive_chain("How does clarithromycin affect warfarin levels?"))Map-Reduce Chains
Process multiple items in parallel then combine:
import asyncio
async def process_one_medication(
medication: str,
patient_context: str,
) -> dict:
"""Async processing of a single medication."""
prompt = f"""For the medication: {medication}
Patient context: {patient_context}
Provide:
1. Appropriate dose given the patient context
2. Key monitoring parameters
3. Any safety concerns
Return as JSON: {{"dose": "...", "monitoring": ["..."], "concerns": ["..."]}}"""
response = await client.chat.completions.create( # Would use async client
model="gpt-4o",
messages=[{"role": "user", "content": prompt}],
temperature=0,
)
return {"medication": medication, "analysis": json.loads(response.choices[0].message.content)}
def map_reduce_medication_review(
medications: list[str],
patient_context: str,
) -> str:
"""MAP: analyze each medication → REDUCE: synthesize into a combined review."""
# MAP phase: process each medication
individual_reviews = []
for med in medications:
review = llm_call(
f"""Medication: {med}
Patient context: {patient_context}
Provide brief analysis: dose appropriateness, monitoring, concerns.
Return JSON: {{"dose_ok": bool, "monitoring": ["params"], "concerns": ["issues"]}}"""
)
individual_reviews.append({"medication": med, "review": review})
# REDUCE phase: synthesize all reviews
combined = "\n\n".join(
f"Medication: {r['medication']}\nAnalysis: {r['review']}"
for r in individual_reviews
)
synthesis = llm_call(
f"""Multiple medication reviews:
{combined}
Synthesize into a single pharmacy review that:
1. Highlights cross-medication interactions
2. Prioritizes the top 3 concerns
3. Provides an overall risk assessment
4. Lists recommended monitoring"""
)
return synthesis
result = map_reduce_medication_review(
medications=["warfarin 5mg daily", "clarithromycin 500mg BID", "metformin 1000mg BID"],
patient_context="68-year-old female, eGFR 52, AFib and T2DM, INR 3.8",
)
print(result)Chain Observability
Log each step for debugging and monitoring:
import time
from dataclasses import dataclass
@dataclass
class ChainStep:
step_name: str
input_summary: str
output: str
duration_seconds: float
token_count: int
class ObservableChain:
def __init__(self, chain_name: str):
self.chain_name = chain_name
self.steps: list[ChainStep] = []
def run_step(self, step_name: str, prompt: str, system: str = "") -> str:
start = time.time()
output = llm_call(prompt, system)
duration = time.time() - start
token_count = len(prompt.split()) + len(output.split()) # Rough estimate
step = ChainStep(
step_name=step_name,
input_summary=prompt[:100] + "...",
output=output,
duration_seconds=duration,
token_count=token_count,
)
self.steps.append(step)
return output
def summary(self) -> dict:
return {
"chain": self.chain_name,
"total_steps": len(self.steps),
"total_duration": sum(s.duration_seconds for s in self.steps),
"total_tokens": sum(s.token_count for s in self.steps),
"steps": [s.step_name for s in self.steps],
}
# Usage
chain = ObservableChain("clinical_note_processing")
extracted = chain.run_step("extraction", f"Extract medications from: {note}")
risk = chain.run_step("risk_assessment", f"Assess risks for: {extracted}")
print(chain.summary())When to Chain vs Single Prompt
| Situation | Use single prompt | Use chain | |---|---|---| | Simple factual Q&A | Yes | No | | Complex multi-step task | No | Yes | | Need intermediate validation | No | Yes | | Different experts for different steps | No | Yes | | Task under 500 output tokens | Yes | No | | High-stakes with error catching needed | No | Yes |
Chains add latency (serial LLM calls). If a single well-crafted prompt with structured output can do the job, prefer it over a chain.
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.