Skip to content
63 changes: 63 additions & 0 deletions graphrag_sdk/src/graphrag_sdk/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from graphrag_sdk.retrieval.strategies.multi_path import MultiPathRetrieval
from graphrag_sdk.storage.deduplicator import EntityDeduplicator
from graphrag_sdk.storage.graph_store import GraphStore
from graphrag_sdk.storage.ontology_store import OntologyStore
from graphrag_sdk.storage.vector_store import VectorStore

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -194,12 +195,23 @@ def __init__(
# Deduplication engine
self._deduplicator = EntityDeduplicator(self._graph_store, self.embedder)

# Ontology is inferred from the live data graph (no separate
# persistent graph). User-supplied ``schema`` is merged on top at
# retrieval time so declared descriptions / required flags / not-yet-
# extracted properties survive.
self._ontology_store = OntologyStore(self._conn)
# Global ontology used at retrieval time. Initially just the local
# schema; refresh_ontology() merges in the inferred view on demand
# and after each ingest.
self._global_schema: GraphSchema = self.schema

# Default retrieval strategy
self._retrieval_strategy = retrieval_strategy or MultiPathRetrieval(
graph_store=self._graph_store,
vector_store=self._vector_store,
embedder=self.embedder,
llm=self.llm,
schema=self._global_schema,
)

# -- Async context manager -------------------------------------------
Expand Down Expand Up @@ -233,6 +245,50 @@ async def close(self) -> None:
"""Close the underlying database connection."""
await self._conn.close()

# ── Ontology ─────────────────────────────────────────────────

async def get_ontology(self) -> GraphSchema:
"""Return the **global** ontology used at retrieval time.

The global ontology is the union of:
- the ontology inferred from the live data graph
(``db.labels()`` / ``db.relationshipTypes()`` + property sampling), and
- the user-supplied ``schema`` passed to :py:class:`GraphRAG`
(carries descriptions, ``required`` flags, properties not yet
present in the data).

Returns an empty schema when both are empty.
"""
inferred = await self._ontology_store.infer()
if inferred.entities or inferred.relations:
return inferred.merge(self.schema)
return self.schema

async def refresh_ontology(self) -> GraphSchema:
"""Recompute the global ontology and propagate it to the retrieval path.

Called automatically after each ``ingest()``. Call explicitly when
the data graph has changed under your feet (concurrent writer,
manual edits) and the next retrieval should see it.
"""
self._global_schema = await self.get_ontology()
# Best-effort propagation to the retrieval strategy. Built-in
# MultiPathRetrieval honours this; custom strategies opt in by exposing
# a ``_schema`` attribute or accepting it via constructor.
if hasattr(self._retrieval_strategy, "_schema"):
self._retrieval_strategy._schema = self._global_schema
return self._global_schema

async def save_ontology(self, path: str, *, indent: int = 2) -> None:
"""Write the current global ontology to ``path`` as JSON.

Convenience for the schema-as-config workflow: ``rag.save_ontology(
"ontology.json")``, hand-edit / version-control it, then load with
``GraphSchema.from_file("ontology.json")`` on the next run.
"""
ontology = await self.get_ontology()
ontology.save_to_file(path, indent=indent)

# ── Graph admin ──────────────────────────────────────────────

async def get_statistics(self) -> dict[str, Any]:
Expand Down Expand Up @@ -544,6 +600,13 @@ async def _ingest_single(

result = await pipeline.run(source, ctx, text=text, document_info=doc_info)

# Refresh the global ontology so the next retrieval call sees any
# new properties declared by this run.
try:
await self.refresh_ontology()
except Exception as exc:
logger.warning("Ontology refresh failed (continuing): %s", exc)

if not _skip_post:
# Post-ingestion: create indices only.
# backfill_entity_embeddings() is intentionally NOT called here —
Expand Down
168 changes: 148 additions & 20 deletions graphrag_sdk/src/graphrag_sdk/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,27 @@ class DocumentRecord(DataModel):
# ── Schema Types ─────────────────────────────────────────────────


_PROPERTY_TYPES: frozenset[str] = frozenset(
{"STRING", "INTEGER", "FLOAT", "BOOLEAN", "DATE", "LIST"}
)

# Property names reserved by the SDK; user schemas may not declare them.
RESERVED_PROPERTY_NAMES: frozenset[str] = frozenset(
{
"name",
"type",
"description",
"source_chunk_ids",
"spans",
"rel_type",
"fact",
"src_name",
"tgt_name",
"id",
}
)


class PropertyType(DataModel):
"""A property definition for a node or relationship type."""

Expand All @@ -144,6 +165,17 @@ class PropertyType(DataModel):
description: str | None = None
required: bool = False

@model_validator(mode="after")
def _normalize_type(self) -> PropertyType:
normalized = (self.type or "STRING").strip().upper()
if normalized not in _PROPERTY_TYPES:
raise ValueError(
f"PropertyType '{self.name}' has unsupported type "
f"{self.type!r}. Allowed: {sorted(_PROPERTY_TYPES)}"
)
self.type = normalized
return self


class EntityType(DataModel):
"""Definition of a node/entity type in the graph schema."""
Expand Down Expand Up @@ -185,6 +217,7 @@ class RelationType(DataModel):
label: str
description: str | None = None
patterns: list[tuple[str, str]] = Field(default_factory=list)
properties: list[PropertyType] = Field(default_factory=list)

# Identity is by label only — two RelationType instances with the same
# label but different patterns compare/hash equal. Schemas are expected
Expand All @@ -209,30 +242,123 @@ class GraphSchema(DataModel):
relations: list[RelationType] = Field(default_factory=list)

@model_validator(mode="after")
def _warn_on_undeclared_pattern_labels(self) -> GraphSchema:
"""Warn when a ``RelationType.patterns`` references undeclared entity labels.

Catches typos like ``("Persn", "Company")`` at config time, before any
extraction has run. We warn rather than raise: open-schema setups may
legitimately reference labels not (yet) listed in ``entities``.
def _validate_schema(self) -> GraphSchema:
"""Validate the schema at config time.

- Rejects user-declared property names that collide with SDK-reserved
keys (``name``, ``description``, ``source_chunk_ids``, ``spans``,
``rel_type``, ``fact``, ``src_name``, ``tgt_name``, ``id``, ``type``).
- Warns when a ``RelationType.patterns`` references undeclared entity
labels (typo guard) — open-schema setups may legitimately reference
labels not yet listed in ``entities``, so this is a warning, not an
error.
"""
if not self.entities:
return self
declared = {e.label for e in self.entities}
for rel in self.relations:
for src, tgt in rel.patterns:
missing = [lbl for lbl in (src, tgt) if lbl not in declared]
if missing:
logger.warning(
"RelationType '%s' pattern (%s, %s) references "
"entity label(s) not declared in schema.entities: %s",
rel.label,
src,
tgt,
", ".join(missing),
for et in self.entities:
for prop in et.properties:
if prop.name in RESERVED_PROPERTY_NAMES:
raise ValueError(
f"EntityType '{et.label}' declares property "
f"'{prop.name}', which is reserved by the SDK. "
f"Reserved names: {sorted(RESERVED_PROPERTY_NAMES)}"
)
for rt in self.relations:
for prop in rt.properties:
if prop.name in RESERVED_PROPERTY_NAMES:
raise ValueError(
f"RelationType '{rt.label}' declares property "
f"'{prop.name}', which is reserved by the SDK. "
f"Reserved names: {sorted(RESERVED_PROPERTY_NAMES)}"
)

if self.entities:
declared = {e.label for e in self.entities}
for rel in self.relations:
for src, tgt in rel.patterns:
missing = [lbl for lbl in (src, tgt) if lbl not in declared]
if missing:
logger.warning(
"RelationType '%s' pattern (%s, %s) references "
"entity label(s) not declared in schema.entities: %s",
rel.label,
src,
tgt,
", ".join(missing),
)
return self

@classmethod
def from_file(cls, path: str) -> GraphSchema:
"""Load a ``GraphSchema`` from a JSON file.

The schema-as-config workflow: keep the canonical schema in a JSON
file under version control, load it into the SDK with one call. See
:py:meth:`save_to_file` for the reverse direction.
"""
from pathlib import Path

return cls.model_validate_json(Path(path).read_text(encoding="utf-8"))

def save_to_file(self, path: str, *, indent: int = 2) -> None:
"""Write this schema to ``path`` as JSON (overwrites existing files)."""
from pathlib import Path

Path(path).write_text(self.model_dump_json(indent=indent), encoding="utf-8")

def merge(self, other: GraphSchema) -> GraphSchema:
"""Return a new ``GraphSchema`` that is the union of ``self`` and ``other``.

- Entity / relation types are unioned by ``label``.
- For each type, ``properties`` are unioned by ``name``. When the same
property name appears in both, the incoming type/description/required
overrides — last-write-wins, matching the persisted ontology's
register() semantics.
- For relations, ``patterns`` are unioned (order-preserving, deduped).
"""

def _merge_props(
existing: list[PropertyType], incoming: list[PropertyType]
) -> list[PropertyType]:
by_name: dict[str, PropertyType] = {p.name: p for p in existing}
for p in incoming:
by_name[p.name] = p
return list(by_name.values())

ent_by_label: dict[str, EntityType] = {e.label: e for e in self.entities}
for e in other.entities:
if e.label in ent_by_label:
cur = ent_by_label[e.label]
ent_by_label[e.label] = EntityType(
label=cur.label,
description=e.description or cur.description,
properties=_merge_props(cur.properties, e.properties),
Comment thread
galshubeli marked this conversation as resolved.
)
else:
ent_by_label[e.label] = e

rel_by_label: dict[str, RelationType] = {r.label: r for r in self.relations}
for r in other.relations:
if r.label in rel_by_label:
cur = rel_by_label[r.label]
seen: set[tuple[str, str]] = set()
merged_patterns: list[tuple[str, str]] = []
for pat in list(cur.patterns) + list(r.patterns):
if pat not in seen:
seen.add(pat)
merged_patterns.append(pat)
rel_by_label[r.label] = RelationType(
label=cur.label,
description=r.description or cur.description,
patterns=merged_patterns,
properties=_merge_props(cur.properties, r.properties),
)
else:
rel_by_label[r.label] = r

return GraphSchema(
entities=list(ent_by_label.values()),
relations=list(rel_by_label.values()),
)


# ── Extraction / Resolution Output Types ─────────────────────────

Expand All @@ -254,6 +380,7 @@ class ExtractedEntity(DataModel):
type: str
description: str = ""
source_chunk_ids: list[str] = Field(default_factory=list)
attributes: dict[str, Any] = Field(default_factory=dict)


class ExtractedRelation(DataModel):
Expand All @@ -266,6 +393,7 @@ class ExtractedRelation(DataModel):
description: str = ""
weight: float = 1.0
source_chunk_ids: list[str] = Field(default_factory=list)
attributes: dict[str, Any] = Field(default_factory=dict)


class EntityMention(DataModel):
Expand Down
Loading
Loading