Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions graphrag_sdk/src/graphrag_sdk/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from graphrag_sdk.retrieval.strategies.multi_path import MultiPathRetrieval
from graphrag_sdk.storage.deduplicator import EntityDeduplicator
from graphrag_sdk.storage.graph_store import GraphStore
from graphrag_sdk.storage.ontology_store import OntologyStore
from graphrag_sdk.storage.vector_store import VectorStore

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -192,12 +193,23 @@ def __init__(
# Deduplication engine
self._deduplicator = EntityDeduplicator(self._graph_store, self.embedder)

# Ontology is inferred from the live data graph (no separate
# persistent graph). User-supplied ``schema`` is merged on top at
# retrieval time so declared descriptions / required flags / not-yet-
# extracted properties survive.
self._ontology_store = OntologyStore(self._conn)
# Global ontology used at retrieval time. Initially just the local
# schema; refresh_ontology() merges in the inferred view on demand
# and after each ingest.
self._global_schema: GraphSchema = self.schema

# Default retrieval strategy
self._retrieval_strategy = retrieval_strategy or MultiPathRetrieval(
graph_store=self._graph_store,
vector_store=self._vector_store,
embedder=self.embedder,
llm=self.llm,
schema=self._global_schema,
)
Comment thread
galshubeli marked this conversation as resolved.
Outdated

# -- Async context manager -------------------------------------------
Expand Down Expand Up @@ -231,6 +243,50 @@ async def close(self) -> None:
"""Close the underlying database connection."""
await self._conn.close()

# ── Ontology ─────────────────────────────────────────────────

async def get_ontology(self) -> GraphSchema:
"""Return the **global** ontology used at retrieval time.

The global ontology is the union of:
- the ontology inferred from the live data graph
(``db.labels()`` / ``db.relationshipTypes()`` + property sampling), and
- the user-supplied ``schema`` passed to :py:class:`GraphRAG`
(carries descriptions, ``required`` flags, properties not yet
present in the data).

Returns an empty schema when both are empty.
"""
inferred = await self._ontology_store.infer()
if inferred.entities or inferred.relations:
return inferred.merge(self.schema)
return self.schema

async def refresh_ontology(self) -> GraphSchema:
"""Recompute the global ontology and propagate it to the retrieval path.

Called automatically after each ``ingest()``. Call explicitly when
the data graph has changed under your feet (concurrent writer,
manual edits) and the next retrieval should see it.
"""
self._global_schema = await self.get_ontology()
# Best-effort propagation to the retrieval strategy. Built-in
# MultiPathRetrieval honours this; custom strategies opt in by exposing
# a ``_schema`` attribute or accepting it via constructor.
if hasattr(self._retrieval_strategy, "_schema"):
self._retrieval_strategy._schema = self._global_schema
return self._global_schema

async def save_ontology(self, path: str, *, indent: int = 2) -> None:
"""Write the current global ontology to ``path`` as JSON.

Convenience for the schema-as-config workflow: ``rag.save_ontology(
"ontology.json")``, hand-edit / version-control it, then load with
``GraphSchema.from_file("ontology.json")`` on the next run.
"""
ontology = await self.get_ontology()
ontology.save_to_file(path, indent=indent)

# ── Graph admin ──────────────────────────────────────────────

async def get_statistics(self) -> dict[str, Any]:
Expand Down Expand Up @@ -539,6 +595,13 @@ async def _ingest_single(

result = await pipeline.run(source, ctx, text=text, document_info=doc_info)

# Refresh the global ontology so the next retrieval call sees any
# new properties declared by this run.
try:
await self.refresh_ontology()
except Exception as exc:
logger.warning("Ontology refresh failed (continuing): %s", exc)

if not _skip_post:
# Post-ingestion: create indices only.
# backfill_entity_embeddings() is intentionally NOT called here —
Expand Down
168 changes: 148 additions & 20 deletions graphrag_sdk/src/graphrag_sdk/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,27 @@ class DocumentRecord(DataModel):
# ── Schema Types ─────────────────────────────────────────────────


_PROPERTY_TYPES: frozenset[str] = frozenset(
{"STRING", "INTEGER", "FLOAT", "BOOLEAN", "DATE", "LIST"}
)

# Property names reserved by the SDK; user schemas may not declare them.
RESERVED_PROPERTY_NAMES: frozenset[str] = frozenset(
{
"name",
"type",
"description",
"source_chunk_ids",
"spans",
"rel_type",
"fact",
"src_name",
"tgt_name",
"id",
}
)


class PropertyType(DataModel):
"""A property definition for a node or relationship type."""

Expand All @@ -144,6 +165,17 @@ class PropertyType(DataModel):
description: str | None = None
required: bool = False

@model_validator(mode="after")
def _normalize_type(self) -> PropertyType:
normalized = (self.type or "STRING").strip().upper()
if normalized not in _PROPERTY_TYPES:
raise ValueError(
f"PropertyType '{self.name}' has unsupported type "
f"{self.type!r}. Allowed: {sorted(_PROPERTY_TYPES)}"
)
self.type = normalized
return self


class EntityType(DataModel):
"""Definition of a node/entity type in the graph schema."""
Expand Down Expand Up @@ -185,6 +217,7 @@ class RelationType(DataModel):
label: str
description: str | None = None
patterns: list[tuple[str, str]] = Field(default_factory=list)
properties: list[PropertyType] = Field(default_factory=list)

# Identity is by label only — two RelationType instances with the same
# label but different patterns compare/hash equal. Schemas are expected
Expand All @@ -209,30 +242,123 @@ class GraphSchema(DataModel):
relations: list[RelationType] = Field(default_factory=list)

@model_validator(mode="after")
def _warn_on_undeclared_pattern_labels(self) -> GraphSchema:
"""Warn when a ``RelationType.patterns`` references undeclared entity labels.

Catches typos like ``("Persn", "Company")`` at config time, before any
extraction has run. We warn rather than raise: open-schema setups may
legitimately reference labels not (yet) listed in ``entities``.
def _validate_schema(self) -> GraphSchema:
"""Validate the schema at config time.

- Rejects user-declared property names that collide with SDK-reserved
keys (``name``, ``description``, ``source_chunk_ids``, ``spans``,
``rel_type``, ``fact``, ``src_name``, ``tgt_name``, ``id``, ``type``).
- Warns when a ``RelationType.patterns`` references undeclared entity
labels (typo guard) — open-schema setups may legitimately reference
labels not yet listed in ``entities``, so this is a warning, not an
error.
"""
if not self.entities:
return self
declared = {e.label for e in self.entities}
for rel in self.relations:
for src, tgt in rel.patterns:
missing = [lbl for lbl in (src, tgt) if lbl not in declared]
if missing:
logger.warning(
"RelationType '%s' pattern (%s, %s) references "
"entity label(s) not declared in schema.entities: %s",
rel.label,
src,
tgt,
", ".join(missing),
for et in self.entities:
for prop in et.properties:
if prop.name in RESERVED_PROPERTY_NAMES:
raise ValueError(
f"EntityType '{et.label}' declares property "
f"'{prop.name}', which is reserved by the SDK. "
f"Reserved names: {sorted(RESERVED_PROPERTY_NAMES)}"
)
for rt in self.relations:
for prop in rt.properties:
if prop.name in RESERVED_PROPERTY_NAMES:
raise ValueError(
f"RelationType '{rt.label}' declares property "
f"'{prop.name}', which is reserved by the SDK. "
f"Reserved names: {sorted(RESERVED_PROPERTY_NAMES)}"
)

if self.entities:
declared = {e.label for e in self.entities}
for rel in self.relations:
for src, tgt in rel.patterns:
missing = [lbl for lbl in (src, tgt) if lbl not in declared]
if missing:
logger.warning(
"RelationType '%s' pattern (%s, %s) references "
"entity label(s) not declared in schema.entities: %s",
rel.label,
src,
tgt,
", ".join(missing),
)
return self

@classmethod
def from_file(cls, path: str) -> GraphSchema:
"""Load a ``GraphSchema`` from a JSON file.

The schema-as-config workflow: keep the canonical schema in a JSON
file under version control, load it into the SDK with one call. See
:py:meth:`save_to_file` for the reverse direction.
"""
from pathlib import Path

return cls.model_validate_json(Path(path).read_text(encoding="utf-8"))

def save_to_file(self, path: str, *, indent: int = 2) -> None:
"""Write this schema to ``path`` as JSON (overwrites existing files)."""
from pathlib import Path

Path(path).write_text(self.model_dump_json(indent=indent), encoding="utf-8")

def merge(self, other: GraphSchema) -> GraphSchema:
"""Return a new ``GraphSchema`` that is the union of ``self`` and ``other``.

- Entity / relation types are unioned by ``label``.
- For each type, ``properties`` are unioned by ``name``. When the same
property name appears in both, the incoming type/description/required
overrides — last-write-wins, matching the persisted ontology's
register() semantics.
Comment thread
galshubeli marked this conversation as resolved.
Outdated
- For relations, ``patterns`` are unioned (order-preserving, deduped).
"""

def _merge_props(
existing: list[PropertyType], incoming: list[PropertyType]
) -> list[PropertyType]:
by_name: dict[str, PropertyType] = {p.name: p for p in existing}
for p in incoming:
by_name[p.name] = p
return list(by_name.values())

ent_by_label: dict[str, EntityType] = {e.label: e for e in self.entities}
for e in other.entities:
if e.label in ent_by_label:
cur = ent_by_label[e.label]
ent_by_label[e.label] = EntityType(
label=cur.label,
description=e.description or cur.description,
properties=_merge_props(cur.properties, e.properties),
Comment thread
galshubeli marked this conversation as resolved.
)
else:
ent_by_label[e.label] = e

rel_by_label: dict[str, RelationType] = {r.label: r for r in self.relations}
for r in other.relations:
if r.label in rel_by_label:
cur = rel_by_label[r.label]
seen: set[tuple[str, str]] = set()
merged_patterns: list[tuple[str, str]] = []
for pat in list(cur.patterns) + list(r.patterns):
if pat not in seen:
seen.add(pat)
merged_patterns.append(pat)
rel_by_label[r.label] = RelationType(
label=cur.label,
description=r.description or cur.description,
patterns=merged_patterns,
properties=_merge_props(cur.properties, r.properties),
)
else:
rel_by_label[r.label] = r

return GraphSchema(
entities=list(ent_by_label.values()),
relations=list(rel_by_label.values()),
)


# ── Extraction / Resolution Output Types ─────────────────────────

Expand All @@ -254,6 +380,7 @@ class ExtractedEntity(DataModel):
type: str
description: str = ""
source_chunk_ids: list[str] = Field(default_factory=list)
attributes: dict[str, Any] = Field(default_factory=dict)


class ExtractedRelation(DataModel):
Expand All @@ -266,6 +393,7 @@ class ExtractedRelation(DataModel):
description: str = ""
weight: float = 1.0
source_chunk_ids: list[str] = Field(default_factory=list)
attributes: dict[str, Any] = Field(default_factory=dict)


class EntityMention(DataModel):
Expand Down
Loading
Loading