LangChain Mastery · Lesson 31 of 33
Callbacks: Monitoring Tokens and Latency
What are Callbacks?
Callbacks let you hook into LangChain's event lifecycle without modifying chain logic. They fire at:
- LLM calls start and end
- Chain steps start and end
- Tool invocations start and end
- Token-level streaming events
- Errors at any stage
This lets you add logging, cost tracking, streaming UI updates, and custom observability — all outside your main chain code.
BaseCallbackHandler: The Core Interface
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from typing import Any, Union
import time
class TimingCallbackHandler(BaseCallbackHandler):
"""Track latency for each LLM call and each chain step."""
def __init__(self):
self._llm_start: float = 0
self._chain_starts: dict[str, float] = {}
self.llm_calls: list[dict] = []
def on_llm_start(self, serialized: dict, prompts: list[str], **kwargs) -> None:
self._llm_start = time.time()
model_name = serialized.get("kwargs", {}).get("model_name", "unknown")
print(f"[LLM] Starting: {model_name}")
print(f"[LLM] Prompt length: {sum(len(p) for p in prompts)} chars")
def on_llm_end(self, response: LLMResult, **kwargs) -> None:
latency_ms = round((time.time() - self._llm_start) * 1000)
token_usage = response.llm_output.get("token_usage", {})
self.llm_calls.append({
"latency_ms": latency_ms,
"prompt_tokens": token_usage.get("prompt_tokens", 0),
"completion_tokens": token_usage.get("completion_tokens", 0),
})
print(f"[LLM] Done in {latency_ms}ms | Tokens: {token_usage}")
def on_llm_error(self, error: Exception, **kwargs) -> None:
print(f"[LLM] Error: {error}")
def on_chain_start(self, serialized: dict, inputs: dict, run_id, **kwargs) -> None:
chain_name = serialized.get("name", "unknown")
self._chain_starts[str(run_id)] = time.time()
print(f"[Chain] Start: {chain_name}")
def on_chain_end(self, outputs: dict, run_id, **kwargs) -> None:
start = self._chain_starts.pop(str(run_id), time.time())
latency_ms = round((time.time() - start) * 1000)
print(f"[Chain] Done in {latency_ms}ms")
def on_chain_error(self, error: Exception, **kwargs) -> None:
print(f"[Chain] Error: {error}")
def on_tool_start(self, serialized: dict, input_str: str, **kwargs) -> None:
tool_name = serialized.get("name", "unknown")
print(f"[Tool] Calling: {tool_name}({input_str[:80]})")
def on_tool_end(self, output: str, **kwargs) -> None:
print(f"[Tool] Result: {output[:100]}")
def on_tool_error(self, error: Exception, **kwargs) -> None:
print(f"[Tool] Error: {error}")Attaching Callbacks
Callbacks can be attached at three levels:
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableConfig
timing_callback = TimingCallbackHandler()
# Level 1: Per-invocation (most common)
chain = (
ChatPromptTemplate.from_messages([("human", "{question}")])
| ChatOpenAI(model="gpt-4o", temperature=0)
| StrOutputParser()
)
result = chain.invoke(
{"question": "What is warfarin?"},
config=RunnableConfig(callbacks=[timing_callback]),
)
# Level 2: Constructor (always applies for this chain instance)
model_with_callback = ChatOpenAI(
model="gpt-4o",
temperature=0,
callbacks=[timing_callback], # Fires on every call to this model instance
)
# Level 3: Global (applies to all LangChain calls — use sparingly)
from langchain.callbacks import set_handler
# set_handler(timing_callback) # Uncomment to enable globallyCost Tracking Callback
class CostTrackingCallback(BaseCallbackHandler):
"""Track token usage and estimated cost per run."""
# Prices per 1M tokens (input / output)
PRICING = {
"gpt-4o": {"input": 2.50, "output": 10.00},
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
"claude-sonnet-4-6": {"input": 3.00, "output": 15.00},
"claude-haiku-4-5": {"input": 0.80, "output": 4.00},
}
def __init__(self):
self.total_input_tokens = 0
self.total_output_tokens = 0
self.total_cost_usd = 0.0
self._current_model = "gpt-4o"
def on_llm_start(self, serialized: dict, prompts: list[str], **kwargs) -> None:
self._current_model = serialized.get("kwargs", {}).get("model_name", "gpt-4o")
def on_llm_end(self, response: LLMResult, **kwargs) -> None:
usage = response.llm_output.get("token_usage", {})
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
pricing = self.PRICING.get(self._current_model, {"input": 2.50, "output": 10.00})
cost = (
input_tokens * pricing["input"] / 1_000_000 +
output_tokens * pricing["output"] / 1_000_000
)
self.total_input_tokens += input_tokens
self.total_output_tokens += output_tokens
self.total_cost_usd += cost
def report(self) -> dict:
return {
"total_input_tokens": self.total_input_tokens,
"total_output_tokens": self.total_output_tokens,
"total_tokens": self.total_input_tokens + self.total_output_tokens,
"total_cost_usd": round(self.total_cost_usd, 6),
}
cost_tracker = CostTrackingCallback()
chain.invoke({"question": "What is warfarin?"}, config=RunnableConfig(callbacks=[cost_tracker]))
chain.invoke({"question": "What is metformin?"}, config=RunnableConfig(callbacks=[cost_tracker]))
print(cost_tracker.report())
# {'total_input_tokens': 145, 'total_output_tokens': 89, 'total_tokens': 234, 'total_cost_usd': 0.001257}Streaming Callback
Stream tokens to a UI or websocket as they're generated:
from langchain_core.callbacks import StreamingStdOutCallbackHandler
# Built-in: prints tokens to stdout as they arrive
streaming_handler = StreamingStdOutCallbackHandler()
streaming_model = ChatOpenAI(
model="gpt-4o",
streaming=True,
callbacks=[streaming_handler],
)
# Tokens appear as they arrive
streaming_model.invoke("Explain warfarin mechanism in detail.")
# Custom: stream to a queue (for WebSocket or SSE)
import queue
class StreamToQueueCallback(BaseCallbackHandler):
"""Stream LLM tokens to an asyncio-compatible queue."""
def __init__(self, q: queue.Queue):
self.q = q
def on_llm_new_token(self, token: str, **kwargs) -> None:
self.q.put(token)
def on_llm_end(self, response: LLMResult, **kwargs) -> None:
self.q.put(None) # Sentinel: signal stream is complete
def on_llm_error(self, error: Exception, **kwargs) -> None:
self.q.put(Exception(str(error)))
# Usage with a web framework
token_queue = queue.Queue()
stream_callback = StreamToQueueCallback(token_queue)
import threading
def run_chain_in_thread():
ChatOpenAI(model="gpt-4o", streaming=True, callbacks=[stream_callback]).invoke(
"Explain warfarin mechanism."
)
thread = threading.Thread(target=run_chain_in_thread)
thread.start()
# Consumer: yield tokens to a WebSocket
while True:
token = token_queue.get()
if token is None:
break
if isinstance(token, Exception):
raise token
print(token, end="", flush=True)Audit Logging Callback
import logging
import json
logger = logging.getLogger("langchain_audit")
class AuditCallback(BaseCallbackHandler):
"""Structured JSON audit log for every LLM call and tool use."""
def __init__(self, session_id: str, user_id: str):
self.session_id = session_id
self.user_id = user_id
def _log(self, event: str, data: dict) -> None:
logger.info(json.dumps({
"event": event,
"session_id": self.session_id,
"user_id": self.user_id,
**data,
}))
def on_llm_start(self, serialized: dict, prompts: list[str], **kwargs) -> None:
self._log("llm_call_start", {
"model": serialized.get("kwargs", {}).get("model_name"),
"prompt_chars": sum(len(p) for p in prompts),
})
def on_llm_end(self, response: LLMResult, **kwargs) -> None:
usage = response.llm_output.get("token_usage", {})
self._log("llm_call_end", {
"prompt_tokens": usage.get("prompt_tokens"),
"completion_tokens": usage.get("completion_tokens"),
})
def on_tool_start(self, serialized: dict, input_str: str, **kwargs) -> None:
self._log("tool_call", {
"tool": serialized.get("name"),
"input": input_str[:200], # Truncate for log size
})
def on_tool_end(self, output: str, **kwargs) -> None:
self._log("tool_result", {
"output_chars": len(str(output)),
})
# Attach per request
audit = AuditCallback(session_id="sess_123", user_id="pharmacist_456")
result = chain.invoke(
{"question": "What is the warfarin dose for AFib?"},
config=RunnableConfig(callbacks=[audit]),
)Chaining Multiple Callbacks
# Combine callbacks — each receives all events
callbacks = [
TimingCallbackHandler(),
CostTrackingCallback(),
AuditCallback(session_id="sess_abc", user_id="user_1"),
]
result = chain.invoke(
{"question": "What are warfarin interactions?"},
config=RunnableConfig(callbacks=callbacks),
)Callback Events Reference
| Method | Fires When | Key Args |
|---|---|---|
| on_llm_start | LLM call begins | serialized (model info), prompts |
| on_llm_new_token | Each streamed token arrives | token: str |
| on_llm_end | LLM call completes | response: LLMResult (includes token usage) |
| on_llm_error | LLM call fails | error: Exception |
| on_chain_start | Chain step begins | serialized, inputs: dict, run_id |
| on_chain_end | Chain step completes | outputs: dict, run_id |
| on_chain_error | Chain step fails | error: Exception |
| on_tool_start | Tool invocation begins | serialized (tool info), input_str |
| on_tool_end | Tool invocation completes | output: str |
| on_tool_error | Tool invocation fails | error: Exception |
| on_retriever_start | Retriever query begins | serialized, query |
| on_retriever_end | Retriever returns docs | documents: list[Document] |