Learnixo
Back to blog
AI Systemsintermediate

Building a Custom Tool End-to-End

Full walkthrough of building a production-ready custom tool: schema design, implementation, input validation, structured output, FastAPI integration, and testing.

Asma Hafeez KhanMay 15, 20268 min read
Tool CallingFastAPIPostgreSQLPythonAI Agents
Share:š•

What We're Building

This lesson walks through building a search_drug_database tool end-to-end. By the end you'll have:

  • A well-designed JSON Schema
  • A FastAPI endpoint that the tool calls
  • Input validation with Pydantic
  • Structured JSON output
  • A complete agent that uses the tool
  • A test suite

The example uses a pharmaceutical drug database, but the pattern applies to any domain.


Step 1: Design the Schema First

Before writing any code, design the schema. Ask:

  • What is the exact purpose of this tool? (not "search drugs" — too vague)
  • What inputs does it need?
  • What outputs should it return?
  • When should the LLM call it vs other tools?
Python
SEARCH_DRUG_DATABASE_SCHEMA = {
    "type": "function",
    "function": {
        "name": "search_drug_database",
        "description": (
            "Search the internal drug formulary database by drug name, active ingredient, "
            "or therapeutic class. Returns a list of matching drugs with their IDs and basic info. "
            "Use this to look up whether a drug is in the formulary and to get its drug_id for "
            "further lookups. Do NOT use this for dosage or interaction questions — use "
            "get_drug_details instead after you have a drug_id."
        ),
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": (
                        "The search term. Can be a brand name (e.g. 'Glucophage'), "
                        "generic name (e.g. 'metformin'), or therapeutic class "
                        "(e.g. 'biguanide' or 'antidiabetic')."
                    )
                },
                "formulary": {
                    "type": "string",
                    "enum": ["hospital", "outpatient", "all"],
                    "description": (
                        "Which formulary to search. 'hospital' for inpatient drugs, "
                        "'outpatient' for retail pharmacy drugs, 'all' to search both."
                    )
                },
                "include_discontinued": {
                    "type": "boolean",
                    "description": (
                        "Whether to include discontinued drugs in results. "
                        "Defaults to false. Set to true only when specifically asked about "
                        "historical or discontinued medications."
                    )
                },
                "limit": {
                    "type": "integer",
                    "description": "Maximum number of results to return. Between 1 and 20. Defaults to 5."
                }
            },
            "required": ["query", "formulary"]
        }
    }
}

Step 2: Define the Data Models

Python
# models.py
from pydantic import BaseModel, Field, field_validator
from typing import Optional
from enum import Enum

class FormularyType(str, Enum):
    hospital = "hospital"
    outpatient = "outpatient"
    all = "all"

class DrugSearchRequest(BaseModel):
    """Input model — validates everything before hitting the database."""
    query: str = Field(..., min_length=2, max_length=200)
    formulary: FormularyType
    include_discontinued: bool = False
    limit: int = Field(default=5, ge=1, le=20)

    @field_validator("query")
    @classmethod
    def sanitize_query(cls, v: str) -> str:
        # Strip characters that could be used in SQL injection
        # (even with parameterized queries, defense in depth)
        stripped = v.strip()
        if not stripped:
            raise ValueError("Query cannot be empty after stripping whitespace")
        return stripped

class DrugSummary(BaseModel):
    """Individual drug result in search output."""
    drug_id: str
    name: str
    generic_name: str
    therapeutic_class: str
    formulary: str
    is_active: bool
    strength: Optional[str] = None

class DrugSearchResponse(BaseModel):
    """Structured output — what the LLM receives."""
    query: str
    formulary: str
    total_found: int
    returned: int
    results: list[DrugSummary]
    has_more: bool

Step 3: Implement the Tool Function

Python
# tools/drug_search.py
import asyncio
import asyncpg
import logging
from models import DrugSearchRequest, DrugSearchResponse, DrugSummary

logger = logging.getLogger(__name__)

DB_DSN = "postgresql://readonly_user:password@localhost:5432/pharmacy_db"

async def search_drug_database(request: DrugSearchRequest) -> DrugSearchResponse:
    """
    Search the drug formulary database.
    Uses a read-only DB user — cannot modify any data.
    """
    conn = await asyncpg.connect(DB_DSN)

    try:
        # Build the WHERE clause based on request parameters
        conditions = [
            "( LOWER(brand_name) LIKE $1 OR LOWER(generic_name) LIKE $1 OR LOWER(therapeutic_class) LIKE $1 )"
        ]
        params = [f"%{request.query.lower()}%"]
        param_index = 2

        if not request.include_discontinued:
            conditions.append(f"is_active = ${param_index}")
            params.append(True)
            param_index += 1

        if request.formulary != "all":
            conditions.append(f"formulary_type = ${param_index}")
            params.append(request.formulary.value)
            param_index += 1

        where_clause = " AND ".join(conditions)

        # Count total matches
        count_sql = f"SELECT COUNT(*) FROM drugs WHERE {where_clause}"
        total = await conn.fetchval(count_sql, *params)

        # Fetch paginated results
        fetch_sql = f"""
            SELECT drug_id, brand_name, generic_name, therapeutic_class,
                   formulary_type, is_active, strength
            FROM drugs
            WHERE {where_clause}
            ORDER BY brand_name
            LIMIT ${param_index}
        """
        params.append(request.limit)
        rows = await conn.fetch(fetch_sql, *params)

        results = [
            DrugSummary(
                drug_id=row["drug_id"],
                name=row["brand_name"],
                generic_name=row["generic_name"],
                therapeutic_class=row["therapeutic_class"],
                formulary=row["formulary_type"],
                is_active=row["is_active"],
                strength=row["strength"]
            )
            for row in rows
        ]

        return DrugSearchResponse(
            query=request.query,
            formulary=request.formulary.value,
            total_found=total,
            returned=len(results),
            results=results,
            has_more=total > request.limit
        )

    except asyncpg.PostgresError as e:
        logger.error("Database error in drug search: %s", e, extra={"query": request.query})
        raise
    finally:
        await conn.close()

Step 4: Wrap in a FastAPI Endpoint

Python
# main.py
import json
import logging
from fastapi import FastAPI, HTTPException
from pydantic import ValidationError
from models import DrugSearchRequest, DrugSearchResponse
from tools.drug_search import search_drug_database

app = FastAPI(title="Drug Formulary Tool API")
logger = logging.getLogger(__name__)

@app.post("/tools/search-drug", response_model=DrugSearchResponse)
async def search_drug_endpoint(request: DrugSearchRequest) -> DrugSearchResponse:
    """
    Tool endpoint called by the AI agent.
    Accepts validated input, returns structured JSON.
    """
    try:
        result = await search_drug_database(request)
        logger.info(
            "Drug search completed",
            extra={
                "query": request.query,
                "formulary": request.formulary,
                "results_returned": result.returned,
                "total_found": result.total_found
            }
        )
        return result
    except Exception as e:
        logger.error("Drug search failed: %s", e, extra={"query": request.query})
        raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")

Step 5: Connect Tool to the LLM Agent

Python
# agent.py
import json
import httpx
import openai
from pydantic import ValidationError
from models import DrugSearchRequest

client = openai.OpenAI()
TOOL_API_BASE = "http://localhost:8000"

# The schema from Step 1
tools = [SEARCH_DRUG_DATABASE_SCHEMA]

async def call_search_drug_tool(raw_args: dict) -> dict:
    """
    Validate LLM arguments and call the tool API.
    Returns a dict (which will be JSON-serialized for the LLM).
    """
    # Validate before calling
    try:
        request = DrugSearchRequest(**raw_args)
    except ValidationError as e:
        return {
            "error": "Invalid tool arguments",
            "details": e.errors(),
            "hint": "Check parameter types and required fields"
        }

    async with httpx.AsyncClient() as http:
        try:
            response = await http.post(
                f"{TOOL_API_BASE}/tools/search-drug",
                json=request.model_dump(),
                timeout=10.0
            )
            response.raise_for_status()
            return response.json()
        except httpx.TimeoutException:
            return {"error": "Tool request timed out. Try again or narrow your search."}
        except httpx.HTTPStatusError as e:
            return {"error": f"Tool API error: {e.response.status_code}", "detail": e.response.text}

async def run_drug_agent(user_question: str) -> str:
    messages = [
        {
            "role": "system",
            "content": (
                "You are a hospital pharmacy assistant. "
                "Use the search_drug_database tool to look up drugs in our formulary. "
                "Always verify drug availability before confirming to clinical staff."
            )
        },
        {"role": "user", "content": user_question}
    ]

    response = client.chat.completions.create(
        model="gpt-4o",
        messages=messages,
        tools=tools,
        tool_choice="auto"
    )

    msg = response.choices[0].message

    if not msg.tool_calls:
        return msg.content

    messages.append(msg)

    for tc in msg.tool_calls:
        raw_args = json.loads(tc.function.arguments)
        result = await call_search_drug_tool(raw_args)
        messages.append({
            "role": "tool",
            "tool_call_id": tc.id,
            "content": json.dumps(result)
        })

    final = client.chat.completions.create(
        model="gpt-4o",
        messages=messages,
        tools=tools
    )
    return final.choices[0].message.content

# Run it
import asyncio

answer = asyncio.run(run_drug_agent(
    "Is Metformin available in the hospital formulary?"
))
print(answer)

Step 6: Write Tests

Python
# tests/test_drug_search_tool.py
import pytest
import json
from unittest.mock import AsyncMock, patch
from models import DrugSearchRequest, FormularyType
from tools.drug_search import search_drug_database

@pytest.mark.asyncio
async def test_basic_search_returns_results():
    mock_rows = [
        {
            "drug_id": "D-001",
            "brand_name": "Glucophage",
            "generic_name": "Metformin",
            "therapeutic_class": "Biguanide",
            "formulary_type": "hospital",
            "is_active": True,
            "strength": "500mg"
        }
    ]

    with patch("tools.drug_search.asyncpg.connect") as mock_connect:
        mock_conn = AsyncMock()
        mock_connect.return_value = mock_conn
        mock_conn.fetchval.return_value = 1
        mock_conn.fetch.return_value = mock_rows

        request = DrugSearchRequest(query="metformin", formulary=FormularyType.hospital)
        result = await search_drug_database(request)

        assert result.total_found == 1
        assert result.returned == 1
        assert result.results[0].generic_name == "Metformin"
        assert not result.has_more

@pytest.mark.asyncio
async def test_empty_results():
    with patch("tools.drug_search.asyncpg.connect") as mock_connect:
        mock_conn = AsyncMock()
        mock_connect.return_value = mock_conn
        mock_conn.fetchval.return_value = 0
        mock_conn.fetch.return_value = []

        request = DrugSearchRequest(query="nonexistentdrug", formulary=FormularyType.all)
        result = await search_drug_database(request)

        assert result.total_found == 0
        assert result.results == []

def test_input_validation_rejects_short_query():
    with pytest.raises(Exception):
        DrugSearchRequest(query="x", formulary=FormularyType.hospital)  # Too short

def test_input_validation_rejects_large_limit():
    with pytest.raises(Exception):
        DrugSearchRequest(query="metformin", formulary=FormularyType.all, limit=100)  # Over max

def test_tool_schema_is_valid_json():
    """Ensure the schema serializes without error."""
    schema_json = json.dumps(SEARCH_DRUG_DATABASE_SCHEMA)
    parsed = json.loads(schema_json)
    assert parsed["function"]["name"] == "search_drug_database"
    assert "query" in parsed["function"]["parameters"]["properties"]
    assert "query" in parsed["function"]["parameters"]["required"]

Full File Structure

drug_agent/
ā”œā”€ā”€ main.py              # FastAPI app
ā”œā”€ā”€ agent.py             # LLM agent loop
ā”œā”€ā”€ models.py            # Pydantic models
ā”œā”€ā”€ tools/
│   ā”œā”€ā”€ __init__.py
│   ā”œā”€ā”€ drug_search.py   # Tool implementation
│   └── schemas.py       # JSON Schema definitions
└── tests/
    ā”œā”€ā”€ __init__.py
    └── test_drug_search_tool.py

Key Principles This Example Demonstrates

  1. Schema-first design — Write the JSON Schema before any code. It is the contract between the LLM and your tool.

  2. Validate at the boundary — Parse and validate LLM arguments with Pydantic before any I/O happens. The LLM can and will pass invalid arguments.

  3. Structured output only — Tools should return dicts that serialize to clean JSON. Never return bare strings with embedded data.

  4. Errors are data — When the tool fails, return a structured error dict. The LLM will incorporate it and explain the failure to the user.

  5. Separate tool from agent — The tool function (database query) and the agent loop (LLM calls) are independent. Test them separately.

  6. Least privilege at the DB level — The readonly_user in the connection string can only SELECT. Even if the LLM were somehow manipulated to attempt a write, it cannot happen.

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:š•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.