AI Systemsintermediate
Conversational RAG: Multi-Turn Dialogue
Build RAG systems that maintain conversation history, resolve coreference, rewrite follow-up queries, and manage context across multi-turn clinical dialogues.
Asma Hafeez KhanMay 16, 20267 min read
RAGConversational AIMulti-TurnContext ManagementCoreference
The Conversational RAG Challenge
Single-turn RAG is stateless ā each query is independent. In conversation, users ask follow-up questions that assume shared context:
- Turn 1: "What are the indications for warfarin?"
- Turn 2: "What are the main drug interactions?" (which drug? warfarin)
- Turn 3: "Is it safe for elderly patients?" (which drug? still warfarin)
A stateless RAG retrieves the wrong context for turns 2 and 3. Conversational RAG maintains dialogue state and rewrites queries accordingly.
Conversation State Management
Python
from dataclasses import dataclass, field
from typing import Optional
import json
@dataclass
class ConversationTurn:
"""A single turn in a conversation."""
role: str # "user" or "assistant"
content: str
retrieved_docs: list[dict] = field(default_factory=list)
query_used: str = "" # The (possibly rewritten) query used for retrieval
@dataclass
class ConversationSession:
"""State for a multi-turn RAG conversation."""
session_id: str
turns: list[ConversationTurn] = field(default_factory=list)
topic_entities: list[str] = field(default_factory=list) # Entities mentioned
active_drug: Optional[str] = None # Most recently discussed drug
def add_turn(self, role: str, content: str, **kwargs) -> None:
self.turns.append(ConversationTurn(role=role, content=content, **kwargs))
def get_history_text(self, max_turns: int = 6) -> str:
"""Return recent conversation history as formatted text."""
recent = self.turns[-max_turns:] if len(self.turns) > max_turns else self.turns
lines = []
for turn in recent:
prefix = "User" if turn.role == "user" else "Assistant"
lines.append(f"{prefix}: {turn.content}")
return "\n".join(lines)
def get_context_summary(self) -> str:
"""Summarize the conversation topic for query rewriting."""
if not self.topic_entities:
return ""
entities = ", ".join(self.topic_entities[:5])
if self.active_drug:
return f"Currently discussing: {self.active_drug}. Other topics: {entities}."
return f"Topics discussed: {entities}."Query Rewriting for Conversation
Transform follow-up questions into standalone, retrievable queries:
Python
from openai import OpenAI
client = OpenAI()
REWRITE_SYSTEM = """You are a query rewriting assistant for a clinical RAG system.
Given a conversation history and a follow-up question, rewrite the follow-up into a complete,
standalone search query that does not rely on the conversation history.
Rules:
- Replace pronouns (it, they, this drug, the medication) with specific names
- Make implicit references explicit ("the interaction" ā "warfarin-aspirin interaction")
- Preserve medical specificity ā don't over-generalize
- If the question is already standalone, return it unchanged
- Return ONLY the rewritten query, nothing else"""
def rewrite_follow_up_query(
follow_up: str,
conversation_history: str,
context_summary: str = "",
) -> str:
"""Rewrite a conversational follow-up into a standalone search query."""
prompt = f"""Conversation history:
{conversation_history}
{f'Context: {context_summary}' if context_summary else ''}
Follow-up question: {follow_up}
Rewrite as a standalone search query:"""
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": REWRITE_SYSTEM},
{"role": "user", "content": prompt},
],
temperature=0,
max_tokens=100,
)
return response.choices[0].message.content.strip()
def is_follow_up_question(
query: str,
conversation_history: str,
) -> bool:
"""Detect whether a query depends on conversation context."""
# Fast heuristic check first
follow_up_indicators = [
"it", "they", "this", "that", "the drug", "the medication",
"what about", "and also", "what else", "more about",
"is it", "does it", "can it", "how about",
"same", "other", "another",
]
query_lower = query.lower()
has_pronoun = any(f" {ind} " in f" {query_lower} " for ind in follow_up_indicators)
if not has_pronoun and not conversation_history:
return False
# LLM check for ambiguous cases
if not conversation_history:
return False
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": f"""Given this conversation history and a new query, does the query depend on
the conversation history to be understood?
Conversation:
{conversation_history[-500:]}
New query: {query}
Return JSON: {{"is_follow_up": true/false, "reason": "brief explanation"}}""",
}
],
response_format={"type": "json_object"},
temperature=0,
)
result = json.loads(response.choices[0].message.content)
return result.get("is_follow_up", False)Entity Tracking
Track which medical entities are being discussed:
Python
def extract_conversation_entities(text: str) -> dict:
"""Extract drug names and medical entities from text."""
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": f"""Extract drug names, medical conditions, and procedures from this text.
Text: {text}
Return JSON:
{{
"drugs": ["list of drug names mentioned"],
"conditions": ["list of conditions/diagnoses"],
"procedures": ["list of procedures/tests"],
"primary_entity": "the main subject being asked about"
}}""",
}
],
response_format={"type": "json_object"},
temperature=0,
)
return json.loads(response.choices[0].message.content)
def update_session_entities(
session: ConversationSession,
query: str,
response: str,
) -> None:
"""Update tracked entities after each turn."""
# Extract from both query and response
query_entities = extract_conversation_entities(query)
resp_entities = extract_conversation_entities(response[:500])
# Update active drug (most recently mentioned drug)
all_drugs = query_entities.get("drugs", []) + resp_entities.get("drugs", [])
if all_drugs:
session.active_drug = all_drugs[0]
# Update topic entity list
all_entities = (
all_drugs
+ query_entities.get("conditions", [])
+ resp_entities.get("conditions", [])
)
for entity in all_entities:
if entity and entity not in session.topic_entities:
session.topic_entities.insert(0, entity) # Most recent first
# Keep entity list bounded
session.topic_entities = session.topic_entities[:20]Context Window Management
Long conversations need intelligent summarization:
Python
def summarize_conversation(
turns: list[ConversationTurn],
max_tokens: int = 500,
) -> str:
"""Summarize older conversation turns to keep context window manageable."""
history_text = "\n".join([
f"{'User' if t.role == 'user' else 'Assistant'}: {t.content}"
for t in turns
])
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": f"""Summarize this clinical conversation in 2-3 sentences.
Focus on: which drugs/conditions were discussed, what was established, what questions were answered.
Conversation:
{history_text}
Summary:""",
}
],
temperature=0,
max_tokens=max_tokens,
)
return response.choices[0].message.content.strip()
def build_compact_history(
session: ConversationSession,
recent_turns: int = 4,
summarize_older: bool = True,
) -> str:
"""
Build a compact conversation history.
Keeps recent turns verbatim, summarizes older turns.
"""
all_turns = session.turns
if len(all_turns) <= recent_turns:
return session.get_history_text()
older_turns = all_turns[:-recent_turns]
recent = all_turns[-recent_turns:]
parts = []
if summarize_older and older_turns:
summary = summarize_conversation(older_turns)
parts.append(f"[Earlier conversation summary: {summary}]")
for turn in recent:
prefix = "User" if turn.role == "user" else "Assistant"
parts.append(f"{prefix}: {turn.content}")
return "\n".join(parts)Full Conversational RAG Pipeline
Python
class ConversationalRAGPipeline:
"""Multi-turn RAG pipeline with conversation state management."""
def __init__(self, retriever, llm_client, session_store: dict = None):
self.retriever = retriever
self.llm = llm_client
self.sessions: dict[str, ConversationSession] = session_store or {}
def get_or_create_session(self, session_id: str) -> ConversationSession:
if session_id not in self.sessions:
self.sessions[session_id] = ConversationSession(session_id=session_id)
return self.sessions[session_id]
def _embed_query(self, query: str) -> list[float]:
resp = self.llm.embeddings.create(
model="text-embedding-3-small",
input=[query],
)
return resp.data[0].embedding
def query(self, session_id: str, user_query: str) -> dict:
"""Process a conversational query with full context management."""
session = self.get_or_create_session(session_id)
# Step 1: Determine if rewriting is needed
history = build_compact_history(session) if session.turns else ""
context_summary = session.get_context_summary()
needs_rewrite = is_follow_up_question(user_query, history)
if needs_rewrite:
search_query = rewrite_follow_up_query(user_query, history, context_summary)
else:
search_query = user_query
# Step 2: Retrieve with the (possibly rewritten) query
query_emb = self._embed_query(search_query)
docs = self.retriever.retrieve(query_emb, top_k=5)
context_text = "\n\n".join([d["content"] for d in docs])
# Step 3: Generate with conversation history
messages = [
{
"role": "system",
"content": (
"You are a clinical pharmacology assistant. "
"Answer questions using only the provided context. "
"Maintain conversational continuity with the chat history."
),
},
]
# Add summarized history as system context
if history:
messages.append({
"role": "system",
"content": f"Conversation history:\n{history}",
})
messages.append({
"role": "user",
"content": f"Relevant context:\n{context_text}\n\nQuestion: {user_query}",
})
response = self.llm.chat.completions.create(
model="gpt-4o",
messages=messages,
temperature=0,
).choices[0].message.content
# Step 4: Update session state
session.add_turn(
role="user",
content=user_query,
query_used=search_query,
)
session.add_turn(
role="assistant",
content=response,
retrieved_docs=docs,
)
update_session_entities(session, user_query, response)
return {
"response": response,
"query_rewritten": needs_rewrite,
"search_query": search_query if needs_rewrite else user_query,
"session_id": session_id,
"turn_number": len(session.turns) // 2,
}
def reset_session(self, session_id: str) -> None:
"""Clear conversation history for a session."""
if session_id in self.sessions:
del self.sessions[session_id]
def get_session_summary(self, session_id: str) -> dict:
"""Return metadata about a session."""
session = self.sessions.get(session_id)
if not session:
return {}
return {
"turns": len(session.turns) // 2,
"active_drug": session.active_drug,
"topic_entities": session.topic_entities[:10],
}Example Conversation Flow
Python
def demo_conversation():
"""Show how conversational RAG handles multi-turn queries."""
pipeline = ConversationalRAGPipeline(retriever=..., llm_client=client)
sid = "session_001"
turns = [
"What is the mechanism of action of warfarin?",
"What drugs interact with it?", # "it" = warfarin
"Is it safe in elderly patients?", # still about warfarin
"What about rivaroxaban instead?", # comparing alternatives
"How do their dosing schedules compare?", # comparing warfarin vs rivaroxaban
]
for query in turns:
result = pipeline.query(sid, query)
print(f"\nUser: {query}")
if result["query_rewritten"]:
print(f"[Rewritten to: {result['search_query']}]")
print(f"Assistant: {result['response'][:200]}...")
# Turn 1: "mechanism of action of warfarin" ā direct retrieval
# Turn 2: rewritten to "drug interactions with warfarin"
# Turn 3: rewritten to "warfarin safety in elderly patients"
# Turn 4: "rivaroxaban as alternative to warfarin" ā new entity introduced
# Turn 5: "warfarin vs rivaroxaban dosing comparison"Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.