AI Systemsintermediate
Building a RAG Chain with LangChain
Build retrieval-augmented generation chains with LCEL. Covers basic RAG, conversational RAG with history, source citation, streaming, and production patterns.
Asma Hafeez KhanMay 16, 20266 min read
LangChainRAGLCELRetrievalConversational RAGChat History
Basic RAG Chain
A RAG chain retrieves relevant documents, formats them into a prompt, and passes the combined context to the LLM.
Python
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.documents import Document
# Setup: vector store with drug information
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
vectorstore = Chroma(
collection_name="drug_formulary",
embedding_function=embeddings,
persist_directory="./chroma_db",
)
retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
# Format retrieved docs into a single context string
def format_docs(docs: list[Document]) -> str:
return "\n\n".join(
f"[Source {i+1}]: {doc.page_content}"
for i, doc in enumerate(docs)
)
# RAG prompt
prompt = ChatPromptTemplate.from_messages([
("system",
"You are a clinical pharmacist. Answer questions using only the retrieved context below. "
"If the answer is not in the context, say 'I don't have enough information on this.' "
"Never invent drug information.\n\nContext:\n{context}"),
("human", "{question}"),
])
model = ChatOpenAI(model="gpt-4o", temperature=0)
# LCEL RAG chain
rag_chain = (
{
"context": retriever | format_docs,
"question": RunnablePassthrough(),
}
| prompt
| model
| StrOutputParser()
)
answer = rag_chain.invoke("What does warfarin inhibit?")
print(answer)RAG Chain with Source Citations
Return both the answer and the sources it came from:
Python
from langchain_core.runnables import RunnableParallel
# Build a chain that returns answer + source documents together
rag_chain_with_sources = RunnableParallel(
answer=rag_chain,
sources=retriever, # Returns the raw Document objects in parallel
)
result = rag_chain_with_sources.invoke("How is warfarin monitored?")
print(f"Answer: {result['answer']}\n")
print("Sources:")
for doc in result["sources"]:
source = doc.metadata.get("source", "Unknown")
page = doc.metadata.get("page", "")
print(f" - {source}" + (f", page {page}" if page else ""))
# Inline citations: include source numbers in the answer
citation_prompt = ChatPromptTemplate.from_messages([
("system",
"You are a clinical pharmacist. Answer using the retrieved context. "
"Cite sources using [1], [2], etc. inline in your answer. "
"Only cite sources that actually support the statement.\n\n"
"Context:\n{context}"),
("human", "{question}"),
])
def format_docs_numbered(docs: list[Document]) -> str:
return "\n\n".join(
f"[{i+1}] {doc.page_content}\n(Source: {doc.metadata.get('source', 'unknown')})"
for i, doc in enumerate(docs)
)
citation_chain = (
{
"context": retriever | format_docs_numbered,
"question": RunnablePassthrough(),
}
| citation_prompt
| model
| StrOutputParser()
)
answer = citation_chain.invoke("What monitoring is required for warfarin?")
# Answer: "Warfarin requires regular INR monitoring [1]. Frequency depends on stability [2]."Conversational RAG
A basic RAG chain has no memory — each question is independent. Conversational RAG rewrites follow-up questions using chat history before retrieving.
Python
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.messages import HumanMessage, AIMessage
# Step 1: History-aware retriever
# Takes (question + chat history) → rewrites question → retrieves docs
contextualize_q_prompt = ChatPromptTemplate.from_messages([
("system",
"Given the chat history and the latest question, reformulate the question "
"so it is self-contained (no references to 'it', 'that drug', 'the previous one'). "
"Do NOT answer — only reformulate the question if needed. Return it as-is if it's already clear."),
("placeholder", "{chat_history}"),
("human", "{input}"),
])
history_aware_retriever = create_history_aware_retriever(
llm=model,
retriever=retriever,
prompt=contextualize_q_prompt,
)
# Step 2: QA chain (uses retrieved docs to answer)
qa_prompt = ChatPromptTemplate.from_messages([
("system",
"You are a clinical pharmacist. Answer using the retrieved context only. "
"If unsure, say so — never invent drug information.\n\nContext:\n{context}"),
("placeholder", "{chat_history}"),
("human", "{input}"),
])
question_answer_chain = create_stuff_documents_chain(model, qa_prompt)
# Step 3: Combine into conversational RAG chain
conversational_rag = create_retrieval_chain(history_aware_retriever, question_answer_chain)
# Use with manual history tracking
chat_history = []
def chat(question: str) -> str:
result = conversational_rag.invoke({
"input": question,
"chat_history": chat_history,
})
# Update history
chat_history.extend([
HumanMessage(content=question),
AIMessage(content=result["answer"]),
])
return result["answer"]
# Multi-turn conversation
print(chat("What is warfarin?"))
print(chat("What are its main interactions?")) # "its" → resolved to "warfarin" by rewriter
print(chat("What about with aspirin specifically?")) # Follow-up also resolvedRAG with LangChain Memory
Attach memory to avoid managing chat_history manually:
Python
from langchain.memory import ConversationBufferWindowMemory
from langchain_core.runnables import RunnableLambda
from langchain_core.messages import BaseMessage
class RAGChatbot:
def __init__(self, retriever, model, k: int = 6):
self.retriever = retriever
self.model = model
self.memory = ConversationBufferWindowMemory(
k=k,
memory_key="chat_history",
return_messages=True,
output_key="answer",
)
self.chain = create_retrieval_chain(
create_history_aware_retriever(model, retriever, contextualize_q_prompt),
create_stuff_documents_chain(model, qa_prompt),
)
def chat(self, question: str) -> dict:
history = self.memory.load_memory_variables({})["chat_history"]
result = self.chain.invoke({
"input": question,
"chat_history": history,
})
self.memory.save_context(
{"input": question},
{"answer": result["answer"]},
)
return {
"answer": result["answer"],
"sources": [d.metadata.get("source", "?") for d in result.get("context", [])],
}
bot = RAGChatbot(retriever=retriever, model=model)
r1 = bot.chat("What is warfarin?")
r2 = bot.chat("What are its interactions with aspirin?")
print(r2["answer"])
print(f"Sources: {r2['sources']}")Streaming RAG
Stream the answer token by token as it's generated:
Python
# Stream from a basic RAG chain
for chunk in rag_chain.stream("How does warfarin interact with aspirin?"):
print(chunk, end="", flush=True)
print()
# Stream from conversational RAG (only stream the final answer, not internal steps)
async def stream_conversational(question: str, chat_history: list):
async for chunk in conversational_rag.astream({
"input": question,
"chat_history": chat_history,
}):
# The chain emits dict chunks at each step
if "answer" in chunk:
print(chunk["answer"], end="", flush=True)
print()
import asyncio
asyncio.run(stream_conversational("What is warfarin?", chat_history=[]))Multi-Turn RAG with Session Management
Python
from collections import defaultdict
class SessionRAGPipeline:
"""RAG chatbot with per-session conversation history."""
def __init__(self, retriever, model, max_history_turns: int = 8):
self.retriever = retriever
self.model = model
self.max_turns = max_history_turns
self._sessions: dict[str, list[BaseMessage]] = defaultdict(list)
self._chain = create_retrieval_chain(
create_history_aware_retriever(model, retriever, contextualize_q_prompt),
create_stuff_documents_chain(model, qa_prompt),
)
def _get_history(self, session_id: str) -> list[BaseMessage]:
history = self._sessions[session_id]
# Keep only last max_turns pairs
max_messages = self.max_turns * 2
return history[-max_messages:] if len(history) > max_messages else history
def chat(self, session_id: str, question: str) -> str:
history = self._get_history(session_id)
result = self._chain.invoke({
"input": question,
"chat_history": history,
})
self._sessions[session_id].extend([
HumanMessage(content=question),
AIMessage(content=result["answer"]),
])
return result["answer"]
def reset_session(self, session_id: str) -> None:
self._sessions.pop(session_id, None)
pipeline = SessionRAGPipeline(retriever=retriever, model=model)
# Two concurrent users, isolated histories
print(pipeline.chat("user_1", "What is warfarin?"))
print(pipeline.chat("user_2", "What is metformin?"))
print(pipeline.chat("user_1", "What are its interactions?")) # "its" = warfarin for user_1
print(pipeline.chat("user_2", "What are its side effects?")) # "its" = metformin for user_2Production RAG Pattern
Python
import time
import logging
logger = logging.getLogger("rag")
def production_rag_query(
question: str,
session_id: str,
pipeline: SessionRAGPipeline,
) -> dict:
start = time.time()
try:
# Input validation
if not question.strip():
return {"answer": "Please provide a question.", "success": False}
if len(question) > 1000:
return {"answer": "Question too long. Please shorten it.", "success": False}
answer = pipeline.chat(session_id, question)
latency_ms = round((time.time() - start) * 1000)
logger.info("rag_success", extra={
"session_id": session_id,
"latency_ms": latency_ms,
"question_chars": len(question),
})
return {
"answer": answer,
"latency_ms": latency_ms,
"success": True,
"disclaimer": "Verify all clinical information with current references before use.",
}
except Exception as e:
logger.error("rag_error", extra={"session_id": session_id, "error": str(e)})
return {
"answer": "I encountered an error. Please try rephrasing your question.",
"success": False,
}Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.