Skip to content
59 changes: 59 additions & 0 deletions graphrag_sdk/src/graphrag_sdk/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from graphrag_sdk.retrieval.strategies.multi_path import MultiPathRetrieval
from graphrag_sdk.storage.deduplicator import EntityDeduplicator
from graphrag_sdk.storage.graph_store import GraphStore
from graphrag_sdk.storage.ontology_store import OntologyStore
from graphrag_sdk.storage.vector_store import VectorStore

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -192,12 +193,21 @@ def __init__(
# Deduplication engine
self._deduplicator = EntityDeduplicator(self._graph_store, self.embedder)

# Persistent ontology lives in `<data_graph>__ontology`. Stays in sync
# across multiple ingest passes that may declare different schemas.
self._ontology_store = OntologyStore(self._conn, self._conn.config.graph_name)
# Global ontology used at retrieval time. Initially the user-supplied
# local schema; refreshed from the persisted ontology after each ingest
# (see refresh_ontology()) and on demand.
self._global_schema: GraphSchema = self.schema

Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
# Default retrieval strategy
self._retrieval_strategy = retrieval_strategy or MultiPathRetrieval(
graph_store=self._graph_store,
vector_store=self._vector_store,
embedder=self.embedder,
llm=self.llm,
schema=self._global_schema,
)

# -- Async context manager -------------------------------------------
Expand Down Expand Up @@ -231,6 +241,37 @@ async def close(self) -> None:
"""Close the underlying database connection."""
await self._conn.close()

# ── Ontology ─────────────────────────────────────────────────

async def get_ontology(self) -> GraphSchema:
"""Return the persisted **global** ontology (union of every schema
ever registered against this graph).

Reads from the dedicated ``<data_graph>__ontology`` graph. Returns an
empty schema before the first ingest has run.
"""
return await self._ontology_store.load()

async def refresh_ontology(self) -> GraphSchema:
"""Reload the global ontology and propagate it to the retrieval path.

Called automatically after each ``ingest()``. Call explicitly when
another process has registered new schema and you want the next
retrieval to see it without re-ingesting first.
"""
loaded = await self._ontology_store.load()
if loaded.entities or loaded.relations:
self._global_schema = loaded
else:
# No persisted ontology yet — keep the user-provided local schema.
self._global_schema = self.schema
# Best-effort propagation to the retrieval strategy. Built-in
# MultiPathRetrieval honours this; custom strategies opt in by exposing
# a ``_schema`` attribute or accepting it via constructor.
if hasattr(self._retrieval_strategy, "_schema"):
self._retrieval_strategy._schema = self._global_schema
return self._global_schema

# ── Graph admin ──────────────────────────────────────────────

async def get_statistics(self) -> dict[str, Any]:
Expand Down Expand Up @@ -527,6 +568,17 @@ async def _ingest_single(

doc_info = DocumentInfo(uid=resolved_id, path=path_for_node)

# Register this run's local schema into the persisted ontology so the
# global ontology is the union of every schema ever registered. The
# local schema continues to drive *this* run's extraction.
if self.schema.entities or self.schema.relations:
try:
await self._ontology_store.register(self.schema)
except Exception as exc:
logger.warning(
"Ontology registration failed (continuing ingest): %s", exc
)

pipeline = IngestionPipeline(
loader=loader or TextLoader(),
chunker=chunker or FixedSizeChunking(),
Expand All @@ -539,6 +591,13 @@ async def _ingest_single(

result = await pipeline.run(source, ctx, text=text, document_info=doc_info)

# Refresh the global ontology so the next retrieval call sees any
# new properties declared by this run.
try:
await self.refresh_ontology()
except Exception as exc:
logger.warning("Ontology refresh failed (continuing): %s", exc)

if not _skip_post:
# Post-ingestion: create indices only.
# backfill_entity_embeddings() is intentionally NOT called here —
Expand Down
150 changes: 130 additions & 20 deletions graphrag_sdk/src/graphrag_sdk/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,27 @@ class DocumentRecord(DataModel):
# ── Schema Types ─────────────────────────────────────────────────


_PROPERTY_TYPES: frozenset[str] = frozenset(
{"STRING", "INTEGER", "FLOAT", "BOOLEAN", "DATE", "LIST"}
)

# Property names reserved by the SDK; user schemas may not declare them.
RESERVED_PROPERTY_NAMES: frozenset[str] = frozenset(
{
"name",
"type",
"description",
"source_chunk_ids",
"spans",
"rel_type",
"fact",
"src_name",
"tgt_name",
"id",
}
)


class PropertyType(DataModel):
"""A property definition for a node or relationship type."""

Expand All @@ -144,6 +165,17 @@ class PropertyType(DataModel):
description: str | None = None
required: bool = False

@model_validator(mode="after")
def _normalize_type(self) -> PropertyType:
normalized = (self.type or "STRING").strip().upper()
if normalized not in _PROPERTY_TYPES:
raise ValueError(
f"PropertyType '{self.name}' has unsupported type "
f"{self.type!r}. Allowed: {sorted(_PROPERTY_TYPES)}"
)
self.type = normalized
return self


class EntityType(DataModel):
"""Definition of a node/entity type in the graph schema."""
Expand Down Expand Up @@ -185,6 +217,7 @@ class RelationType(DataModel):
label: str
description: str | None = None
patterns: list[tuple[str, str]] = Field(default_factory=list)
properties: list[PropertyType] = Field(default_factory=list)

# Identity is by label only — two RelationType instances with the same
# label but different patterns compare/hash equal. Schemas are expected
Expand All @@ -209,30 +242,105 @@ class GraphSchema(DataModel):
relations: list[RelationType] = Field(default_factory=list)

@model_validator(mode="after")
def _warn_on_undeclared_pattern_labels(self) -> GraphSchema:
"""Warn when a ``RelationType.patterns`` references undeclared entity labels.

Catches typos like ``("Persn", "Company")`` at config time, before any
extraction has run. We warn rather than raise: open-schema setups may
legitimately reference labels not (yet) listed in ``entities``.
def _validate_schema(self) -> GraphSchema:
"""Validate the schema at config time.

- Rejects user-declared property names that collide with SDK-reserved
keys (``name``, ``description``, ``source_chunk_ids``, ``spans``,
``rel_type``, ``fact``, ``src_name``, ``tgt_name``, ``id``, ``type``).
- Warns when a ``RelationType.patterns`` references undeclared entity
labels (typo guard) — open-schema setups may legitimately reference
labels not yet listed in ``entities``, so this is a warning, not an
error.
"""
if not self.entities:
return self
declared = {e.label for e in self.entities}
for rel in self.relations:
for src, tgt in rel.patterns:
missing = [lbl for lbl in (src, tgt) if lbl not in declared]
if missing:
logger.warning(
"RelationType '%s' pattern (%s, %s) references "
"entity label(s) not declared in schema.entities: %s",
rel.label,
src,
tgt,
", ".join(missing),
for et in self.entities:
for prop in et.properties:
if prop.name in RESERVED_PROPERTY_NAMES:
raise ValueError(
f"EntityType '{et.label}' declares property "
f"'{prop.name}', which is reserved by the SDK. "
f"Reserved names: {sorted(RESERVED_PROPERTY_NAMES)}"
)
for rt in self.relations:
for prop in rt.properties:
if prop.name in RESERVED_PROPERTY_NAMES:
raise ValueError(
f"RelationType '{rt.label}' declares property "
f"'{prop.name}', which is reserved by the SDK. "
f"Reserved names: {sorted(RESERVED_PROPERTY_NAMES)}"
)

if self.entities:
declared = {e.label for e in self.entities}
for rel in self.relations:
for src, tgt in rel.patterns:
missing = [lbl for lbl in (src, tgt) if lbl not in declared]
if missing:
logger.warning(
"RelationType '%s' pattern (%s, %s) references "
"entity label(s) not declared in schema.entities: %s",
rel.label,
src,
tgt,
", ".join(missing),
)
return self

def merge(self, other: GraphSchema) -> GraphSchema:
"""Return a new ``GraphSchema`` that is the union of ``self`` and ``other``.

- Entity / relation types are unioned by ``label``.
- For each type, ``properties`` are unioned by ``name``. When the same
property name appears in both, the incoming type/description/required
overrides — last-write-wins, matching the persisted ontology's
register() semantics.
- For relations, ``patterns`` are unioned (order-preserving, deduped).
"""

def _merge_props(
existing: list[PropertyType], incoming: list[PropertyType]
) -> list[PropertyType]:
by_name: dict[str, PropertyType] = {p.name: p for p in existing}
for p in incoming:
by_name[p.name] = p
return list(by_name.values())

ent_by_label: dict[str, EntityType] = {e.label: e for e in self.entities}
for e in other.entities:
if e.label in ent_by_label:
cur = ent_by_label[e.label]
ent_by_label[e.label] = EntityType(
label=cur.label,
description=e.description or cur.description,
properties=_merge_props(cur.properties, e.properties),
Comment thread
galshubeli marked this conversation as resolved.
)
else:
ent_by_label[e.label] = e

rel_by_label: dict[str, RelationType] = {r.label: r for r in self.relations}
for r in other.relations:
if r.label in rel_by_label:
cur = rel_by_label[r.label]
seen: set[tuple[str, str]] = set()
merged_patterns: list[tuple[str, str]] = []
for pat in list(cur.patterns) + list(r.patterns):
if pat not in seen:
seen.add(pat)
merged_patterns.append(pat)
rel_by_label[r.label] = RelationType(
label=cur.label,
description=r.description or cur.description,
patterns=merged_patterns,
properties=_merge_props(cur.properties, r.properties),
)
else:
rel_by_label[r.label] = r

return GraphSchema(
entities=list(ent_by_label.values()),
relations=list(rel_by_label.values()),
)


# ── Extraction / Resolution Output Types ─────────────────────────

Expand All @@ -254,6 +362,7 @@ class ExtractedEntity(DataModel):
type: str
description: str = ""
source_chunk_ids: list[str] = Field(default_factory=list)
attributes: dict[str, Any] = Field(default_factory=dict)


class ExtractedRelation(DataModel):
Expand All @@ -266,6 +375,7 @@ class ExtractedRelation(DataModel):
description: str = ""
weight: float = 1.0
source_chunk_ids: list[str] = Field(default_factory=list)
attributes: dict[str, Any] = Field(default_factory=dict)


class EntityMention(DataModel):
Expand Down
Loading
Loading