Learnixo
Back to blog
AI Systemsadvanced

Graph RAG: Knowledge Graph-Enhanced Retrieval

Enhance RAG with knowledge graphs. GraphRAG by Microsoft, entity extraction, relationship indexing, and combining vector search with graph traversal.

Asma Hafeez KhanMay 16, 20266 min read
RAGKnowledge GraphGraphRAGEntity ExtractionNeo4j
Share:𝕏

Why Graph RAG

Standard vector RAG treats each document chunk independently. Knowledge graphs capture relationships between entities:

  • Drug A inhibits Enzyme B
  • Drug B is substrate of Enzyme B
  • Therefore: Drug A raises levels of Drug B

This three-hop reasoning is impossible with standard RAG — it would require the exact text "Drug A raises Drug B levels" to be in a document. Graph RAG can reason over entity relationships.

Key advantages:

  • Multi-hop reasoning across entity relationships
  • Global summaries (what topics is this corpus about?)
  • Entity-centric queries (everything about Drug A)
  • Relationship traversal (what drugs interact with warfarin via CYP2C9?)

Entity and Relationship Extraction

The first step: extract entities and relationships from documents:

Python
from openai import OpenAI
import json
from dataclasses import dataclass

client = OpenAI()

@dataclass
class Entity:
    id: str
    name: str
    entity_type: str    # "drug", "enzyme", "condition", "gene", etc.
    description: str
    aliases: list[str]

@dataclass
class Relationship:
    source_id: str
    target_id: str
    relation_type: str  # "inhibits", "activates", "treats", "causes", etc.
    description: str
    strength: str       # "strong", "moderate", "weak"
    evidence: str       # Source text supporting this relationship

EXTRACTION_PROMPT = """Extract entities and relationships from this clinical pharmacology text.

Entities to identify:
- DRUG: medications, pharmaceutical agents
- ENZYME: metabolic enzymes (CYP450, UGT, etc.)
- CONDITION: diseases, disorders, symptoms
- PROTEIN: receptors, transporters
- GENE: genetic variants affecting drug metabolism

Relationships to identify:
- inhibits: Drug A inhibits Enzyme B
- induces: Drug A induces Enzyme B  
- treats: Drug A treats Condition B
- causes: Drug A causes Condition B
- substrate_of: Drug A is metabolized by Enzyme B
- interacts_with: Drug A interacts with Drug B

Return JSON:
{
  "entities": [{"id": "e1", "name": "...", "type": "DRUG", "description": "...", "aliases": []}],
  "relationships": [{"source": "e1", "target": "e2", "type": "inhibits", "description": "...", "strength": "strong"}]
}"""


def extract_knowledge_graph(text: str) -> dict:
    """Extract entities and relationships from text."""
    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": EXTRACTION_PROMPT},
            {"role": "user", "content": f"Text:\n{text[:3000]}"},
        ],
        response_format={"type": "json_object"},
        temperature=0,
    )
    return json.loads(response.choices[0].message.content)


def build_knowledge_graph(documents: list[dict]) -> dict:
    """Extract and merge entity relationships from a corpus."""
    all_entities = {}  # id  entity
    all_relationships = []

    for doc in documents:
        extracted = extract_knowledge_graph(doc["content"])

        # Merge entities (deduplicate by name)
        for entity in extracted.get("entities", []):
            entity_key = entity["name"].lower()
            if entity_key not in all_entities:
                all_entities[entity_key] = entity
            else:
                # Merge aliases
                existing = all_entities[entity_key]
                existing["aliases"] = list(set(
                    existing.get("aliases", []) + entity.get("aliases", [])
                ))

        # Collect relationships
        for rel in extracted.get("relationships", []):
            all_relationships.append({
                **rel,
                "source_document": doc["id"],
            })

    return {
        "entities": list(all_entities.values()),
        "relationships": all_relationships,
    }

Storing in Neo4j

Python
from neo4j import GraphDatabase

class ClinicalKnowledgeGraph:
    """Neo4j-backed knowledge graph for clinical pharmacology."""

    def __init__(self, uri: str, username: str, password: str):
        self.driver = GraphDatabase.driver(uri, auth=(username, password))

    def close(self):
        self.driver.close()

    def create_entity(self, entity: dict) -> None:
        """Create or update an entity node."""
        with self.driver.session() as session:
            session.run(
                """
                MERGE (e:Entity {name: $name})
                SET e.type = $type,
                    e.description = $description,
                    e.aliases = $aliases
                WITH e
                CALL apoc.create.addLabels(e, [$type]) YIELD node
                RETURN node
                """,
                name=entity["name"],
                type=entity.get("type", "Unknown"),
                description=entity.get("description", ""),
                aliases=entity.get("aliases", []),
            )

    def create_relationship(self, rel: dict) -> None:
        """Create a relationship between entities."""
        with self.driver.session() as session:
            session.run(
                f"""
                MATCH (a:Entity {{name: $source}})
                MATCH (b:Entity {{name: $target}})
                MERGE (a)-[r:{rel['type'].upper()}]->(b)
                SET r.description = $description,
                    r.strength = $strength
                """,
                source=rel["source"],
                target=rel["target"],
                description=rel.get("description", ""),
                strength=rel.get("strength", "unknown"),
            )

    def query_entity_relationships(self, entity_name: str, depth: int = 2) -> list[dict]:
        """Find all relationships within N hops of an entity."""
        with self.driver.session() as session:
            result = session.run(
                f"""
                MATCH path = (start:Entity {{name: $name}})-[*1..{depth}]-(related)
                RETURN path
                LIMIT 50
                """,
                name=entity_name,
            )
            return [record["path"] for record in result]

    def find_interaction_path(self, drug_a: str, drug_b: str) -> list[dict]:
        """Find the shortest interaction path between two drugs."""
        with self.driver.session() as session:
            result = session.run(
                """
                MATCH path = shortestPath(
                    (a:Entity {name: $drug_a})-[*..5]-(b:Entity {name: $drug_b})
                )
                RETURN path, length(path) as hops
                ORDER BY hops
                LIMIT 3
                """,
                drug_a=drug_a,
                drug_b=drug_b,
            )
            return [{"path": r["path"], "hops": r["hops"]} for r in result]

    def get_entity_community(self, entity_name: str) -> list[str]:
        """Find entities in the same community (densely connected cluster)."""
        with self.driver.session() as session:
            result = session.run(
                """
                MATCH (e:Entity {name: $name})-[:INTERACTS_WITH|:INHIBITS|:SUBSTRATE_OF]-(related)
                RETURN DISTINCT related.name as name, count(*) as connections
                ORDER BY connections DESC
                LIMIT 20
                """,
                name=entity_name,
            )
            return [r["name"] for r in result]

Microsoft GraphRAG

Microsoft's GraphRAG builds hierarchical community summaries for global question-answering:

Python
# Using Microsoft GraphRAG library
# pip install graphrag

# Setup
from graphrag.index import run_pipeline_with_config
from graphrag.query.api import global_search, local_search

# GraphRAG builds two indexes:
# 1. Local index: entity-level information + vector embeddings
# 2. Global index: community summaries at multiple levels of abstraction

# Local search: best for specific entity questions
# "What is the mechanism of warfarin?"
async def local_rag_query(query: str, data_dir: str) -> str:
    """GraphRAG local search — best for specific entity questions."""
    result = await local_search(
        config_path=f"{data_dir}/settings.yml",
        query=query,
        community_level=2,
    )
    return result.response


# Global search: best for holistic questions across the whole corpus
# "What are the main themes in this clinical pharmacology database?"
async def global_rag_query(query: str, data_dir: str) -> str:
    """GraphRAG global search — best for high-level questions."""
    result = await global_search(
        config_path=f"{data_dir}/settings.yml",
        query=query,
        community_level=2,
    )
    return result.response

Hybrid: Vector + Graph Retrieval

Combine standard vector search with graph traversal:

Python
class HybridGraphRetriever:
    """Combines vector retrieval with graph-based entity expansion."""

    def __init__(self, vector_retriever, knowledge_graph: ClinicalKnowledgeGraph):
        self.vector_retriever = vector_retriever
        self.graph = knowledge_graph

    def retrieve(
        self,
        query: str,
        query_embedding: list[float],
        top_k: int = 5,
    ) -> dict:
        """Retrieve using both vector search and graph expansion."""

        # 1. Standard vector retrieval
        vector_results = self.vector_retriever.retrieve(query_embedding, top_k=top_k)

        # 2. Extract entities from query (simplified)
        entities_in_query = extract_entities_from_query(query)

        # 3. Graph expansion: find related entities
        graph_context = []
        for entity in entities_in_query:
            related = self.graph.get_entity_community(entity)
            for related_entity in related[:5]:
                # Retrieve text about related entities
                related_emb = embed_text(related_entity)
                related_docs = self.vector_retriever.retrieve(related_emb, top_k=2)
                graph_context.extend(related_docs)

        # 4. If query involves an interaction, find the path
        if len(entities_in_query) >= 2:
            path = self.graph.find_interaction_path(entities_in_query[0], entities_in_query[1])
            if path:
                graph_context.append({
                    "content": format_graph_path(path[0]),
                    "source": "knowledge_graph",
                    "score": 0.9,
                })

        return {
            "vector_results": vector_results,
            "graph_context": graph_context,
        }


def extract_entities_from_query(query: str) -> list[str]:
    """Extract drug/entity names from a query using NER."""
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {
                "role": "user",
                "content": f"""Extract drug names and medical entity names from this query.
Return JSON: {{"entities": ["entity1", "entity2"]}}

Query: {query}""",
            }
        ],
        response_format={"type": "json_object"},
        temperature=0,
    )
    import json
    result = json.loads(response.choices[0].message.content)
    return result.get("entities", [])

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.