Skip to content
114 changes: 107 additions & 7 deletions graphrag_sdk/src/graphrag_sdk/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ChatMessage,
DeleteDocumentResult,
DocumentInfo,
EntityType,
FinalizeResult,
GraphSchema,
IngestionResult,
Expand All @@ -35,6 +36,9 @@
SentenceTokenCapChunking,
)
from graphrag_sdk.ingestion.extraction_strategies.base import ExtractionStrategy
from graphrag_sdk.ingestion.extraction_strategies.entity_extractors import (
DEFAULT_ENTITY_TYPES,
)
from graphrag_sdk.ingestion.extraction_strategies.graph_extraction import GraphExtraction
from graphrag_sdk.ingestion.loaders.base import LoaderStrategy
from graphrag_sdk.ingestion.loaders.markdown_loader import MarkdownLoader
Expand All @@ -48,6 +52,7 @@
from graphrag_sdk.retrieval.strategies.multi_path import MultiPathRetrieval
from graphrag_sdk.storage.deduplicator import EntityDeduplicator
from graphrag_sdk.storage.graph_store import GraphStore
from graphrag_sdk.storage.ontology_store import OntologyStore
from graphrag_sdk.storage.vector_store import VectorStore

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -194,12 +199,23 @@ def __init__(
# Deduplication engine
self._deduplicator = EntityDeduplicator(self._graph_store, self.embedder)

# Persistent ontology graph (``<data_graph>__ontology``). Always-on,
# always the anchor: ``self.schema`` is registered into it on first
# connection, and ``get_ontology()`` / retrieval always read from it.
self._ontology_store = OntologyStore(self._conn, self._conn.config.graph_name)
# Lazy-init flag; the first async call that needs the schema fires
# ``_ensure_ontology_initialized()`` to load + register the user's schema.
self._ontology_initialized = False
# Working schema used by retrieval; populated by ``_ensure_ontology_initialized()``.
self._global_schema: GraphSchema = self.schema

# Default retrieval strategy
self._retrieval_strategy = retrieval_strategy or MultiPathRetrieval(
graph_store=self._graph_store,
vector_store=self._vector_store,
embedder=self.embedder,
llm=self.llm,
schema=self._global_schema,
)

# -- Async context manager -------------------------------------------
Expand Down Expand Up @@ -233,6 +249,76 @@ async def close(self) -> None:
"""Close the underlying database connection."""
await self._conn.close()

# ── Ontology ─────────────────────────────────────────────────

async def _ensure_ontology_initialized(self) -> None:
"""Lazy first-touch: load the persisted ontology and register the
user-supplied :py:attr:`schema` into it.

Three states:

- ``self.schema`` is non-empty → register it (validate + persist).
- ``self.schema`` empty but the ontology graph already has content
(previous session, another writer) → use it as-is.
- Both empty (first connection, no user schema) → register the
built-in :py:data:`DEFAULT_ENTITY_TYPES` so the ontology graph
accurately reflects the labels the extractor will produce.

Idempotent. Raises :py:class:`OntologyContradictionError` if
``self.schema`` re-defines an existing property's type.
"""
if self._ontology_initialized:
return
loaded = await self._ontology_store.load()
if self.schema.entities or self.schema.relations:
self._global_schema = await self._ontology_store.register(self.schema)
elif loaded.entities or loaded.relations:
self._global_schema = loaded
else:
default_schema = GraphSchema(
entities=[EntityType(label=label) for label in DEFAULT_ENTITY_TYPES],
)
self._global_schema = await self._ontology_store.register(default_schema)
if hasattr(self._retrieval_strategy, "_schema"):
self._retrieval_strategy._schema = self._global_schema
self._ontology_initialized = True

async def get_ontology(self) -> GraphSchema:
"""Return the persisted global ontology.

Reads from the ontology graph — the single source of truth. On the
first call, also registers any ``schema`` passed to :py:class:`GraphRAG`
into the ontology graph (validating no type contradictions with what's
already persisted).
"""
await self._ensure_ontology_initialized()
# Always reflect the latest persisted state in case another writer
# has registered new types since our last ensure_initialized.
loaded = await self._ontology_store.load()
self._global_schema = loaded
if hasattr(self._retrieval_strategy, "_schema"):
self._retrieval_strategy._schema = self._global_schema
return self._global_schema

async def refresh_ontology(self) -> GraphSchema:
"""Reload the global ontology and propagate it to the retrieval path.

Call explicitly when another process has registered new schema and
you want the next retrieval to see it without re-ingesting first.
"""
return await self.get_ontology()

async def save_ontology(self, path: str, *, indent: int = 2) -> None:
"""Write the current global ontology to ``path`` as JSON.

Bridges the persisted ontology graph to a versionable JSON artifact:
``rag.save_ontology("ontology.json")``, hand-edit / version-control
it, then load with ``GraphSchema.from_file("ontology.json")`` on the
next run. The ontology graph remains the canonical copy.
"""
ontology = await self.get_ontology()
ontology.save_to_file(path, indent=indent)

# ── Graph admin ──────────────────────────────────────────────

async def get_statistics(self) -> dict[str, Any]:
Expand All @@ -244,20 +330,29 @@ async def get_statistics(self) -> dict[str, Any]:
return await self._graph_store.get_statistics()

async def delete_all(self) -> None:
"""Drop the entire knowledge graph.
"""Drop the entire knowledge graph (data + ontology).

Irreversible. Removes all nodes, relationships, and indexes
managed by this ``GraphRAG`` instance. Also invalidates the
cached config and index flags so a follow-up ``ingest()`` on
the same instance re-runs validation and re-creates indexes
instead of trusting stale state.
Irreversible. Removes all nodes, relationships, and indexes managed
by this ``GraphRAG`` instance, plus the paired ontology graph at
``<graph_name>__ontology``. Also invalidates the cached config and
index flags so a follow-up ``ingest()`` on the same instance re-runs
validation, re-creates indexes, and re-registers the user's schema.
"""
await self._graph_store.delete_all()
# Drop the ontology graph alongside the data graph so the two never
# outlive each other; a fresh ingest re-registers self.schema from
# scratch via _ensure_ontology_initialized().
try:
await self._ontology_store.clear()
except Exception as exc:
logger.warning("Ontology graph clear failed during delete_all (continuing): %s", exc)
# Indexes were dropped along with the graph; force re-creation
# on the next ensure_indices() call.
self._vector_store._indices_ensured = False
# The __GraphRAGConfig__ node is gone too; re-validate next time.
self._config_validated = False
# Force re-registration of self.schema next call.
self._ontology_initialized = False

# ── Ingestion ────────────────────────────────────────────────

Expand Down Expand Up @@ -532,14 +627,19 @@ async def _ingest_single(

doc_info = DocumentInfo(uid=resolved_id, path=path_for_node)

# Load the persisted ontology and register self.schema before the
# pipeline runs — contradictions surface early, before any expensive
# extraction work is done.
await self._ensure_ontology_initialized()

pipeline = IngestionPipeline(
loader=loader or TextLoader(),
chunker=chunker or SentenceTokenCapChunking(),
extractor=extractor or self._default_extractor(),
resolver=resolver or ExactMatchResolution(),
graph_store=self._graph_store,
vector_store=self._vector_store,
schema=self.schema,
schema=self._global_schema,
)

result = await pipeline.run(source, ctx, text=text, document_info=doc_info)
Expand Down
168 changes: 147 additions & 21 deletions graphrag_sdk/src/graphrag_sdk/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,44 @@ class DocumentRecord(DataModel):
# ── Schema Types ─────────────────────────────────────────────────


_PROPERTY_TYPES: frozenset[str] = frozenset(
{"STRING", "INTEGER", "FLOAT", "BOOLEAN", "DATE", "LIST"}
)

# Property names reserved by the SDK; user schemas may not declare them.
RESERVED_PROPERTY_NAMES: frozenset[str] = frozenset(
{
"name",
"type",
"description",
"source_chunk_ids",
"spans",
"rel_type",
"fact",
"src_name",
"tgt_name",
"id",
}
)


class PropertyType(DataModel):
"""A property definition for a node or relationship type."""

name: str
type: str = "STRING" # STRING, INTEGER, FLOAT, BOOLEAN, DATE, LIST
description: str | None = None
required: bool = False

@model_validator(mode="after")
def _normalize_type(self) -> PropertyType:
normalized = (self.type or "STRING").strip().upper()
if normalized not in _PROPERTY_TYPES:
raise ValueError(
f"PropertyType '{self.name}' has unsupported type "
f"{self.type!r}. Allowed: {sorted(_PROPERTY_TYPES)}"
)
self.type = normalized
return self


class EntityType(DataModel):
Expand Down Expand Up @@ -185,6 +216,7 @@ class RelationType(DataModel):
label: str
description: str | None = None
patterns: list[tuple[str, str]] = Field(default_factory=list)
properties: list[PropertyType] = Field(default_factory=list)

# Identity is by label only — two RelationType instances with the same
# label but different patterns compare/hash equal. Schemas are expected
Expand All @@ -209,30 +241,122 @@ class GraphSchema(DataModel):
relations: list[RelationType] = Field(default_factory=list)

@model_validator(mode="after")
def _warn_on_undeclared_pattern_labels(self) -> GraphSchema:
"""Warn when a ``RelationType.patterns`` references undeclared entity labels.

Catches typos like ``("Persn", "Company")`` at config time, before any
extraction has run. We warn rather than raise: open-schema setups may
legitimately reference labels not (yet) listed in ``entities``.
def _validate_schema(self) -> GraphSchema:
"""Validate the schema at config time.

- Rejects user-declared property names that collide with SDK-reserved
keys (``name``, ``description``, ``source_chunk_ids``, ``spans``,
``rel_type``, ``fact``, ``src_name``, ``tgt_name``, ``id``, ``type``).
- Warns when a ``RelationType.patterns`` references undeclared entity
labels (typo guard) — open-schema setups may legitimately reference
labels not yet listed in ``entities``, so this is a warning, not an
error.
"""
if not self.entities:
return self
declared = {e.label for e in self.entities}
for rel in self.relations:
for src, tgt in rel.patterns:
missing = [lbl for lbl in (src, tgt) if lbl not in declared]
if missing:
logger.warning(
"RelationType '%s' pattern (%s, %s) references "
"entity label(s) not declared in schema.entities: %s",
rel.label,
src,
tgt,
", ".join(missing),
for et in self.entities:
for prop in et.properties:
if prop.name in RESERVED_PROPERTY_NAMES:
raise ValueError(
f"EntityType '{et.label}' declares property "
f"'{prop.name}', which is reserved by the SDK. "
f"Reserved names: {sorted(RESERVED_PROPERTY_NAMES)}"
)
for rt in self.relations:
for prop in rt.properties:
if prop.name in RESERVED_PROPERTY_NAMES:
raise ValueError(
f"RelationType '{rt.label}' declares property "
f"'{prop.name}', which is reserved by the SDK. "
f"Reserved names: {sorted(RESERVED_PROPERTY_NAMES)}"
)

if self.entities:
declared = {e.label for e in self.entities}
for rel in self.relations:
for src, tgt in rel.patterns:
missing = [lbl for lbl in (src, tgt) if lbl not in declared]
if missing:
logger.warning(
"RelationType '%s' pattern (%s, %s) references "
"entity label(s) not declared in schema.entities: %s",
rel.label,
src,
tgt,
", ".join(missing),
)
return self

@classmethod
def from_file(cls, path: str) -> GraphSchema:
"""Load a ``GraphSchema`` from a JSON file.

The schema-as-config workflow: keep the canonical schema in a JSON
file under version control, load it into the SDK with one call. See
:py:meth:`save_to_file` for the reverse direction.
"""
from pathlib import Path

return cls.model_validate_json(Path(path).read_text(encoding="utf-8"))

def save_to_file(self, path: str, *, indent: int = 2) -> None:
"""Write this schema to ``path`` as JSON (overwrites existing files)."""
from pathlib import Path

Path(path).write_text(self.model_dump_json(indent=indent), encoding="utf-8")

def merge(self, other: GraphSchema) -> GraphSchema:
"""Return a new ``GraphSchema`` that is the union of ``self`` and ``other``.

- Entity / relation types are unioned by ``label``.
- For each type, ``properties`` are unioned by ``name``. When the same
property name appears in both, the incoming type/description overrides
(last-write-wins).
- For relations, ``patterns`` are unioned (order-preserving, deduped).
"""

def _merge_props(
existing: list[PropertyType], incoming: list[PropertyType]
) -> list[PropertyType]:
by_name: dict[str, PropertyType] = {p.name: p for p in existing}
for p in incoming:
by_name[p.name] = p
return list(by_name.values())

ent_by_label: dict[str, EntityType] = {e.label: e for e in self.entities}
for e in other.entities:
if e.label in ent_by_label:
cur = ent_by_label[e.label]
ent_by_label[e.label] = EntityType(
label=cur.label,
description=e.description or cur.description,
properties=_merge_props(cur.properties, e.properties),
Comment thread
galshubeli marked this conversation as resolved.
)
else:
ent_by_label[e.label] = e

rel_by_label: dict[str, RelationType] = {r.label: r for r in self.relations}
for r in other.relations:
if r.label in rel_by_label:
cur = rel_by_label[r.label]
seen: set[tuple[str, str]] = set()
merged_patterns: list[tuple[str, str]] = []
for pat in list(cur.patterns) + list(r.patterns):
if pat not in seen:
seen.add(pat)
merged_patterns.append(pat)
rel_by_label[r.label] = RelationType(
label=cur.label,
description=r.description or cur.description,
patterns=merged_patterns,
properties=_merge_props(cur.properties, r.properties),
)
else:
rel_by_label[r.label] = r

return GraphSchema(
entities=list(ent_by_label.values()),
relations=list(rel_by_label.values()),
)


# ── Extraction / Resolution Output Types ─────────────────────────

Expand All @@ -254,6 +378,7 @@ class ExtractedEntity(DataModel):
type: str
description: str = ""
source_chunk_ids: list[str] = Field(default_factory=list)
attributes: dict[str, Any] = Field(default_factory=dict)


class ExtractedRelation(DataModel):
Expand All @@ -266,6 +391,7 @@ class ExtractedRelation(DataModel):
description: str = ""
weight: float = 1.0
source_chunk_ids: list[str] = Field(default_factory=list)
attributes: dict[str, Any] = Field(default_factory=dict)


class EntityMention(DataModel):
Expand Down
Loading
Loading