LangChain Mastery · Lesson 21 of 33
Defining Custom Tools with @tool
What Makes a Good Tool?
A LangChain tool has three components:
- Name — what the model calls to invoke it
- Description — how the model decides when to use it (critical)
- Function — the actual Python code that runs
The description is the most important: it's what the LLM reads to decide if this tool is relevant to the current step. A vague description leads to wrong tool selection.
Method 1: @tool Decorator (Simplest)
Python
from langchain_core.tools import tool
from typing import Optional
@tool
def get_drug_information(drug_name: str) -> str:
"""
Retrieve clinical information about a medication from the pharmacy database.
Use this tool when you need factual information about a drug including:
- Mechanism of action
- Standard dosing
- Common indications
- Drug class
Do NOT use for drug interaction checks — use check_drug_interaction instead.
"""
# In production: query a real database
drug_db = {
"warfarin": {
"class": "Anticoagulant",
"mechanism": "Inhibits vitamin K epoxide reductase (VKORC1)",
"typical_dose": "2-10mg daily, INR-guided",
"indication": "Atrial fibrillation, DVT, PE, prosthetic heart valves",
},
"metformin": {
"class": "Biguanide antidiabetic",
"mechanism": "Activates AMPK, reduces hepatic glucose output",
"typical_dose": "500-2550mg/day in divided doses",
"indication": "Type 2 diabetes mellitus",
},
}
drug_data = drug_db.get(drug_name.lower().strip())
if not drug_data:
return f"Drug '{drug_name}' not found. Please check the spelling or use a generic name."
return f"Drug: {drug_name}\n" + "\n".join(f"{k}: {v}" for k, v in drug_data.items())
# Tool metadata
print(get_drug_information.name) # "get_drug_information"
print(get_drug_information.description[:100]) # First 100 chars of docstring
print(get_drug_information.args) # {"drug_name": {"title": "Drug Name", "type": "string"}}
# Call the tool directly
result = get_drug_information.invoke({"drug_name": "warfarin"})
print(result)Multi-Argument Tools
Python
@tool
def check_drug_interaction(
drug_a: str,
drug_b: str,
severity_filter: Optional[str] = None,
) -> str:
"""
Check for pharmacological interactions between two drugs.
Returns interaction severity (Major/Moderate/Minor/None), mechanism, and clinical recommendation.
Args:
drug_a: First drug name (generic preferred)
drug_b: Second drug name (generic preferred)
severity_filter: Optional filter for minimum severity (e.g., "Major")
Use this tool specifically when asked about drug-drug interactions, combination therapy safety,
or whether two medications can be co-administered.
"""
# Simplified interaction database
interactions = {
("warfarin", "aspirin"): {
"severity": "Major",
"mechanism": "Both inhibit platelet aggregation and increase bleeding time",
"recommendation": "Avoid unless benefit outweighs substantial bleeding risk. If unavoidable, use lowest aspirin dose (81mg) and monitor closely.",
},
("warfarin", "amiodarone"): {
"severity": "Major",
"mechanism": "Amiodarone inhibits CYP2C9, reducing warfarin metabolism and raising INR significantly",
"recommendation": "Reduce warfarin dose by 30-50%. Check INR within 1 week of starting amiodarone and weekly for first month.",
},
("metformin", "contrast_dye"): {
"severity": "Major",
"mechanism": "Iodinated contrast can cause acute kidney injury, leading to metformin accumulation and lactic acidosis",
"recommendation": "Hold metformin 48 hours before and after contrast. Resume only after normal renal function confirmed.",
},
}
key = tuple(sorted([drug_a.lower(), drug_b.lower()]))
interaction = interactions.get(key)
if not interaction:
return f"No documented major interaction between {drug_a} and {drug_b}. Always verify with current clinical resources."
if severity_filter and interaction["severity"] != severity_filter:
return f"Interaction found but severity ({interaction['severity']}) does not match filter ({severity_filter})."
return (
f"Interaction: {drug_a} + {drug_b}\n"
f"Severity: {interaction['severity']}\n"
f"Mechanism: {interaction['mechanism']}\n"
f"Recommendation: {interaction['recommendation']}"
)Method 2: StructuredTool with Pydantic Schema
For complex inputs with validation:
Python
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field, validator
from typing import Optional, Literal
class DoseCalculationInput(BaseModel):
"""Input schema for dose calculation tool."""
drug: str = Field(description="Drug name (generic)")
patient_weight_kg: float = Field(description="Patient weight in kilograms", gt=0, le=500)
patient_age_years: int = Field(description="Patient age in years", ge=0, le=120)
renal_function: Literal["normal", "mild_impairment", "moderate_impairment", "severe_impairment"] = Field(
default="normal",
description="Level of renal impairment based on eGFR"
)
indication: Optional[str] = Field(default=None, description="Clinical indication for the drug")
@validator("patient_weight_kg")
def weight_must_be_realistic(cls, v):
if v < 2:
raise ValueError("Weight seems too low — please verify")
return v
def calculate_dose(
drug: str,
patient_weight_kg: float,
patient_age_years: int,
renal_function: str = "normal",
indication: Optional[str] = None,
) -> str:
"""Calculate appropriate drug dose based on patient characteristics."""
base_doses = {
"vancomycin": 15, # mg/kg
"gentamicin": 5, # mg/kg
}
base_dose_per_kg = base_doses.get(drug.lower())
if not base_dose_per_kg:
return f"Weight-based dosing not available for {drug}"
dose = base_dose_per_kg * patient_weight_kg
# Adjust for renal function
adjustments = {
"mild_impairment": 0.85,
"moderate_impairment": 0.65,
"severe_impairment": 0.40,
}
adjustment = adjustments.get(renal_function, 1.0)
adjusted_dose = dose * adjustment
# Adjust for age
if patient_age_years >= 65:
adjusted_dose *= 0.9
age_note = " (10% reduction for elderly patient)"
else:
age_note = ""
return (
f"Calculated dose for {drug}:\n"
f"Weight-based: {dose:.0f}mg\n"
f"Renal adjustment ({renal_function}): ×{adjustment:.2f}\n"
f"Final dose: {adjusted_dose:.0f}mg{age_note}\n"
f"Note: Always confirm with pharmacist and adjust based on drug levels."
)
dose_calculator = StructuredTool.from_function(
func=calculate_dose,
name="calculate_drug_dose",
description=(
"Calculate weight-based drug doses adjusted for renal function and age. "
"Use for medications requiring precise dosing (e.g., vancomycin, aminoglycosides). "
"Input schema validates patient parameters automatically."
),
args_schema=DoseCalculationInput,
return_direct=False,
)Method 3: BaseTool Subclass (Full Control)
Python
from langchain_core.tools import BaseTool
from pydantic import BaseModel
from typing import Any, Optional, Type
class EHRQueryInput(BaseModel):
patient_id: str
query_type: Literal["medications", "labs", "allergies", "diagnoses"]
date_range_days: int = 30
class EHRQueryTool(BaseTool):
"""Query a patient's Electronic Health Record for clinical information."""
name: str = "query_ehr"
description: str = (
"Query the Electronic Health Record (EHR) for patient-specific information. "
"Use when you need: current medications, recent labs, known allergies, or active diagnoses. "
"Requires patient_id — only use when you have explicit patient context."
)
args_schema: Type[BaseModel] = EHRQueryInput
def _run(
self,
patient_id: str,
query_type: str,
date_range_days: int = 30,
run_manager=None,
) -> str:
"""Synchronous EHR query."""
# In production: make API call to EHR system
mock_data = {
"medications": f"Patient {patient_id}: warfarin 5mg daily, metformin 1000mg BID",
"labs": f"Patient {patient_id}: INR 2.4 (2026-05-10), HbA1c 7.2% (2026-05-01)",
"allergies": f"Patient {patient_id}: penicillin (anaphylaxis), sulfa drugs (rash)",
"diagnoses": f"Patient {patient_id}: Atrial fibrillation, Type 2 diabetes mellitus",
}
return mock_data.get(query_type, f"No data found for query type: {query_type}")
async def _arun(self, patient_id: str, query_type: str, date_range_days: int = 30) -> str:
"""Async EHR query."""
import asyncio
await asyncio.sleep(0) # Simulate async I/O
return self._run(patient_id, query_type, date_range_days)
ehr_tool = EHRQueryTool()Tool Input Validation and Error Handling
Python
@tool
def validate_and_lookup(drug_name: str) -> str:
"""
Look up validated drug information. Normalizes drug names and handles common misspellings.
"""
# Normalize input
normalized = drug_name.lower().strip()
# Common name normalization
aliases = {
"tylenol": "acetaminophen",
"advil": "ibuprofen",
"motrin": "ibuprofen",
"coumadin": "warfarin",
"glucophage": "metformin",
}
normalized = aliases.get(normalized, normalized)
# Validate: no special characters, reasonable length
if not normalized.replace(" ", "").replace("-", "").isalpha():
return f"Invalid drug name '{drug_name}' — contains invalid characters"
if len(normalized) > 100:
return "Drug name too long — please provide the generic drug name"
return get_drug_information.invoke({"drug_name": normalized})
# Tool with error handling wrapper
@tool
def safe_interaction_check(drug_a: str, drug_b: str) -> str:
"""
Safely check drug interactions with automatic error handling.
Returns a safe default message if the interaction database is unavailable.
"""
try:
return check_drug_interaction.invoke({"drug_a": drug_a, "drug_b": drug_b})
except Exception as e:
return (
f"Unable to retrieve interaction data for {drug_a} + {drug_b}. "
f"Please consult Lexicomp or Micromedex directly. Error: {str(e)}"
)Assembling a Tool Set
Python
from langchain_openai import ChatOpenAI
from langchain.agents import create_tool_calling_agent, AgentExecutor
from langchain_core.prompts import ChatPromptTemplate
# Define the complete tool set
clinical_tools = [
get_drug_information,
check_drug_interaction,
dose_calculator,
ehr_tool,
validate_and_lookup,
]
# Bind to model (model decides which tools to call)
model = ChatOpenAI(model="gpt-4o", temperature=0)
prompt = ChatPromptTemplate.from_messages([
("system",
"You are a clinical pharmacist. Use your tools to provide accurate, "
"evidence-based drug information. When checking interactions, always "
"include severity and clinical recommendation."),
("placeholder", "{chat_history}"),
("human", "{input}"),
("placeholder", "{agent_scratchpad}"),
])
agent = create_tool_calling_agent(model, clinical_tools, prompt)
executor = AgentExecutor(
agent=agent,
tools=clinical_tools,
verbose=True,
max_iterations=6,
)
result = executor.invoke({
"input": "My patient is on warfarin and needs aspirin for a cardiac event. What should I know?",
"chat_history": [],
})
print(result["output"])