Callbacks: Hooking into LangChain Events
Use LangChain callbacks for logging, cost tracking, streaming progress, and custom observability. Implement BaseCallbackHandler for chain, LLM, and tool events.
What are Callbacks?
Callbacks let you hook into LangChain's event lifecycle without modifying chain logic. They fire at:
- LLM calls start and end
- Chain steps start and end
- Tool invocations start and end
- Token-level streaming events
- Errors at any stage
This lets you add logging, cost tracking, streaming UI updates, and custom observability ā all outside your main chain code.
BaseCallbackHandler: The Core Interface
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from typing import Any, Union
import time
class TimingCallbackHandler(BaseCallbackHandler):
"""Track latency for each LLM call and each chain step."""
def __init__(self):
self._llm_start: float = 0
self._chain_starts: dict[str, float] = {}
self.llm_calls: list[dict] = []
def on_llm_start(self, serialized: dict, prompts: list[str], **kwargs) -> None:
self._llm_start = time.time()
model_name = serialized.get("kwargs", {}).get("model_name", "unknown")
print(f"[LLM] Starting: {model_name}")
print(f"[LLM] Prompt length: {sum(len(p) for p in prompts)} chars")
def on_llm_end(self, response: LLMResult, **kwargs) -> None:
latency_ms = round((time.time() - self._llm_start) * 1000)
token_usage = response.llm_output.get("token_usage", {})
self.llm_calls.append({
"latency_ms": latency_ms,
"prompt_tokens": token_usage.get("prompt_tokens", 0),
"completion_tokens": token_usage.get("completion_tokens", 0),
})
print(f"[LLM] Done in {latency_ms}ms | Tokens: {token_usage}")
def on_llm_error(self, error: Exception, **kwargs) -> None:
print(f"[LLM] Error: {error}")
def on_chain_start(self, serialized: dict, inputs: dict, run_id, **kwargs) -> None:
chain_name = serialized.get("name", "unknown")
self._chain_starts[str(run_id)] = time.time()
print(f"[Chain] Start: {chain_name}")
def on_chain_end(self, outputs: dict, run_id, **kwargs) -> None:
start = self._chain_starts.pop(str(run_id), time.time())
latency_ms = round((time.time() - start) * 1000)
print(f"[Chain] Done in {latency_ms}ms")
def on_chain_error(self, error: Exception, **kwargs) -> None:
print(f"[Chain] Error: {error}")
def on_tool_start(self, serialized: dict, input_str: str, **kwargs) -> None:
tool_name = serialized.get("name", "unknown")
print(f"[Tool] Calling: {tool_name}({input_str[:80]})")
def on_tool_end(self, output: str, **kwargs) -> None:
print(f"[Tool] Result: {output[:100]}")
def on_tool_error(self, error: Exception, **kwargs) -> None:
print(f"[Tool] Error: {error}")Attaching Callbacks
Callbacks can be attached at three levels:
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableConfig
timing_callback = TimingCallbackHandler()
# Level 1: Per-invocation (most common)
chain = (
ChatPromptTemplate.from_messages([("human", "{question}")])
| ChatOpenAI(model="gpt-4o", temperature=0)
| StrOutputParser()
)
result = chain.invoke(
{"question": "What is warfarin?"},
config=RunnableConfig(callbacks=[timing_callback]),
)
# Level 2: Constructor (always applies for this chain instance)
model_with_callback = ChatOpenAI(
model="gpt-4o",
temperature=0,
callbacks=[timing_callback], # Fires on every call to this model instance
)
# Level 3: Global (applies to all LangChain calls ā use sparingly)
from langchain.callbacks import set_handler
# set_handler(timing_callback) # Uncomment to enable globallyCost Tracking Callback
class CostTrackingCallback(BaseCallbackHandler):
"""Track token usage and estimated cost per run."""
# Prices per 1M tokens (input / output)
PRICING = {
"gpt-4o": {"input": 2.50, "output": 10.00},
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
"claude-sonnet-4-6": {"input": 3.00, "output": 15.00},
"claude-haiku-4-5": {"input": 0.80, "output": 4.00},
}
def __init__(self):
self.total_input_tokens = 0
self.total_output_tokens = 0
self.total_cost_usd = 0.0
self._current_model = "gpt-4o"
def on_llm_start(self, serialized: dict, prompts: list[str], **kwargs) -> None:
self._current_model = serialized.get("kwargs", {}).get("model_name", "gpt-4o")
def on_llm_end(self, response: LLMResult, **kwargs) -> None:
usage = response.llm_output.get("token_usage", {})
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
pricing = self.PRICING.get(self._current_model, {"input": 2.50, "output": 10.00})
cost = (
input_tokens * pricing["input"] / 1_000_000 +
output_tokens * pricing["output"] / 1_000_000
)
self.total_input_tokens += input_tokens
self.total_output_tokens += output_tokens
self.total_cost_usd += cost
def report(self) -> dict:
return {
"total_input_tokens": self.total_input_tokens,
"total_output_tokens": self.total_output_tokens,
"total_tokens": self.total_input_tokens + self.total_output_tokens,
"total_cost_usd": round(self.total_cost_usd, 6),
}
cost_tracker = CostTrackingCallback()
chain.invoke({"question": "What is warfarin?"}, config=RunnableConfig(callbacks=[cost_tracker]))
chain.invoke({"question": "What is metformin?"}, config=RunnableConfig(callbacks=[cost_tracker]))
print(cost_tracker.report())
# {'total_input_tokens': 145, 'total_output_tokens': 89, 'total_tokens': 234, 'total_cost_usd': 0.001257}Streaming Callback
Stream tokens to a UI or websocket as they're generated:
from langchain_core.callbacks import StreamingStdOutCallbackHandler
# Built-in: prints tokens to stdout as they arrive
streaming_handler = StreamingStdOutCallbackHandler()
streaming_model = ChatOpenAI(
model="gpt-4o",
streaming=True,
callbacks=[streaming_handler],
)
# Tokens appear as they arrive
streaming_model.invoke("Explain warfarin mechanism in detail.")
# Custom: stream to a queue (for WebSocket or SSE)
import queue
class StreamToQueueCallback(BaseCallbackHandler):
"""Stream LLM tokens to an asyncio-compatible queue."""
def __init__(self, q: queue.Queue):
self.q = q
def on_llm_new_token(self, token: str, **kwargs) -> None:
self.q.put(token)
def on_llm_end(self, response: LLMResult, **kwargs) -> None:
self.q.put(None) # Sentinel: signal stream is complete
def on_llm_error(self, error: Exception, **kwargs) -> None:
self.q.put(Exception(str(error)))
# Usage with a web framework
token_queue = queue.Queue()
stream_callback = StreamToQueueCallback(token_queue)
import threading
def run_chain_in_thread():
ChatOpenAI(model="gpt-4o", streaming=True, callbacks=[stream_callback]).invoke(
"Explain warfarin mechanism."
)
thread = threading.Thread(target=run_chain_in_thread)
thread.start()
# Consumer: yield tokens to a WebSocket
while True:
token = token_queue.get()
if token is None:
break
if isinstance(token, Exception):
raise token
print(token, end="", flush=True)Audit Logging Callback
import logging
import json
logger = logging.getLogger("langchain_audit")
class AuditCallback(BaseCallbackHandler):
"""Structured JSON audit log for every LLM call and tool use."""
def __init__(self, session_id: str, user_id: str):
self.session_id = session_id
self.user_id = user_id
def _log(self, event: str, data: dict) -> None:
logger.info(json.dumps({
"event": event,
"session_id": self.session_id,
"user_id": self.user_id,
**data,
}))
def on_llm_start(self, serialized: dict, prompts: list[str], **kwargs) -> None:
self._log("llm_call_start", {
"model": serialized.get("kwargs", {}).get("model_name"),
"prompt_chars": sum(len(p) for p in prompts),
})
def on_llm_end(self, response: LLMResult, **kwargs) -> None:
usage = response.llm_output.get("token_usage", {})
self._log("llm_call_end", {
"prompt_tokens": usage.get("prompt_tokens"),
"completion_tokens": usage.get("completion_tokens"),
})
def on_tool_start(self, serialized: dict, input_str: str, **kwargs) -> None:
self._log("tool_call", {
"tool": serialized.get("name"),
"input": input_str[:200], # Truncate for log size
})
def on_tool_end(self, output: str, **kwargs) -> None:
self._log("tool_result", {
"output_chars": len(str(output)),
})
# Attach per request
audit = AuditCallback(session_id="sess_123", user_id="pharmacist_456")
result = chain.invoke(
{"question": "What is the warfarin dose for AFib?"},
config=RunnableConfig(callbacks=[audit]),
)Chaining Multiple Callbacks
# Combine callbacks ā each receives all events
callbacks = [
TimingCallbackHandler(),
CostTrackingCallback(),
AuditCallback(session_id="sess_abc", user_id="user_1"),
]
result = chain.invoke(
{"question": "What are warfarin interactions?"},
config=RunnableConfig(callbacks=callbacks),
)Callback Events Reference
| Method | Fires When | Key Args |
|---|---|---|
| on_llm_start | LLM call begins | serialized (model info), prompts |
| on_llm_new_token | Each streamed token arrives | token: str |
| on_llm_end | LLM call completes | response: LLMResult (includes token usage) |
| on_llm_error | LLM call fails | error: Exception |
| on_chain_start | Chain step begins | serialized, inputs: dict, run_id |
| on_chain_end | Chain step completes | outputs: dict, run_id |
| on_chain_error | Chain step fails | error: Exception |
| on_tool_start | Tool invocation begins | serialized (tool info), input_str |
| on_tool_end | Tool invocation completes | output: str |
| on_tool_error | Tool invocation fails | error: Exception |
| on_retriever_start | Retriever query begins | serialized, query |
| on_retriever_end | Retriever returns docs | documents: list[Document] |
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.