Learnixo

LangChain Mastery · Lesson 31 of 33

Callbacks: Monitoring Tokens and Latency

What are Callbacks?

Callbacks let you hook into LangChain's event lifecycle without modifying chain logic. They fire at:

  • LLM calls start and end
  • Chain steps start and end
  • Tool invocations start and end
  • Token-level streaming events
  • Errors at any stage

This lets you add logging, cost tracking, streaming UI updates, and custom observability — all outside your main chain code.


BaseCallbackHandler: The Core Interface

Python
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from typing import Any, Union
import time

class TimingCallbackHandler(BaseCallbackHandler):
    """Track latency for each LLM call and each chain step."""

    def __init__(self):
        self._llm_start: float = 0
        self._chain_starts: dict[str, float] = {}
        self.llm_calls: list[dict] = []

    def on_llm_start(self, serialized: dict, prompts: list[str], **kwargs) -> None:
        self._llm_start = time.time()
        model_name = serialized.get("kwargs", {}).get("model_name", "unknown")
        print(f"[LLM] Starting: {model_name}")
        print(f"[LLM] Prompt length: {sum(len(p) for p in prompts)} chars")

    def on_llm_end(self, response: LLMResult, **kwargs) -> None:
        latency_ms = round((time.time() - self._llm_start) * 1000)
        token_usage = response.llm_output.get("token_usage", {})
        
        self.llm_calls.append({
            "latency_ms": latency_ms,
            "prompt_tokens": token_usage.get("prompt_tokens", 0),
            "completion_tokens": token_usage.get("completion_tokens", 0),
        })
        print(f"[LLM] Done in {latency_ms}ms | Tokens: {token_usage}")

    def on_llm_error(self, error: Exception, **kwargs) -> None:
        print(f"[LLM] Error: {error}")

    def on_chain_start(self, serialized: dict, inputs: dict, run_id, **kwargs) -> None:
        chain_name = serialized.get("name", "unknown")
        self._chain_starts[str(run_id)] = time.time()
        print(f"[Chain] Start: {chain_name}")

    def on_chain_end(self, outputs: dict, run_id, **kwargs) -> None:
        start = self._chain_starts.pop(str(run_id), time.time())
        latency_ms = round((time.time() - start) * 1000)
        print(f"[Chain] Done in {latency_ms}ms")

    def on_chain_error(self, error: Exception, **kwargs) -> None:
        print(f"[Chain] Error: {error}")

    def on_tool_start(self, serialized: dict, input_str: str, **kwargs) -> None:
        tool_name = serialized.get("name", "unknown")
        print(f"[Tool] Calling: {tool_name}({input_str[:80]})")

    def on_tool_end(self, output: str, **kwargs) -> None:
        print(f"[Tool] Result: {output[:100]}")

    def on_tool_error(self, error: Exception, **kwargs) -> None:
        print(f"[Tool] Error: {error}")

Attaching Callbacks

Callbacks can be attached at three levels:

Python
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableConfig

timing_callback = TimingCallbackHandler()

# Level 1: Per-invocation (most common)
chain = (
    ChatPromptTemplate.from_messages([("human", "{question}")])
    | ChatOpenAI(model="gpt-4o", temperature=0)
    | StrOutputParser()
)

result = chain.invoke(
    {"question": "What is warfarin?"},
    config=RunnableConfig(callbacks=[timing_callback]),
)

# Level 2: Constructor (always applies for this chain instance)
model_with_callback = ChatOpenAI(
    model="gpt-4o",
    temperature=0,
    callbacks=[timing_callback],   # Fires on every call to this model instance
)

# Level 3: Global (applies to all LangChain calls  use sparingly)
from langchain.callbacks import set_handler
# set_handler(timing_callback)   # Uncomment to enable globally

Cost Tracking Callback

Python
class CostTrackingCallback(BaseCallbackHandler):
    """Track token usage and estimated cost per run."""

    # Prices per 1M tokens (input / output)
    PRICING = {
        "gpt-4o": {"input": 2.50, "output": 10.00},
        "gpt-4o-mini": {"input": 0.15, "output": 0.60},
        "claude-sonnet-4-6": {"input": 3.00, "output": 15.00},
        "claude-haiku-4-5": {"input": 0.80, "output": 4.00},
    }

    def __init__(self):
        self.total_input_tokens = 0
        self.total_output_tokens = 0
        self.total_cost_usd = 0.0
        self._current_model = "gpt-4o"

    def on_llm_start(self, serialized: dict, prompts: list[str], **kwargs) -> None:
        self._current_model = serialized.get("kwargs", {}).get("model_name", "gpt-4o")

    def on_llm_end(self, response: LLMResult, **kwargs) -> None:
        usage = response.llm_output.get("token_usage", {})
        input_tokens = usage.get("prompt_tokens", 0)
        output_tokens = usage.get("completion_tokens", 0)
        
        pricing = self.PRICING.get(self._current_model, {"input": 2.50, "output": 10.00})
        cost = (
            input_tokens * pricing["input"] / 1_000_000 +
            output_tokens * pricing["output"] / 1_000_000
        )
        
        self.total_input_tokens += input_tokens
        self.total_output_tokens += output_tokens
        self.total_cost_usd += cost

    def report(self) -> dict:
        return {
            "total_input_tokens": self.total_input_tokens,
            "total_output_tokens": self.total_output_tokens,
            "total_tokens": self.total_input_tokens + self.total_output_tokens,
            "total_cost_usd": round(self.total_cost_usd, 6),
        }


cost_tracker = CostTrackingCallback()

chain.invoke({"question": "What is warfarin?"}, config=RunnableConfig(callbacks=[cost_tracker]))
chain.invoke({"question": "What is metformin?"}, config=RunnableConfig(callbacks=[cost_tracker]))

print(cost_tracker.report())
# {'total_input_tokens': 145, 'total_output_tokens': 89, 'total_tokens': 234, 'total_cost_usd': 0.001257}

Streaming Callback

Stream tokens to a UI or websocket as they're generated:

Python
from langchain_core.callbacks import StreamingStdOutCallbackHandler

# Built-in: prints tokens to stdout as they arrive
streaming_handler = StreamingStdOutCallbackHandler()

streaming_model = ChatOpenAI(
    model="gpt-4o",
    streaming=True,
    callbacks=[streaming_handler],
)

# Tokens appear as they arrive
streaming_model.invoke("Explain warfarin mechanism in detail.")


# Custom: stream to a queue (for WebSocket or SSE)
import queue

class StreamToQueueCallback(BaseCallbackHandler):
    """Stream LLM tokens to an asyncio-compatible queue."""

    def __init__(self, q: queue.Queue):
        self.q = q

    def on_llm_new_token(self, token: str, **kwargs) -> None:
        self.q.put(token)

    def on_llm_end(self, response: LLMResult, **kwargs) -> None:
        self.q.put(None)   # Sentinel: signal stream is complete

    def on_llm_error(self, error: Exception, **kwargs) -> None:
        self.q.put(Exception(str(error)))


# Usage with a web framework
token_queue = queue.Queue()
stream_callback = StreamToQueueCallback(token_queue)

import threading

def run_chain_in_thread():
    ChatOpenAI(model="gpt-4o", streaming=True, callbacks=[stream_callback]).invoke(
        "Explain warfarin mechanism."
    )

thread = threading.Thread(target=run_chain_in_thread)
thread.start()

# Consumer: yield tokens to a WebSocket
while True:
    token = token_queue.get()
    if token is None:
        break
    if isinstance(token, Exception):
        raise token
    print(token, end="", flush=True)

Audit Logging Callback

Python
import logging
import json

logger = logging.getLogger("langchain_audit")

class AuditCallback(BaseCallbackHandler):
    """Structured JSON audit log for every LLM call and tool use."""

    def __init__(self, session_id: str, user_id: str):
        self.session_id = session_id
        self.user_id = user_id

    def _log(self, event: str, data: dict) -> None:
        logger.info(json.dumps({
            "event": event,
            "session_id": self.session_id,
            "user_id": self.user_id,
            **data,
        }))

    def on_llm_start(self, serialized: dict, prompts: list[str], **kwargs) -> None:
        self._log("llm_call_start", {
            "model": serialized.get("kwargs", {}).get("model_name"),
            "prompt_chars": sum(len(p) for p in prompts),
        })

    def on_llm_end(self, response: LLMResult, **kwargs) -> None:
        usage = response.llm_output.get("token_usage", {})
        self._log("llm_call_end", {
            "prompt_tokens": usage.get("prompt_tokens"),
            "completion_tokens": usage.get("completion_tokens"),
        })

    def on_tool_start(self, serialized: dict, input_str: str, **kwargs) -> None:
        self._log("tool_call", {
            "tool": serialized.get("name"),
            "input": input_str[:200],   # Truncate for log size
        })

    def on_tool_end(self, output: str, **kwargs) -> None:
        self._log("tool_result", {
            "output_chars": len(str(output)),
        })


# Attach per request
audit = AuditCallback(session_id="sess_123", user_id="pharmacist_456")
result = chain.invoke(
    {"question": "What is the warfarin dose for AFib?"},
    config=RunnableConfig(callbacks=[audit]),
)

Chaining Multiple Callbacks

Python
# Combine callbacks  each receives all events
callbacks = [
    TimingCallbackHandler(),
    CostTrackingCallback(),
    AuditCallback(session_id="sess_abc", user_id="user_1"),
]

result = chain.invoke(
    {"question": "What are warfarin interactions?"},
    config=RunnableConfig(callbacks=callbacks),
)

Callback Events Reference

| Method | Fires When | Key Args | |---|---|---| | on_llm_start | LLM call begins | serialized (model info), prompts | | on_llm_new_token | Each streamed token arrives | token: str | | on_llm_end | LLM call completes | response: LLMResult (includes token usage) | | on_llm_error | LLM call fails | error: Exception | | on_chain_start | Chain step begins | serialized, inputs: dict, run_id | | on_chain_end | Chain step completes | outputs: dict, run_id | | on_chain_error | Chain step fails | error: Exception | | on_tool_start | Tool invocation begins | serialized (tool info), input_str | | on_tool_end | Tool invocation completes | output: str | | on_tool_error | Tool invocation fails | error: Exception | | on_retriever_start | Retriever query begins | serialized, query | | on_retriever_end | Retriever returns docs | documents: list[Document] |