Learnixo
Back to blog
AI Systemsintermediate

Building a RAG Chain with LangChain

Build retrieval-augmented generation chains with LCEL. Covers basic RAG, conversational RAG with history, source citation, streaming, and production patterns.

Asma Hafeez KhanMay 16, 20266 min read
LangChainRAGLCELRetrievalConversational RAGChat History
Share:𝕏

Basic RAG Chain

A RAG chain retrieves relevant documents, formats them into a prompt, and passes the combined context to the LLM.

Python
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.documents import Document

# Setup: vector store with drug information
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
vectorstore = Chroma(
    collection_name="drug_formulary",
    embedding_function=embeddings,
    persist_directory="./chroma_db",
)
retriever = vectorstore.as_retriever(search_kwargs={"k": 4})

# Format retrieved docs into a single context string
def format_docs(docs: list[Document]) -> str:
    return "\n\n".join(
        f"[Source {i+1}]: {doc.page_content}"
        for i, doc in enumerate(docs)
    )

# RAG prompt
prompt = ChatPromptTemplate.from_messages([
    ("system",
     "You are a clinical pharmacist. Answer questions using only the retrieved context below. "
     "If the answer is not in the context, say 'I don't have enough information on this.' "
     "Never invent drug information.\n\nContext:\n{context}"),
    ("human", "{question}"),
])

model = ChatOpenAI(model="gpt-4o", temperature=0)

# LCEL RAG chain
rag_chain = (
    {
        "context": retriever | format_docs,
        "question": RunnablePassthrough(),
    }
    | prompt
    | model
    | StrOutputParser()
)

answer = rag_chain.invoke("What does warfarin inhibit?")
print(answer)

RAG Chain with Source Citations

Return both the answer and the sources it came from:

Python
from langchain_core.runnables import RunnableParallel

# Build a chain that returns answer + source documents together
rag_chain_with_sources = RunnableParallel(
    answer=rag_chain,
    sources=retriever,   # Returns the raw Document objects in parallel
)

result = rag_chain_with_sources.invoke("How is warfarin monitored?")

print(f"Answer: {result['answer']}\n")
print("Sources:")
for doc in result["sources"]:
    source = doc.metadata.get("source", "Unknown")
    page = doc.metadata.get("page", "")
    print(f"  - {source}" + (f", page {page}" if page else ""))


# Inline citations: include source numbers in the answer
citation_prompt = ChatPromptTemplate.from_messages([
    ("system",
     "You are a clinical pharmacist. Answer using the retrieved context. "
     "Cite sources using [1], [2], etc. inline in your answer. "
     "Only cite sources that actually support the statement.\n\n"
     "Context:\n{context}"),
    ("human", "{question}"),
])

def format_docs_numbered(docs: list[Document]) -> str:
    return "\n\n".join(
        f"[{i+1}] {doc.page_content}\n(Source: {doc.metadata.get('source', 'unknown')})"
        for i, doc in enumerate(docs)
    )

citation_chain = (
    {
        "context": retriever | format_docs_numbered,
        "question": RunnablePassthrough(),
    }
    | citation_prompt
    | model
    | StrOutputParser()
)

answer = citation_chain.invoke("What monitoring is required for warfarin?")
# Answer: "Warfarin requires regular INR monitoring [1]. Frequency depends on stability [2]."

Conversational RAG

A basic RAG chain has no memory — each question is independent. Conversational RAG rewrites follow-up questions using chat history before retrieving.

Python
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.messages import HumanMessage, AIMessage

# Step 1: History-aware retriever
# Takes (question + chat history)  rewrites question  retrieves docs
contextualize_q_prompt = ChatPromptTemplate.from_messages([
    ("system",
     "Given the chat history and the latest question, reformulate the question "
     "so it is self-contained (no references to 'it', 'that drug', 'the previous one'). "
     "Do NOT answer — only reformulate the question if needed. Return it as-is if it's already clear."),
    ("placeholder", "{chat_history}"),
    ("human", "{input}"),
])

history_aware_retriever = create_history_aware_retriever(
    llm=model,
    retriever=retriever,
    prompt=contextualize_q_prompt,
)

# Step 2: QA chain (uses retrieved docs to answer)
qa_prompt = ChatPromptTemplate.from_messages([
    ("system",
     "You are a clinical pharmacist. Answer using the retrieved context only. "
     "If unsure, say so — never invent drug information.\n\nContext:\n{context}"),
    ("placeholder", "{chat_history}"),
    ("human", "{input}"),
])

question_answer_chain = create_stuff_documents_chain(model, qa_prompt)

# Step 3: Combine into conversational RAG chain
conversational_rag = create_retrieval_chain(history_aware_retriever, question_answer_chain)

# Use with manual history tracking
chat_history = []

def chat(question: str) -> str:
    result = conversational_rag.invoke({
        "input": question,
        "chat_history": chat_history,
    })
    
    # Update history
    chat_history.extend([
        HumanMessage(content=question),
        AIMessage(content=result["answer"]),
    ])
    
    return result["answer"]


# Multi-turn conversation
print(chat("What is warfarin?"))
print(chat("What are its main interactions?"))    # "its"  resolved to "warfarin" by rewriter
print(chat("What about with aspirin specifically?"))  # Follow-up also resolved

RAG with LangChain Memory

Attach memory to avoid managing chat_history manually:

Python
from langchain.memory import ConversationBufferWindowMemory
from langchain_core.runnables import RunnableLambda
from langchain_core.messages import BaseMessage

class RAGChatbot:
    def __init__(self, retriever, model, k: int = 6):
        self.retriever = retriever
        self.model = model
        self.memory = ConversationBufferWindowMemory(
            k=k,
            memory_key="chat_history",
            return_messages=True,
            output_key="answer",
        )
        self.chain = create_retrieval_chain(
            create_history_aware_retriever(model, retriever, contextualize_q_prompt),
            create_stuff_documents_chain(model, qa_prompt),
        )

    def chat(self, question: str) -> dict:
        history = self.memory.load_memory_variables({})["chat_history"]
        
        result = self.chain.invoke({
            "input": question,
            "chat_history": history,
        })
        
        self.memory.save_context(
            {"input": question},
            {"answer": result["answer"]},
        )
        
        return {
            "answer": result["answer"],
            "sources": [d.metadata.get("source", "?") for d in result.get("context", [])],
        }


bot = RAGChatbot(retriever=retriever, model=model)
r1 = bot.chat("What is warfarin?")
r2 = bot.chat("What are its interactions with aspirin?")
print(r2["answer"])
print(f"Sources: {r2['sources']}")

Streaming RAG

Stream the answer token by token as it's generated:

Python
# Stream from a basic RAG chain
for chunk in rag_chain.stream("How does warfarin interact with aspirin?"):
    print(chunk, end="", flush=True)
print()

# Stream from conversational RAG (only stream the final answer, not internal steps)
async def stream_conversational(question: str, chat_history: list):
    async for chunk in conversational_rag.astream({
        "input": question,
        "chat_history": chat_history,
    }):
        # The chain emits dict chunks at each step
        if "answer" in chunk:
            print(chunk["answer"], end="", flush=True)
    print()

import asyncio
asyncio.run(stream_conversational("What is warfarin?", chat_history=[]))

Multi-Turn RAG with Session Management

Python
from collections import defaultdict

class SessionRAGPipeline:
    """RAG chatbot with per-session conversation history."""

    def __init__(self, retriever, model, max_history_turns: int = 8):
        self.retriever = retriever
        self.model = model
        self.max_turns = max_history_turns
        self._sessions: dict[str, list[BaseMessage]] = defaultdict(list)
        
        self._chain = create_retrieval_chain(
            create_history_aware_retriever(model, retriever, contextualize_q_prompt),
            create_stuff_documents_chain(model, qa_prompt),
        )

    def _get_history(self, session_id: str) -> list[BaseMessage]:
        history = self._sessions[session_id]
        # Keep only last max_turns pairs
        max_messages = self.max_turns * 2
        return history[-max_messages:] if len(history) > max_messages else history

    def chat(self, session_id: str, question: str) -> str:
        history = self._get_history(session_id)
        
        result = self._chain.invoke({
            "input": question,
            "chat_history": history,
        })
        
        self._sessions[session_id].extend([
            HumanMessage(content=question),
            AIMessage(content=result["answer"]),
        ])
        
        return result["answer"]

    def reset_session(self, session_id: str) -> None:
        self._sessions.pop(session_id, None)


pipeline = SessionRAGPipeline(retriever=retriever, model=model)

# Two concurrent users, isolated histories
print(pipeline.chat("user_1", "What is warfarin?"))
print(pipeline.chat("user_2", "What is metformin?"))
print(pipeline.chat("user_1", "What are its interactions?"))  # "its" = warfarin for user_1
print(pipeline.chat("user_2", "What are its side effects?"))  # "its" = metformin for user_2

Production RAG Pattern

Python
import time
import logging

logger = logging.getLogger("rag")

def production_rag_query(
    question: str,
    session_id: str,
    pipeline: SessionRAGPipeline,
) -> dict:
    start = time.time()
    
    try:
        # Input validation
        if not question.strip():
            return {"answer": "Please provide a question.", "success": False}
        if len(question) > 1000:
            return {"answer": "Question too long. Please shorten it.", "success": False}
        
        answer = pipeline.chat(session_id, question)
        latency_ms = round((time.time() - start) * 1000)
        
        logger.info("rag_success", extra={
            "session_id": session_id,
            "latency_ms": latency_ms,
            "question_chars": len(question),
        })
        
        return {
            "answer": answer,
            "latency_ms": latency_ms,
            "success": True,
            "disclaimer": "Verify all clinical information with current references before use.",
        }
    
    except Exception as e:
        logger.error("rag_error", extra={"session_id": session_id, "error": str(e)})
        return {
            "answer": "I encountered an error. Please try rephrasing your question.",
            "success": False,
        }

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.