Learnixo
Back to blog
AI Systemsintermediate

Callbacks: Hooking into LangChain Events

Use LangChain callbacks for logging, cost tracking, streaming progress, and custom observability. Implement BaseCallbackHandler for chain, LLM, and tool events.

Asma Hafeez KhanMay 16, 20266 min read
LangChainCallbacksLoggingObservabilityStreamingCost Tracking
Share:š•

What are Callbacks?

Callbacks let you hook into LangChain's event lifecycle without modifying chain logic. They fire at:

  • LLM calls start and end
  • Chain steps start and end
  • Tool invocations start and end
  • Token-level streaming events
  • Errors at any stage

This lets you add logging, cost tracking, streaming UI updates, and custom observability — all outside your main chain code.


BaseCallbackHandler: The Core Interface

Python
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from typing import Any, Union
import time

class TimingCallbackHandler(BaseCallbackHandler):
    """Track latency for each LLM call and each chain step."""

    def __init__(self):
        self._llm_start: float = 0
        self._chain_starts: dict[str, float] = {}
        self.llm_calls: list[dict] = []

    def on_llm_start(self, serialized: dict, prompts: list[str], **kwargs) -> None:
        self._llm_start = time.time()
        model_name = serialized.get("kwargs", {}).get("model_name", "unknown")
        print(f"[LLM] Starting: {model_name}")
        print(f"[LLM] Prompt length: {sum(len(p) for p in prompts)} chars")

    def on_llm_end(self, response: LLMResult, **kwargs) -> None:
        latency_ms = round((time.time() - self._llm_start) * 1000)
        token_usage = response.llm_output.get("token_usage", {})
        
        self.llm_calls.append({
            "latency_ms": latency_ms,
            "prompt_tokens": token_usage.get("prompt_tokens", 0),
            "completion_tokens": token_usage.get("completion_tokens", 0),
        })
        print(f"[LLM] Done in {latency_ms}ms | Tokens: {token_usage}")

    def on_llm_error(self, error: Exception, **kwargs) -> None:
        print(f"[LLM] Error: {error}")

    def on_chain_start(self, serialized: dict, inputs: dict, run_id, **kwargs) -> None:
        chain_name = serialized.get("name", "unknown")
        self._chain_starts[str(run_id)] = time.time()
        print(f"[Chain] Start: {chain_name}")

    def on_chain_end(self, outputs: dict, run_id, **kwargs) -> None:
        start = self._chain_starts.pop(str(run_id), time.time())
        latency_ms = round((time.time() - start) * 1000)
        print(f"[Chain] Done in {latency_ms}ms")

    def on_chain_error(self, error: Exception, **kwargs) -> None:
        print(f"[Chain] Error: {error}")

    def on_tool_start(self, serialized: dict, input_str: str, **kwargs) -> None:
        tool_name = serialized.get("name", "unknown")
        print(f"[Tool] Calling: {tool_name}({input_str[:80]})")

    def on_tool_end(self, output: str, **kwargs) -> None:
        print(f"[Tool] Result: {output[:100]}")

    def on_tool_error(self, error: Exception, **kwargs) -> None:
        print(f"[Tool] Error: {error}")

Attaching Callbacks

Callbacks can be attached at three levels:

Python
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableConfig

timing_callback = TimingCallbackHandler()

# Level 1: Per-invocation (most common)
chain = (
    ChatPromptTemplate.from_messages([("human", "{question}")])
    | ChatOpenAI(model="gpt-4o", temperature=0)
    | StrOutputParser()
)

result = chain.invoke(
    {"question": "What is warfarin?"},
    config=RunnableConfig(callbacks=[timing_callback]),
)

# Level 2: Constructor (always applies for this chain instance)
model_with_callback = ChatOpenAI(
    model="gpt-4o",
    temperature=0,
    callbacks=[timing_callback],   # Fires on every call to this model instance
)

# Level 3: Global (applies to all LangChain calls — use sparingly)
from langchain.callbacks import set_handler
# set_handler(timing_callback)   # Uncomment to enable globally

Cost Tracking Callback

Python
class CostTrackingCallback(BaseCallbackHandler):
    """Track token usage and estimated cost per run."""

    # Prices per 1M tokens (input / output)
    PRICING = {
        "gpt-4o": {"input": 2.50, "output": 10.00},
        "gpt-4o-mini": {"input": 0.15, "output": 0.60},
        "claude-sonnet-4-6": {"input": 3.00, "output": 15.00},
        "claude-haiku-4-5": {"input": 0.80, "output": 4.00},
    }

    def __init__(self):
        self.total_input_tokens = 0
        self.total_output_tokens = 0
        self.total_cost_usd = 0.0
        self._current_model = "gpt-4o"

    def on_llm_start(self, serialized: dict, prompts: list[str], **kwargs) -> None:
        self._current_model = serialized.get("kwargs", {}).get("model_name", "gpt-4o")

    def on_llm_end(self, response: LLMResult, **kwargs) -> None:
        usage = response.llm_output.get("token_usage", {})
        input_tokens = usage.get("prompt_tokens", 0)
        output_tokens = usage.get("completion_tokens", 0)
        
        pricing = self.PRICING.get(self._current_model, {"input": 2.50, "output": 10.00})
        cost = (
            input_tokens * pricing["input"] / 1_000_000 +
            output_tokens * pricing["output"] / 1_000_000
        )
        
        self.total_input_tokens += input_tokens
        self.total_output_tokens += output_tokens
        self.total_cost_usd += cost

    def report(self) -> dict:
        return {
            "total_input_tokens": self.total_input_tokens,
            "total_output_tokens": self.total_output_tokens,
            "total_tokens": self.total_input_tokens + self.total_output_tokens,
            "total_cost_usd": round(self.total_cost_usd, 6),
        }


cost_tracker = CostTrackingCallback()

chain.invoke({"question": "What is warfarin?"}, config=RunnableConfig(callbacks=[cost_tracker]))
chain.invoke({"question": "What is metformin?"}, config=RunnableConfig(callbacks=[cost_tracker]))

print(cost_tracker.report())
# {'total_input_tokens': 145, 'total_output_tokens': 89, 'total_tokens': 234, 'total_cost_usd': 0.001257}

Streaming Callback

Stream tokens to a UI or websocket as they're generated:

Python
from langchain_core.callbacks import StreamingStdOutCallbackHandler

# Built-in: prints tokens to stdout as they arrive
streaming_handler = StreamingStdOutCallbackHandler()

streaming_model = ChatOpenAI(
    model="gpt-4o",
    streaming=True,
    callbacks=[streaming_handler],
)

# Tokens appear as they arrive
streaming_model.invoke("Explain warfarin mechanism in detail.")


# Custom: stream to a queue (for WebSocket or SSE)
import queue

class StreamToQueueCallback(BaseCallbackHandler):
    """Stream LLM tokens to an asyncio-compatible queue."""

    def __init__(self, q: queue.Queue):
        self.q = q

    def on_llm_new_token(self, token: str, **kwargs) -> None:
        self.q.put(token)

    def on_llm_end(self, response: LLMResult, **kwargs) -> None:
        self.q.put(None)   # Sentinel: signal stream is complete

    def on_llm_error(self, error: Exception, **kwargs) -> None:
        self.q.put(Exception(str(error)))


# Usage with a web framework
token_queue = queue.Queue()
stream_callback = StreamToQueueCallback(token_queue)

import threading

def run_chain_in_thread():
    ChatOpenAI(model="gpt-4o", streaming=True, callbacks=[stream_callback]).invoke(
        "Explain warfarin mechanism."
    )

thread = threading.Thread(target=run_chain_in_thread)
thread.start()

# Consumer: yield tokens to a WebSocket
while True:
    token = token_queue.get()
    if token is None:
        break
    if isinstance(token, Exception):
        raise token
    print(token, end="", flush=True)

Audit Logging Callback

Python
import logging
import json

logger = logging.getLogger("langchain_audit")

class AuditCallback(BaseCallbackHandler):
    """Structured JSON audit log for every LLM call and tool use."""

    def __init__(self, session_id: str, user_id: str):
        self.session_id = session_id
        self.user_id = user_id

    def _log(self, event: str, data: dict) -> None:
        logger.info(json.dumps({
            "event": event,
            "session_id": self.session_id,
            "user_id": self.user_id,
            **data,
        }))

    def on_llm_start(self, serialized: dict, prompts: list[str], **kwargs) -> None:
        self._log("llm_call_start", {
            "model": serialized.get("kwargs", {}).get("model_name"),
            "prompt_chars": sum(len(p) for p in prompts),
        })

    def on_llm_end(self, response: LLMResult, **kwargs) -> None:
        usage = response.llm_output.get("token_usage", {})
        self._log("llm_call_end", {
            "prompt_tokens": usage.get("prompt_tokens"),
            "completion_tokens": usage.get("completion_tokens"),
        })

    def on_tool_start(self, serialized: dict, input_str: str, **kwargs) -> None:
        self._log("tool_call", {
            "tool": serialized.get("name"),
            "input": input_str[:200],   # Truncate for log size
        })

    def on_tool_end(self, output: str, **kwargs) -> None:
        self._log("tool_result", {
            "output_chars": len(str(output)),
        })


# Attach per request
audit = AuditCallback(session_id="sess_123", user_id="pharmacist_456")
result = chain.invoke(
    {"question": "What is the warfarin dose for AFib?"},
    config=RunnableConfig(callbacks=[audit]),
)

Chaining Multiple Callbacks

Python
# Combine callbacks — each receives all events
callbacks = [
    TimingCallbackHandler(),
    CostTrackingCallback(),
    AuditCallback(session_id="sess_abc", user_id="user_1"),
]

result = chain.invoke(
    {"question": "What are warfarin interactions?"},
    config=RunnableConfig(callbacks=callbacks),
)

Callback Events Reference

| Method | Fires When | Key Args | |---|---|---| | on_llm_start | LLM call begins | serialized (model info), prompts | | on_llm_new_token | Each streamed token arrives | token: str | | on_llm_end | LLM call completes | response: LLMResult (includes token usage) | | on_llm_error | LLM call fails | error: Exception | | on_chain_start | Chain step begins | serialized, inputs: dict, run_id | | on_chain_end | Chain step completes | outputs: dict, run_id | | on_chain_error | Chain step fails | error: Exception | | on_tool_start | Tool invocation begins | serialized (tool info), input_str | | on_tool_end | Tool invocation completes | output: str | | on_tool_error | Tool invocation fails | error: Exception | | on_retriever_start | Retriever query begins | serialized, query | | on_retriever_end | Retriever returns docs | documents: list[Document] |

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:š•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.