Learnixo
Back to blog
AI Systemsintermediate

Defining Custom Tools with @tool

Create LangChain tools with @tool decorator, StructuredTool, BaseTool class, and Pydantic input schemas. Build validated, type-safe tools for clinical AI agents.

Asma Hafeez KhanMay 16, 20267 min read
LangChainTools@toolStructuredToolBaseToolAgent Tools
Share:𝕏

What Makes a Good Tool?

A LangChain tool has three components:

  1. Name — what the model calls to invoke it
  2. Description — how the model decides when to use it (critical)
  3. Function — the actual Python code that runs

The description is the most important: it's what the LLM reads to decide if this tool is relevant to the current step. A vague description leads to wrong tool selection.


Method 1: @tool Decorator (Simplest)

Python
from langchain_core.tools import tool
from typing import Optional

@tool
def get_drug_information(drug_name: str) -> str:
    """
    Retrieve clinical information about a medication from the pharmacy database.
    
    Use this tool when you need factual information about a drug including:
    - Mechanism of action
    - Standard dosing
    - Common indications
    - Drug class
    
    Do NOT use for drug interaction checks — use check_drug_interaction instead.
    """
    # In production: query a real database
    drug_db = {
        "warfarin": {
            "class": "Anticoagulant",
            "mechanism": "Inhibits vitamin K epoxide reductase (VKORC1)",
            "typical_dose": "2-10mg daily, INR-guided",
            "indication": "Atrial fibrillation, DVT, PE, prosthetic heart valves",
        },
        "metformin": {
            "class": "Biguanide antidiabetic",
            "mechanism": "Activates AMPK, reduces hepatic glucose output",
            "typical_dose": "500-2550mg/day in divided doses",
            "indication": "Type 2 diabetes mellitus",
        },
    }
    drug_data = drug_db.get(drug_name.lower().strip())
    if not drug_data:
        return f"Drug '{drug_name}' not found. Please check the spelling or use a generic name."
    return f"Drug: {drug_name}\n" + "\n".join(f"{k}: {v}" for k, v in drug_data.items())


# Tool metadata
print(get_drug_information.name)           # "get_drug_information"
print(get_drug_information.description[:100])  # First 100 chars of docstring
print(get_drug_information.args)           # {"drug_name": {"title": "Drug Name", "type": "string"}}

# Call the tool directly
result = get_drug_information.invoke({"drug_name": "warfarin"})
print(result)

Multi-Argument Tools

Python
@tool
def check_drug_interaction(
    drug_a: str,
    drug_b: str,
    severity_filter: Optional[str] = None,
) -> str:
    """
    Check for pharmacological interactions between two drugs.
    
    Returns interaction severity (Major/Moderate/Minor/None), mechanism, and clinical recommendation.
    
    Args:
        drug_a: First drug name (generic preferred)
        drug_b: Second drug name (generic preferred)
        severity_filter: Optional filter for minimum severity (e.g., "Major")
    
    Use this tool specifically when asked about drug-drug interactions, combination therapy safety,
    or whether two medications can be co-administered.
    """
    # Simplified interaction database
    interactions = {
        ("warfarin", "aspirin"): {
            "severity": "Major",
            "mechanism": "Both inhibit platelet aggregation and increase bleeding time",
            "recommendation": "Avoid unless benefit outweighs substantial bleeding risk. If unavoidable, use lowest aspirin dose (81mg) and monitor closely.",
        },
        ("warfarin", "amiodarone"): {
            "severity": "Major",
            "mechanism": "Amiodarone inhibits CYP2C9, reducing warfarin metabolism and raising INR significantly",
            "recommendation": "Reduce warfarin dose by 30-50%. Check INR within 1 week of starting amiodarone and weekly for first month.",
        },
        ("metformin", "contrast_dye"): {
            "severity": "Major",
            "mechanism": "Iodinated contrast can cause acute kidney injury, leading to metformin accumulation and lactic acidosis",
            "recommendation": "Hold metformin 48 hours before and after contrast. Resume only after normal renal function confirmed.",
        },
    }

    key = tuple(sorted([drug_a.lower(), drug_b.lower()]))
    interaction = interactions.get(key)

    if not interaction:
        return f"No documented major interaction between {drug_a} and {drug_b}. Always verify with current clinical resources."

    if severity_filter and interaction["severity"] != severity_filter:
        return f"Interaction found but severity ({interaction['severity']}) does not match filter ({severity_filter})."

    return (
        f"Interaction: {drug_a} + {drug_b}\n"
        f"Severity: {interaction['severity']}\n"
        f"Mechanism: {interaction['mechanism']}\n"
        f"Recommendation: {interaction['recommendation']}"
    )

Method 2: StructuredTool with Pydantic Schema

For complex inputs with validation:

Python
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field, validator
from typing import Optional, Literal

class DoseCalculationInput(BaseModel):
    """Input schema for dose calculation tool."""
    drug: str = Field(description="Drug name (generic)")
    patient_weight_kg: float = Field(description="Patient weight in kilograms", gt=0, le=500)
    patient_age_years: int = Field(description="Patient age in years", ge=0, le=120)
    renal_function: Literal["normal", "mild_impairment", "moderate_impairment", "severe_impairment"] = Field(
        default="normal",
        description="Level of renal impairment based on eGFR"
    )
    indication: Optional[str] = Field(default=None, description="Clinical indication for the drug")

    @validator("patient_weight_kg")
    def weight_must_be_realistic(cls, v):
        if v < 2:
            raise ValueError("Weight seems too low — please verify")
        return v


def calculate_dose(
    drug: str,
    patient_weight_kg: float,
    patient_age_years: int,
    renal_function: str = "normal",
    indication: Optional[str] = None,
) -> str:
    """Calculate appropriate drug dose based on patient characteristics."""
    base_doses = {
        "vancomycin": 15,   # mg/kg
        "gentamicin": 5,    # mg/kg
    }

    base_dose_per_kg = base_doses.get(drug.lower())
    if not base_dose_per_kg:
        return f"Weight-based dosing not available for {drug}"

    dose = base_dose_per_kg * patient_weight_kg

    # Adjust for renal function
    adjustments = {
        "mild_impairment": 0.85,
        "moderate_impairment": 0.65,
        "severe_impairment": 0.40,
    }
    adjustment = adjustments.get(renal_function, 1.0)
    adjusted_dose = dose * adjustment

    # Adjust for age
    if patient_age_years >= 65:
        adjusted_dose *= 0.9
        age_note = " (10% reduction for elderly patient)"
    else:
        age_note = ""

    return (
        f"Calculated dose for {drug}:\n"
        f"Weight-based: {dose:.0f}mg\n"
        f"Renal adjustment ({renal_function}): ×{adjustment:.2f}\n"
        f"Final dose: {adjusted_dose:.0f}mg{age_note}\n"
        f"Note: Always confirm with pharmacist and adjust based on drug levels."
    )


dose_calculator = StructuredTool.from_function(
    func=calculate_dose,
    name="calculate_drug_dose",
    description=(
        "Calculate weight-based drug doses adjusted for renal function and age. "
        "Use for medications requiring precise dosing (e.g., vancomycin, aminoglycosides). "
        "Input schema validates patient parameters automatically."
    ),
    args_schema=DoseCalculationInput,
    return_direct=False,
)

Method 3: BaseTool Subclass (Full Control)

Python
from langchain_core.tools import BaseTool
from pydantic import BaseModel
from typing import Any, Optional, Type

class EHRQueryInput(BaseModel):
    patient_id: str
    query_type: Literal["medications", "labs", "allergies", "diagnoses"]
    date_range_days: int = 30

class EHRQueryTool(BaseTool):
    """Query a patient's Electronic Health Record for clinical information."""

    name: str = "query_ehr"
    description: str = (
        "Query the Electronic Health Record (EHR) for patient-specific information. "
        "Use when you need: current medications, recent labs, known allergies, or active diagnoses. "
        "Requires patient_id — only use when you have explicit patient context."
    )
    args_schema: Type[BaseModel] = EHRQueryInput

    def _run(
        self,
        patient_id: str,
        query_type: str,
        date_range_days: int = 30,
        run_manager=None,
    ) -> str:
        """Synchronous EHR query."""
        # In production: make API call to EHR system
        mock_data = {
            "medications": f"Patient {patient_id}: warfarin 5mg daily, metformin 1000mg BID",
            "labs": f"Patient {patient_id}: INR 2.4 (2026-05-10), HbA1c 7.2% (2026-05-01)",
            "allergies": f"Patient {patient_id}: penicillin (anaphylaxis), sulfa drugs (rash)",
            "diagnoses": f"Patient {patient_id}: Atrial fibrillation, Type 2 diabetes mellitus",
        }
        return mock_data.get(query_type, f"No data found for query type: {query_type}")

    async def _arun(self, patient_id: str, query_type: str, date_range_days: int = 30) -> str:
        """Async EHR query."""
        import asyncio
        await asyncio.sleep(0)   # Simulate async I/O
        return self._run(patient_id, query_type, date_range_days)


ehr_tool = EHRQueryTool()

Tool Input Validation and Error Handling

Python
@tool
def validate_and_lookup(drug_name: str) -> str:
    """
    Look up validated drug information. Normalizes drug names and handles common misspellings.
    """
    # Normalize input
    normalized = drug_name.lower().strip()

    # Common name normalization
    aliases = {
        "tylenol": "acetaminophen",
        "advil": "ibuprofen",
        "motrin": "ibuprofen",
        "coumadin": "warfarin",
        "glucophage": "metformin",
    }
    normalized = aliases.get(normalized, normalized)

    # Validate: no special characters, reasonable length
    if not normalized.replace(" ", "").replace("-", "").isalpha():
        return f"Invalid drug name '{drug_name}' — contains invalid characters"

    if len(normalized) > 100:
        return "Drug name too long — please provide the generic drug name"

    return get_drug_information.invoke({"drug_name": normalized})


# Tool with error handling wrapper
@tool
def safe_interaction_check(drug_a: str, drug_b: str) -> str:
    """
    Safely check drug interactions with automatic error handling.
    Returns a safe default message if the interaction database is unavailable.
    """
    try:
        return check_drug_interaction.invoke({"drug_a": drug_a, "drug_b": drug_b})
    except Exception as e:
        return (
            f"Unable to retrieve interaction data for {drug_a} + {drug_b}. "
            f"Please consult Lexicomp or Micromedex directly. Error: {str(e)}"
        )

Assembling a Tool Set

Python
from langchain_openai import ChatOpenAI
from langchain.agents import create_tool_calling_agent, AgentExecutor
from langchain_core.prompts import ChatPromptTemplate

# Define the complete tool set
clinical_tools = [
    get_drug_information,
    check_drug_interaction,
    dose_calculator,
    ehr_tool,
    validate_and_lookup,
]

# Bind to model (model decides which tools to call)
model = ChatOpenAI(model="gpt-4o", temperature=0)

prompt = ChatPromptTemplate.from_messages([
    ("system",
     "You are a clinical pharmacist. Use your tools to provide accurate, "
     "evidence-based drug information. When checking interactions, always "
     "include severity and clinical recommendation."),
    ("placeholder", "{chat_history}"),
    ("human", "{input}"),
    ("placeholder", "{agent_scratchpad}"),
])

agent = create_tool_calling_agent(model, clinical_tools, prompt)
executor = AgentExecutor(
    agent=agent,
    tools=clinical_tools,
    verbose=True,
    max_iterations=6,
)

result = executor.invoke({
    "input": "My patient is on warfarin and needs aspirin for a cardiac event. What should I know?",
    "chat_history": [],
})
print(result["output"])

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.