Graph RAG: Knowledge Graph-Enhanced Retrieval
Enhance RAG with knowledge graphs. GraphRAG by Microsoft, entity extraction, relationship indexing, and combining vector search with graph traversal.
Why Graph RAG
Standard vector RAG treats each document chunk independently. Knowledge graphs capture relationships between entities:
- Drug A inhibits Enzyme B
- Drug B is substrate of Enzyme B
- Therefore: Drug A raises levels of Drug B
This three-hop reasoning is impossible with standard RAG — it would require the exact text "Drug A raises Drug B levels" to be in a document. Graph RAG can reason over entity relationships.
Key advantages:
- Multi-hop reasoning across entity relationships
- Global summaries (what topics is this corpus about?)
- Entity-centric queries (everything about Drug A)
- Relationship traversal (what drugs interact with warfarin via CYP2C9?)
Entity and Relationship Extraction
The first step: extract entities and relationships from documents:
from openai import OpenAI
import json
from dataclasses import dataclass
client = OpenAI()
@dataclass
class Entity:
id: str
name: str
entity_type: str # "drug", "enzyme", "condition", "gene", etc.
description: str
aliases: list[str]
@dataclass
class Relationship:
source_id: str
target_id: str
relation_type: str # "inhibits", "activates", "treats", "causes", etc.
description: str
strength: str # "strong", "moderate", "weak"
evidence: str # Source text supporting this relationship
EXTRACTION_PROMPT = """Extract entities and relationships from this clinical pharmacology text.
Entities to identify:
- DRUG: medications, pharmaceutical agents
- ENZYME: metabolic enzymes (CYP450, UGT, etc.)
- CONDITION: diseases, disorders, symptoms
- PROTEIN: receptors, transporters
- GENE: genetic variants affecting drug metabolism
Relationships to identify:
- inhibits: Drug A inhibits Enzyme B
- induces: Drug A induces Enzyme B
- treats: Drug A treats Condition B
- causes: Drug A causes Condition B
- substrate_of: Drug A is metabolized by Enzyme B
- interacts_with: Drug A interacts with Drug B
Return JSON:
{
"entities": [{"id": "e1", "name": "...", "type": "DRUG", "description": "...", "aliases": []}],
"relationships": [{"source": "e1", "target": "e2", "type": "inhibits", "description": "...", "strength": "strong"}]
}"""
def extract_knowledge_graph(text: str) -> dict:
"""Extract entities and relationships from text."""
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": EXTRACTION_PROMPT},
{"role": "user", "content": f"Text:\n{text[:3000]}"},
],
response_format={"type": "json_object"},
temperature=0,
)
return json.loads(response.choices[0].message.content)
def build_knowledge_graph(documents: list[dict]) -> dict:
"""Extract and merge entity relationships from a corpus."""
all_entities = {} # id → entity
all_relationships = []
for doc in documents:
extracted = extract_knowledge_graph(doc["content"])
# Merge entities (deduplicate by name)
for entity in extracted.get("entities", []):
entity_key = entity["name"].lower()
if entity_key not in all_entities:
all_entities[entity_key] = entity
else:
# Merge aliases
existing = all_entities[entity_key]
existing["aliases"] = list(set(
existing.get("aliases", []) + entity.get("aliases", [])
))
# Collect relationships
for rel in extracted.get("relationships", []):
all_relationships.append({
**rel,
"source_document": doc["id"],
})
return {
"entities": list(all_entities.values()),
"relationships": all_relationships,
}Storing in Neo4j
from neo4j import GraphDatabase
class ClinicalKnowledgeGraph:
"""Neo4j-backed knowledge graph for clinical pharmacology."""
def __init__(self, uri: str, username: str, password: str):
self.driver = GraphDatabase.driver(uri, auth=(username, password))
def close(self):
self.driver.close()
def create_entity(self, entity: dict) -> None:
"""Create or update an entity node."""
with self.driver.session() as session:
session.run(
"""
MERGE (e:Entity {name: $name})
SET e.type = $type,
e.description = $description,
e.aliases = $aliases
WITH e
CALL apoc.create.addLabels(e, [$type]) YIELD node
RETURN node
""",
name=entity["name"],
type=entity.get("type", "Unknown"),
description=entity.get("description", ""),
aliases=entity.get("aliases", []),
)
def create_relationship(self, rel: dict) -> None:
"""Create a relationship between entities."""
with self.driver.session() as session:
session.run(
f"""
MATCH (a:Entity {{name: $source}})
MATCH (b:Entity {{name: $target}})
MERGE (a)-[r:{rel['type'].upper()}]->(b)
SET r.description = $description,
r.strength = $strength
""",
source=rel["source"],
target=rel["target"],
description=rel.get("description", ""),
strength=rel.get("strength", "unknown"),
)
def query_entity_relationships(self, entity_name: str, depth: int = 2) -> list[dict]:
"""Find all relationships within N hops of an entity."""
with self.driver.session() as session:
result = session.run(
f"""
MATCH path = (start:Entity {{name: $name}})-[*1..{depth}]-(related)
RETURN path
LIMIT 50
""",
name=entity_name,
)
return [record["path"] for record in result]
def find_interaction_path(self, drug_a: str, drug_b: str) -> list[dict]:
"""Find the shortest interaction path between two drugs."""
with self.driver.session() as session:
result = session.run(
"""
MATCH path = shortestPath(
(a:Entity {name: $drug_a})-[*..5]-(b:Entity {name: $drug_b})
)
RETURN path, length(path) as hops
ORDER BY hops
LIMIT 3
""",
drug_a=drug_a,
drug_b=drug_b,
)
return [{"path": r["path"], "hops": r["hops"]} for r in result]
def get_entity_community(self, entity_name: str) -> list[str]:
"""Find entities in the same community (densely connected cluster)."""
with self.driver.session() as session:
result = session.run(
"""
MATCH (e:Entity {name: $name})-[:INTERACTS_WITH|:INHIBITS|:SUBSTRATE_OF]-(related)
RETURN DISTINCT related.name as name, count(*) as connections
ORDER BY connections DESC
LIMIT 20
""",
name=entity_name,
)
return [r["name"] for r in result]Microsoft GraphRAG
Microsoft's GraphRAG builds hierarchical community summaries for global question-answering:
# Using Microsoft GraphRAG library
# pip install graphrag
# Setup
from graphrag.index import run_pipeline_with_config
from graphrag.query.api import global_search, local_search
# GraphRAG builds two indexes:
# 1. Local index: entity-level information + vector embeddings
# 2. Global index: community summaries at multiple levels of abstraction
# Local search: best for specific entity questions
# "What is the mechanism of warfarin?"
async def local_rag_query(query: str, data_dir: str) -> str:
"""GraphRAG local search — best for specific entity questions."""
result = await local_search(
config_path=f"{data_dir}/settings.yml",
query=query,
community_level=2,
)
return result.response
# Global search: best for holistic questions across the whole corpus
# "What are the main themes in this clinical pharmacology database?"
async def global_rag_query(query: str, data_dir: str) -> str:
"""GraphRAG global search — best for high-level questions."""
result = await global_search(
config_path=f"{data_dir}/settings.yml",
query=query,
community_level=2,
)
return result.responseHybrid: Vector + Graph Retrieval
Combine standard vector search with graph traversal:
class HybridGraphRetriever:
"""Combines vector retrieval with graph-based entity expansion."""
def __init__(self, vector_retriever, knowledge_graph: ClinicalKnowledgeGraph):
self.vector_retriever = vector_retriever
self.graph = knowledge_graph
def retrieve(
self,
query: str,
query_embedding: list[float],
top_k: int = 5,
) -> dict:
"""Retrieve using both vector search and graph expansion."""
# 1. Standard vector retrieval
vector_results = self.vector_retriever.retrieve(query_embedding, top_k=top_k)
# 2. Extract entities from query (simplified)
entities_in_query = extract_entities_from_query(query)
# 3. Graph expansion: find related entities
graph_context = []
for entity in entities_in_query:
related = self.graph.get_entity_community(entity)
for related_entity in related[:5]:
# Retrieve text about related entities
related_emb = embed_text(related_entity)
related_docs = self.vector_retriever.retrieve(related_emb, top_k=2)
graph_context.extend(related_docs)
# 4. If query involves an interaction, find the path
if len(entities_in_query) >= 2:
path = self.graph.find_interaction_path(entities_in_query[0], entities_in_query[1])
if path:
graph_context.append({
"content": format_graph_path(path[0]),
"source": "knowledge_graph",
"score": 0.9,
})
return {
"vector_results": vector_results,
"graph_context": graph_context,
}
def extract_entities_from_query(query: str) -> list[str]:
"""Extract drug/entity names from a query using NER."""
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": f"""Extract drug names and medical entity names from this query.
Return JSON: {{"entities": ["entity1", "entity2"]}}
Query: {query}""",
}
],
response_format={"type": "json_object"},
temperature=0,
)
import json
result = json.loads(response.choices[0].message.content)
return result.get("entities", [])Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.