Learnixo
Back to blog
AI Systemsintermediate

Conversational RAG: Multi-Turn Dialogue

Build RAG systems that maintain conversation history, resolve coreference, rewrite follow-up queries, and manage context across multi-turn clinical dialogues.

Asma Hafeez KhanMay 16, 20267 min read
RAGConversational AIMulti-TurnContext ManagementCoreference
Share:š•

The Conversational RAG Challenge

Single-turn RAG is stateless — each query is independent. In conversation, users ask follow-up questions that assume shared context:

  • Turn 1: "What are the indications for warfarin?"
  • Turn 2: "What are the main drug interactions?" (which drug? warfarin)
  • Turn 3: "Is it safe for elderly patients?" (which drug? still warfarin)

A stateless RAG retrieves the wrong context for turns 2 and 3. Conversational RAG maintains dialogue state and rewrites queries accordingly.


Conversation State Management

Python
from dataclasses import dataclass, field
from typing import Optional
import json

@dataclass
class ConversationTurn:
    """A single turn in a conversation."""
    role: str           # "user" or "assistant"
    content: str
    retrieved_docs: list[dict] = field(default_factory=list)
    query_used: str = ""    # The (possibly rewritten) query used for retrieval


@dataclass
class ConversationSession:
    """State for a multi-turn RAG conversation."""
    session_id: str
    turns: list[ConversationTurn] = field(default_factory=list)
    topic_entities: list[str] = field(default_factory=list)   # Entities mentioned
    active_drug: Optional[str] = None   # Most recently discussed drug

    def add_turn(self, role: str, content: str, **kwargs) -> None:
        self.turns.append(ConversationTurn(role=role, content=content, **kwargs))

    def get_history_text(self, max_turns: int = 6) -> str:
        """Return recent conversation history as formatted text."""
        recent = self.turns[-max_turns:] if len(self.turns) > max_turns else self.turns
        lines = []
        for turn in recent:
            prefix = "User" if turn.role == "user" else "Assistant"
            lines.append(f"{prefix}: {turn.content}")
        return "\n".join(lines)

    def get_context_summary(self) -> str:
        """Summarize the conversation topic for query rewriting."""
        if not self.topic_entities:
            return ""
        entities = ", ".join(self.topic_entities[:5])
        if self.active_drug:
            return f"Currently discussing: {self.active_drug}. Other topics: {entities}."
        return f"Topics discussed: {entities}."

Query Rewriting for Conversation

Transform follow-up questions into standalone, retrievable queries:

Python
from openai import OpenAI

client = OpenAI()

REWRITE_SYSTEM = """You are a query rewriting assistant for a clinical RAG system.
Given a conversation history and a follow-up question, rewrite the follow-up into a complete, 
standalone search query that does not rely on the conversation history.

Rules:
- Replace pronouns (it, they, this drug, the medication) with specific names
- Make implicit references explicit ("the interaction" → "warfarin-aspirin interaction")  
- Preserve medical specificity — don't over-generalize
- If the question is already standalone, return it unchanged
- Return ONLY the rewritten query, nothing else"""


def rewrite_follow_up_query(
    follow_up: str,
    conversation_history: str,
    context_summary: str = "",
) -> str:
    """Rewrite a conversational follow-up into a standalone search query."""
    prompt = f"""Conversation history:
{conversation_history}

{f'Context: {context_summary}' if context_summary else ''}

Follow-up question: {follow_up}

Rewrite as a standalone search query:"""

    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {"role": "system", "content": REWRITE_SYSTEM},
            {"role": "user", "content": prompt},
        ],
        temperature=0,
        max_tokens=100,
    )
    return response.choices[0].message.content.strip()


def is_follow_up_question(
    query: str,
    conversation_history: str,
) -> bool:
    """Detect whether a query depends on conversation context."""
    # Fast heuristic check first
    follow_up_indicators = [
        "it", "they", "this", "that", "the drug", "the medication",
        "what about", "and also", "what else", "more about",
        "is it", "does it", "can it", "how about",
        "same", "other", "another",
    ]
    query_lower = query.lower()
    has_pronoun = any(f" {ind} " in f" {query_lower} " for ind in follow_up_indicators)

    if not has_pronoun and not conversation_history:
        return False

    # LLM check for ambiguous cases
    if not conversation_history:
        return False

    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {
                "role": "user",
                "content": f"""Given this conversation history and a new query, does the query depend on 
the conversation history to be understood?

Conversation:
{conversation_history[-500:]}

New query: {query}

Return JSON: {{"is_follow_up": true/false, "reason": "brief explanation"}}""",
            }
        ],
        response_format={"type": "json_object"},
        temperature=0,
    )
    result = json.loads(response.choices[0].message.content)
    return result.get("is_follow_up", False)

Entity Tracking

Track which medical entities are being discussed:

Python
def extract_conversation_entities(text: str) -> dict:
    """Extract drug names and medical entities from text."""
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {
                "role": "user",
                "content": f"""Extract drug names, medical conditions, and procedures from this text.

Text: {text}

Return JSON:
{{
  "drugs": ["list of drug names mentioned"],
  "conditions": ["list of conditions/diagnoses"],
  "procedures": ["list of procedures/tests"],
  "primary_entity": "the main subject being asked about"
}}""",
            }
        ],
        response_format={"type": "json_object"},
        temperature=0,
    )
    return json.loads(response.choices[0].message.content)


def update_session_entities(
    session: ConversationSession,
    query: str,
    response: str,
) -> None:
    """Update tracked entities after each turn."""
    # Extract from both query and response
    query_entities = extract_conversation_entities(query)
    resp_entities = extract_conversation_entities(response[:500])

    # Update active drug (most recently mentioned drug)
    all_drugs = query_entities.get("drugs", []) + resp_entities.get("drugs", [])
    if all_drugs:
        session.active_drug = all_drugs[0]

    # Update topic entity list
    all_entities = (
        all_drugs
        + query_entities.get("conditions", [])
        + resp_entities.get("conditions", [])
    )
    for entity in all_entities:
        if entity and entity not in session.topic_entities:
            session.topic_entities.insert(0, entity)  # Most recent first

    # Keep entity list bounded
    session.topic_entities = session.topic_entities[:20]

Context Window Management

Long conversations need intelligent summarization:

Python
def summarize_conversation(
    turns: list[ConversationTurn],
    max_tokens: int = 500,
) -> str:
    """Summarize older conversation turns to keep context window manageable."""
    history_text = "\n".join([
        f"{'User' if t.role == 'user' else 'Assistant'}: {t.content}"
        for t in turns
    ])

    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {
                "role": "user",
                "content": f"""Summarize this clinical conversation in 2-3 sentences.
Focus on: which drugs/conditions were discussed, what was established, what questions were answered.

Conversation:
{history_text}

Summary:""",
            }
        ],
        temperature=0,
        max_tokens=max_tokens,
    )
    return response.choices[0].message.content.strip()


def build_compact_history(
    session: ConversationSession,
    recent_turns: int = 4,
    summarize_older: bool = True,
) -> str:
    """
    Build a compact conversation history.
    Keeps recent turns verbatim, summarizes older turns.
    """
    all_turns = session.turns
    if len(all_turns) <= recent_turns:
        return session.get_history_text()

    older_turns = all_turns[:-recent_turns]
    recent = all_turns[-recent_turns:]

    parts = []
    if summarize_older and older_turns:
        summary = summarize_conversation(older_turns)
        parts.append(f"[Earlier conversation summary: {summary}]")

    for turn in recent:
        prefix = "User" if turn.role == "user" else "Assistant"
        parts.append(f"{prefix}: {turn.content}")

    return "\n".join(parts)

Full Conversational RAG Pipeline

Python
class ConversationalRAGPipeline:
    """Multi-turn RAG pipeline with conversation state management."""

    def __init__(self, retriever, llm_client, session_store: dict = None):
        self.retriever = retriever
        self.llm = llm_client
        self.sessions: dict[str, ConversationSession] = session_store or {}

    def get_or_create_session(self, session_id: str) -> ConversationSession:
        if session_id not in self.sessions:
            self.sessions[session_id] = ConversationSession(session_id=session_id)
        return self.sessions[session_id]

    def _embed_query(self, query: str) -> list[float]:
        resp = self.llm.embeddings.create(
            model="text-embedding-3-small",
            input=[query],
        )
        return resp.data[0].embedding

    def query(self, session_id: str, user_query: str) -> dict:
        """Process a conversational query with full context management."""
        session = self.get_or_create_session(session_id)

        # Step 1: Determine if rewriting is needed
        history = build_compact_history(session) if session.turns else ""
        context_summary = session.get_context_summary()

        needs_rewrite = is_follow_up_question(user_query, history)
        if needs_rewrite:
            search_query = rewrite_follow_up_query(user_query, history, context_summary)
        else:
            search_query = user_query

        # Step 2: Retrieve with the (possibly rewritten) query
        query_emb = self._embed_query(search_query)
        docs = self.retriever.retrieve(query_emb, top_k=5)
        context_text = "\n\n".join([d["content"] for d in docs])

        # Step 3: Generate with conversation history
        messages = [
            {
                "role": "system",
                "content": (
                    "You are a clinical pharmacology assistant. "
                    "Answer questions using only the provided context. "
                    "Maintain conversational continuity with the chat history."
                ),
            },
        ]

        # Add summarized history as system context
        if history:
            messages.append({
                "role": "system",
                "content": f"Conversation history:\n{history}",
            })

        messages.append({
            "role": "user",
            "content": f"Relevant context:\n{context_text}\n\nQuestion: {user_query}",
        })

        response = self.llm.chat.completions.create(
            model="gpt-4o",
            messages=messages,
            temperature=0,
        ).choices[0].message.content

        # Step 4: Update session state
        session.add_turn(
            role="user",
            content=user_query,
            query_used=search_query,
        )
        session.add_turn(
            role="assistant",
            content=response,
            retrieved_docs=docs,
        )
        update_session_entities(session, user_query, response)

        return {
            "response": response,
            "query_rewritten": needs_rewrite,
            "search_query": search_query if needs_rewrite else user_query,
            "session_id": session_id,
            "turn_number": len(session.turns) // 2,
        }

    def reset_session(self, session_id: str) -> None:
        """Clear conversation history for a session."""
        if session_id in self.sessions:
            del self.sessions[session_id]

    def get_session_summary(self, session_id: str) -> dict:
        """Return metadata about a session."""
        session = self.sessions.get(session_id)
        if not session:
            return {}
        return {
            "turns": len(session.turns) // 2,
            "active_drug": session.active_drug,
            "topic_entities": session.topic_entities[:10],
        }

Example Conversation Flow

Python
def demo_conversation():
    """Show how conversational RAG handles multi-turn queries."""
    pipeline = ConversationalRAGPipeline(retriever=..., llm_client=client)
    sid = "session_001"

    turns = [
        "What is the mechanism of action of warfarin?",
        "What drugs interact with it?",         # "it" = warfarin
        "Is it safe in elderly patients?",       # still about warfarin
        "What about rivaroxaban instead?",       # comparing alternatives
        "How do their dosing schedules compare?",  # comparing warfarin vs rivaroxaban
    ]

    for query in turns:
        result = pipeline.query(sid, query)
        print(f"\nUser: {query}")
        if result["query_rewritten"]:
            print(f"[Rewritten to: {result['search_query']}]")
        print(f"Assistant: {result['response'][:200]}...")

# Turn 1: "mechanism of action of warfarin" → direct retrieval
# Turn 2: rewritten to "drug interactions with warfarin"
# Turn 3: rewritten to "warfarin safety in elderly patients"
# Turn 4: "rivaroxaban as alternative to warfarin" — new entity introduced
# Turn 5: "warfarin vs rivaroxaban dosing comparison"

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:š•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.