LangChain Mastery · Lesson 29 of 33
Building a Full RAG Chain with LCEL
Basic RAG Chain
A RAG chain retrieves relevant documents, formats them into a prompt, and passes the combined context to the LLM.
Python
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.documents import Document
# Setup: vector store with drug information
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
vectorstore = Chroma(
collection_name="drug_formulary",
embedding_function=embeddings,
persist_directory="./chroma_db",
)
retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
# Format retrieved docs into a single context string
def format_docs(docs: list[Document]) -> str:
return "\n\n".join(
f"[Source {i+1}]: {doc.page_content}"
for i, doc in enumerate(docs)
)
# RAG prompt
prompt = ChatPromptTemplate.from_messages([
("system",
"You are a clinical pharmacist. Answer questions using only the retrieved context below. "
"If the answer is not in the context, say 'I don't have enough information on this.' "
"Never invent drug information.\n\nContext:\n{context}"),
("human", "{question}"),
])
model = ChatOpenAI(model="gpt-4o", temperature=0)
# LCEL RAG chain
rag_chain = (
{
"context": retriever | format_docs,
"question": RunnablePassthrough(),
}
| prompt
| model
| StrOutputParser()
)
answer = rag_chain.invoke("What does warfarin inhibit?")
print(answer)RAG Chain with Source Citations
Return both the answer and the sources it came from:
Python
from langchain_core.runnables import RunnableParallel
# Build a chain that returns answer + source documents together
rag_chain_with_sources = RunnableParallel(
answer=rag_chain,
sources=retriever, # Returns the raw Document objects in parallel
)
result = rag_chain_with_sources.invoke("How is warfarin monitored?")
print(f"Answer: {result['answer']}\n")
print("Sources:")
for doc in result["sources"]:
source = doc.metadata.get("source", "Unknown")
page = doc.metadata.get("page", "")
print(f" - {source}" + (f", page {page}" if page else ""))
# Inline citations: include source numbers in the answer
citation_prompt = ChatPromptTemplate.from_messages([
("system",
"You are a clinical pharmacist. Answer using the retrieved context. "
"Cite sources using [1], [2], etc. inline in your answer. "
"Only cite sources that actually support the statement.\n\n"
"Context:\n{context}"),
("human", "{question}"),
])
def format_docs_numbered(docs: list[Document]) -> str:
return "\n\n".join(
f"[{i+1}] {doc.page_content}\n(Source: {doc.metadata.get('source', 'unknown')})"
for i, doc in enumerate(docs)
)
citation_chain = (
{
"context": retriever | format_docs_numbered,
"question": RunnablePassthrough(),
}
| citation_prompt
| model
| StrOutputParser()
)
answer = citation_chain.invoke("What monitoring is required for warfarin?")
# Answer: "Warfarin requires regular INR monitoring [1]. Frequency depends on stability [2]."Conversational RAG
A basic RAG chain has no memory — each question is independent. Conversational RAG rewrites follow-up questions using chat history before retrieving.
Python
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.messages import HumanMessage, AIMessage
# Step 1: History-aware retriever
# Takes (question + chat history) → rewrites question → retrieves docs
contextualize_q_prompt = ChatPromptTemplate.from_messages([
("system",
"Given the chat history and the latest question, reformulate the question "
"so it is self-contained (no references to 'it', 'that drug', 'the previous one'). "
"Do NOT answer — only reformulate the question if needed. Return it as-is if it's already clear."),
("placeholder", "{chat_history}"),
("human", "{input}"),
])
history_aware_retriever = create_history_aware_retriever(
llm=model,
retriever=retriever,
prompt=contextualize_q_prompt,
)
# Step 2: QA chain (uses retrieved docs to answer)
qa_prompt = ChatPromptTemplate.from_messages([
("system",
"You are a clinical pharmacist. Answer using the retrieved context only. "
"If unsure, say so — never invent drug information.\n\nContext:\n{context}"),
("placeholder", "{chat_history}"),
("human", "{input}"),
])
question_answer_chain = create_stuff_documents_chain(model, qa_prompt)
# Step 3: Combine into conversational RAG chain
conversational_rag = create_retrieval_chain(history_aware_retriever, question_answer_chain)
# Use with manual history tracking
chat_history = []
def chat(question: str) -> str:
result = conversational_rag.invoke({
"input": question,
"chat_history": chat_history,
})
# Update history
chat_history.extend([
HumanMessage(content=question),
AIMessage(content=result["answer"]),
])
return result["answer"]
# Multi-turn conversation
print(chat("What is warfarin?"))
print(chat("What are its main interactions?")) # "its" → resolved to "warfarin" by rewriter
print(chat("What about with aspirin specifically?")) # Follow-up also resolvedRAG with LangChain Memory
Attach memory to avoid managing chat_history manually:
Python
from langchain.memory import ConversationBufferWindowMemory
from langchain_core.runnables import RunnableLambda
from langchain_core.messages import BaseMessage
class RAGChatbot:
def __init__(self, retriever, model, k: int = 6):
self.retriever = retriever
self.model = model
self.memory = ConversationBufferWindowMemory(
k=k,
memory_key="chat_history",
return_messages=True,
output_key="answer",
)
self.chain = create_retrieval_chain(
create_history_aware_retriever(model, retriever, contextualize_q_prompt),
create_stuff_documents_chain(model, qa_prompt),
)
def chat(self, question: str) -> dict:
history = self.memory.load_memory_variables({})["chat_history"]
result = self.chain.invoke({
"input": question,
"chat_history": history,
})
self.memory.save_context(
{"input": question},
{"answer": result["answer"]},
)
return {
"answer": result["answer"],
"sources": [d.metadata.get("source", "?") for d in result.get("context", [])],
}
bot = RAGChatbot(retriever=retriever, model=model)
r1 = bot.chat("What is warfarin?")
r2 = bot.chat("What are its interactions with aspirin?")
print(r2["answer"])
print(f"Sources: {r2['sources']}")Streaming RAG
Stream the answer token by token as it's generated:
Python
# Stream from a basic RAG chain
for chunk in rag_chain.stream("How does warfarin interact with aspirin?"):
print(chunk, end="", flush=True)
print()
# Stream from conversational RAG (only stream the final answer, not internal steps)
async def stream_conversational(question: str, chat_history: list):
async for chunk in conversational_rag.astream({
"input": question,
"chat_history": chat_history,
}):
# The chain emits dict chunks at each step
if "answer" in chunk:
print(chunk["answer"], end="", flush=True)
print()
import asyncio
asyncio.run(stream_conversational("What is warfarin?", chat_history=[]))Multi-Turn RAG with Session Management
Python
from collections import defaultdict
class SessionRAGPipeline:
"""RAG chatbot with per-session conversation history."""
def __init__(self, retriever, model, max_history_turns: int = 8):
self.retriever = retriever
self.model = model
self.max_turns = max_history_turns
self._sessions: dict[str, list[BaseMessage]] = defaultdict(list)
self._chain = create_retrieval_chain(
create_history_aware_retriever(model, retriever, contextualize_q_prompt),
create_stuff_documents_chain(model, qa_prompt),
)
def _get_history(self, session_id: str) -> list[BaseMessage]:
history = self._sessions[session_id]
# Keep only last max_turns pairs
max_messages = self.max_turns * 2
return history[-max_messages:] if len(history) > max_messages else history
def chat(self, session_id: str, question: str) -> str:
history = self._get_history(session_id)
result = self._chain.invoke({
"input": question,
"chat_history": history,
})
self._sessions[session_id].extend([
HumanMessage(content=question),
AIMessage(content=result["answer"]),
])
return result["answer"]
def reset_session(self, session_id: str) -> None:
self._sessions.pop(session_id, None)
pipeline = SessionRAGPipeline(retriever=retriever, model=model)
# Two concurrent users, isolated histories
print(pipeline.chat("user_1", "What is warfarin?"))
print(pipeline.chat("user_2", "What is metformin?"))
print(pipeline.chat("user_1", "What are its interactions?")) # "its" = warfarin for user_1
print(pipeline.chat("user_2", "What are its side effects?")) # "its" = metformin for user_2Production RAG Pattern
Python
import time
import logging
logger = logging.getLogger("rag")
def production_rag_query(
question: str,
session_id: str,
pipeline: SessionRAGPipeline,
) -> dict:
start = time.time()
try:
# Input validation
if not question.strip():
return {"answer": "Please provide a question.", "success": False}
if len(question) > 1000:
return {"answer": "Question too long. Please shorten it.", "success": False}
answer = pipeline.chat(session_id, question)
latency_ms = round((time.time() - start) * 1000)
logger.info("rag_success", extra={
"session_id": session_id,
"latency_ms": latency_ms,
"question_chars": len(question),
})
return {
"answer": answer,
"latency_ms": latency_ms,
"success": True,
"disclaimer": "Verify all clinical information with current references before use.",
}
except Exception as e:
logger.error("rag_error", extra={"session_id": session_id, "error": str(e)})
return {
"answer": "I encountered an error. Please try rephrasing your question.",
"success": False,
}