From 576bafad243640baab9e0473a7e30797fe25946c Mon Sep 17 00:00:00 2001 From: Gal Shubeli Date: Sun, 17 May 2026 16:49:05 +0300 Subject: [PATCH 1/7] feat(schema): schema-driven attributes end-to-end with persistent ontology Wires declared PropertyType / RelationType.properties from the GraphSchema through extraction, storage, and retrieval. The ontology persists in a dedicated __ontology FalkorDB graph and accumulates as the union of every schema ever registered, so ingest passes can use subset/ extension schemas while retrieval always sees the global ontology. - core/models: PropertyType type-normalising validator, RelationType.properties, ExtractedEntity/Relation.attributes, reserved-name rejection, GraphSchema.merge. - storage/ontology_store (new): OntologyStore.load/register/clear against __ontology; idempotent unions; pattern lists deduped. - ingestion/graph_extraction: conditional attribute block in VERIFY_EXTRACT_RELS_PROMPT (empty for property-less schemas, zero drift), per-type coercion at the extractor boundary, required-missing record drop, aggregators carry attributes (last-write-wins), conversion merges into props. - ingestion/pipeline: _validate_attributes pass after _prune, then re-run _filter_quality to cascade-clean dangling relationships. - retrieval/cypher_generation: replaces hardcoded SCHEMA_PROMPT with build_schema_prompt(schema, question) + render_schema_block; synthesises one numeric-attribute example per declared INTEGER/FLOAT property; validate_cypher accepts dynamic schema labels. - retrieval/multi_path: schema threaded into execute_cypher_retrieval. - api/main: GraphRAG owns the OntologyStore, registers local schema on each ingest, refresh_ontology() propagates the global ontology to retrieval, public get_ontology() / refresh_ontology() methods. Forward-only evolution: nodes ingested before a property was declared remain without that property; FalkorDB's MERGE ... SET n += props handles fill-in naturally and WHERE p.attr > N naturally excludes them. Tests: 124 new/extended unit tests; full suite 819 passed, 23 skipped (integration tests gated by RUN_INTEGRATION=1). Property-less schemas produce identical output to the prior pipeline. Co-Authored-By: Claude Opus 4.7 (1M context) --- graphrag_sdk/src/graphrag_sdk/api/main.py | 59 +++++ graphrag_sdk/src/graphrag_sdk/core/models.py | 150 +++++++++-- .../extraction_strategies/graph_extraction.py | 236 +++++++++++++++++- .../src/graphrag_sdk/ingestion/pipeline.py | 115 +++++++++ .../retrieval/strategies/cypher_generation.py | 178 +++++++++++-- .../retrieval/strategies/multi_path.py | 6 +- .../src/graphrag_sdk/storage/__init__.py | 3 +- .../graphrag_sdk/storage/ontology_store.py | 228 +++++++++++++++++ graphrag_sdk/tests/test_attribute_prompt.py | 209 ++++++++++++++++ graphrag_sdk/tests/test_cypher_generation.py | 92 +++++++ graphrag_sdk/tests/test_graph_extraction.py | 126 ++++++++++ graphrag_sdk/tests/test_models.py | 83 +++++- graphrag_sdk/tests/test_ontology_store.py | 226 +++++++++++++++++ graphrag_sdk/tests/test_pipeline.py | 141 +++++++++++ 14 files changed, 1804 insertions(+), 48 deletions(-) create mode 100644 graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py create mode 100644 graphrag_sdk/tests/test_attribute_prompt.py create mode 100644 graphrag_sdk/tests/test_ontology_store.py diff --git a/graphrag_sdk/src/graphrag_sdk/api/main.py b/graphrag_sdk/src/graphrag_sdk/api/main.py index 9b2486c6..1524108b 100644 --- a/graphrag_sdk/src/graphrag_sdk/api/main.py +++ b/graphrag_sdk/src/graphrag_sdk/api/main.py @@ -46,6 +46,7 @@ from graphrag_sdk.retrieval.strategies.multi_path import MultiPathRetrieval from graphrag_sdk.storage.deduplicator import EntityDeduplicator from graphrag_sdk.storage.graph_store import GraphStore +from graphrag_sdk.storage.ontology_store import OntologyStore from graphrag_sdk.storage.vector_store import VectorStore logger = logging.getLogger(__name__) @@ -192,12 +193,21 @@ def __init__( # Deduplication engine self._deduplicator = EntityDeduplicator(self._graph_store, self.embedder) + # Persistent ontology lives in `__ontology`. Stays in sync + # across multiple ingest passes that may declare different schemas. + self._ontology_store = OntologyStore(self._conn, self._conn.config.graph_name) + # Global ontology used at retrieval time. Initially the user-supplied + # local schema; refreshed from the persisted ontology after each ingest + # (see refresh_ontology()) and on demand. + self._global_schema: GraphSchema = self.schema + # Default retrieval strategy self._retrieval_strategy = retrieval_strategy or MultiPathRetrieval( graph_store=self._graph_store, vector_store=self._vector_store, embedder=self.embedder, llm=self.llm, + schema=self._global_schema, ) # -- Async context manager ------------------------------------------- @@ -231,6 +241,37 @@ async def close(self) -> None: """Close the underlying database connection.""" await self._conn.close() + # ── Ontology ───────────────────────────────────────────────── + + async def get_ontology(self) -> GraphSchema: + """Return the persisted **global** ontology (union of every schema + ever registered against this graph). + + Reads from the dedicated ``__ontology`` graph. Returns an + empty schema before the first ingest has run. + """ + return await self._ontology_store.load() + + async def refresh_ontology(self) -> GraphSchema: + """Reload the global ontology and propagate it to the retrieval path. + + Called automatically after each ``ingest()``. Call explicitly when + another process has registered new schema and you want the next + retrieval to see it without re-ingesting first. + """ + loaded = await self._ontology_store.load() + if loaded.entities or loaded.relations: + self._global_schema = loaded + else: + # No persisted ontology yet — keep the user-provided local schema. + self._global_schema = self.schema + # Best-effort propagation to the retrieval strategy. Built-in + # MultiPathRetrieval honours this; custom strategies opt in by exposing + # a ``_schema`` attribute or accepting it via constructor. + if hasattr(self._retrieval_strategy, "_schema"): + self._retrieval_strategy._schema = self._global_schema + return self._global_schema + # ── Graph admin ────────────────────────────────────────────── async def get_statistics(self) -> dict[str, Any]: @@ -527,6 +568,17 @@ async def _ingest_single( doc_info = DocumentInfo(uid=resolved_id, path=path_for_node) + # Register this run's local schema into the persisted ontology so the + # global ontology is the union of every schema ever registered. The + # local schema continues to drive *this* run's extraction. + if self.schema.entities or self.schema.relations: + try: + await self._ontology_store.register(self.schema) + except Exception as exc: + logger.warning( + "Ontology registration failed (continuing ingest): %s", exc + ) + pipeline = IngestionPipeline( loader=loader or TextLoader(), chunker=chunker or FixedSizeChunking(), @@ -539,6 +591,13 @@ async def _ingest_single( result = await pipeline.run(source, ctx, text=text, document_info=doc_info) + # Refresh the global ontology so the next retrieval call sees any + # new properties declared by this run. + try: + await self.refresh_ontology() + except Exception as exc: + logger.warning("Ontology refresh failed (continuing): %s", exc) + if not _skip_post: # Post-ingestion: create indices only. # backfill_entity_embeddings() is intentionally NOT called here — diff --git a/graphrag_sdk/src/graphrag_sdk/core/models.py b/graphrag_sdk/src/graphrag_sdk/core/models.py index 8d88619b..86119263 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/models.py +++ b/graphrag_sdk/src/graphrag_sdk/core/models.py @@ -136,6 +136,27 @@ class DocumentRecord(DataModel): # ── Schema Types ───────────────────────────────────────────────── +_PROPERTY_TYPES: frozenset[str] = frozenset( + {"STRING", "INTEGER", "FLOAT", "BOOLEAN", "DATE", "LIST"} +) + +# Property names reserved by the SDK; user schemas may not declare them. +RESERVED_PROPERTY_NAMES: frozenset[str] = frozenset( + { + "name", + "type", + "description", + "source_chunk_ids", + "spans", + "rel_type", + "fact", + "src_name", + "tgt_name", + "id", + } +) + + class PropertyType(DataModel): """A property definition for a node or relationship type.""" @@ -144,6 +165,17 @@ class PropertyType(DataModel): description: str | None = None required: bool = False + @model_validator(mode="after") + def _normalize_type(self) -> PropertyType: + normalized = (self.type or "STRING").strip().upper() + if normalized not in _PROPERTY_TYPES: + raise ValueError( + f"PropertyType '{self.name}' has unsupported type " + f"{self.type!r}. Allowed: {sorted(_PROPERTY_TYPES)}" + ) + self.type = normalized + return self + class EntityType(DataModel): """Definition of a node/entity type in the graph schema.""" @@ -185,6 +217,7 @@ class RelationType(DataModel): label: str description: str | None = None patterns: list[tuple[str, str]] = Field(default_factory=list) + properties: list[PropertyType] = Field(default_factory=list) # Identity is by label only — two RelationType instances with the same # label but different patterns compare/hash equal. Schemas are expected @@ -209,30 +242,105 @@ class GraphSchema(DataModel): relations: list[RelationType] = Field(default_factory=list) @model_validator(mode="after") - def _warn_on_undeclared_pattern_labels(self) -> GraphSchema: - """Warn when a ``RelationType.patterns`` references undeclared entity labels. - - Catches typos like ``("Persn", "Company")`` at config time, before any - extraction has run. We warn rather than raise: open-schema setups may - legitimately reference labels not (yet) listed in ``entities``. + def _validate_schema(self) -> GraphSchema: + """Validate the schema at config time. + + - Rejects user-declared property names that collide with SDK-reserved + keys (``name``, ``description``, ``source_chunk_ids``, ``spans``, + ``rel_type``, ``fact``, ``src_name``, ``tgt_name``, ``id``, ``type``). + - Warns when a ``RelationType.patterns`` references undeclared entity + labels (typo guard) — open-schema setups may legitimately reference + labels not yet listed in ``entities``, so this is a warning, not an + error. """ - if not self.entities: - return self - declared = {e.label for e in self.entities} - for rel in self.relations: - for src, tgt in rel.patterns: - missing = [lbl for lbl in (src, tgt) if lbl not in declared] - if missing: - logger.warning( - "RelationType '%s' pattern (%s, %s) references " - "entity label(s) not declared in schema.entities: %s", - rel.label, - src, - tgt, - ", ".join(missing), + for et in self.entities: + for prop in et.properties: + if prop.name in RESERVED_PROPERTY_NAMES: + raise ValueError( + f"EntityType '{et.label}' declares property " + f"'{prop.name}', which is reserved by the SDK. " + f"Reserved names: {sorted(RESERVED_PROPERTY_NAMES)}" ) + for rt in self.relations: + for prop in rt.properties: + if prop.name in RESERVED_PROPERTY_NAMES: + raise ValueError( + f"RelationType '{rt.label}' declares property " + f"'{prop.name}', which is reserved by the SDK. " + f"Reserved names: {sorted(RESERVED_PROPERTY_NAMES)}" + ) + + if self.entities: + declared = {e.label for e in self.entities} + for rel in self.relations: + for src, tgt in rel.patterns: + missing = [lbl for lbl in (src, tgt) if lbl not in declared] + if missing: + logger.warning( + "RelationType '%s' pattern (%s, %s) references " + "entity label(s) not declared in schema.entities: %s", + rel.label, + src, + tgt, + ", ".join(missing), + ) return self + def merge(self, other: GraphSchema) -> GraphSchema: + """Return a new ``GraphSchema`` that is the union of ``self`` and ``other``. + + - Entity / relation types are unioned by ``label``. + - For each type, ``properties`` are unioned by ``name``. When the same + property name appears in both, the incoming type/description/required + overrides — last-write-wins, matching the persisted ontology's + register() semantics. + - For relations, ``patterns`` are unioned (order-preserving, deduped). + """ + + def _merge_props( + existing: list[PropertyType], incoming: list[PropertyType] + ) -> list[PropertyType]: + by_name: dict[str, PropertyType] = {p.name: p for p in existing} + for p in incoming: + by_name[p.name] = p + return list(by_name.values()) + + ent_by_label: dict[str, EntityType] = {e.label: e for e in self.entities} + for e in other.entities: + if e.label in ent_by_label: + cur = ent_by_label[e.label] + ent_by_label[e.label] = EntityType( + label=cur.label, + description=e.description or cur.description, + properties=_merge_props(cur.properties, e.properties), + ) + else: + ent_by_label[e.label] = e + + rel_by_label: dict[str, RelationType] = {r.label: r for r in self.relations} + for r in other.relations: + if r.label in rel_by_label: + cur = rel_by_label[r.label] + seen: set[tuple[str, str]] = set() + merged_patterns: list[tuple[str, str]] = [] + for pat in list(cur.patterns) + list(r.patterns): + if pat not in seen: + seen.add(pat) + merged_patterns.append(pat) + rel_by_label[r.label] = RelationType( + label=cur.label, + description=r.description or cur.description, + patterns=merged_patterns, + properties=_merge_props(cur.properties, r.properties), + ) + else: + rel_by_label[r.label] = r + + return GraphSchema( + entities=list(ent_by_label.values()), + relations=list(rel_by_label.values()), + ) + # ── Extraction / Resolution Output Types ───────────────────────── @@ -254,6 +362,7 @@ class ExtractedEntity(DataModel): type: str description: str = "" source_chunk_ids: list[str] = Field(default_factory=list) + attributes: dict[str, Any] = Field(default_factory=dict) class ExtractedRelation(DataModel): @@ -266,6 +375,7 @@ class ExtractedRelation(DataModel): description: str = "" weight: float = 1.0 source_chunk_ids: list[str] = Field(default_factory=list) + attributes: dict[str, Any] = Field(default_factory=dict) class EntityMention(DataModel): diff --git a/graphrag_sdk/src/graphrag_sdk/ingestion/extraction_strategies/graph_extraction.py b/graphrag_sdk/src/graphrag_sdk/ingestion/extraction_strategies/graph_extraction.py index 09a56121..02ec9615 100644 --- a/graphrag_sdk/src/graphrag_sdk/ingestion/extraction_strategies/graph_extraction.py +++ b/graphrag_sdk/src/graphrag_sdk/ingestion/extraction_strategies/graph_extraction.py @@ -10,14 +10,18 @@ from typing import Any from graphrag_sdk.core.context import Context +from datetime import date, datetime + from graphrag_sdk.core.models import ( EntityMention, + EntityType, ExtractedEntity, ExtractedRelation, GraphData, GraphNode, GraphRelationship, GraphSchema, + PropertyType, RelationType, TextChunks, ) @@ -48,6 +52,7 @@ "## Entity Types\n" "{entity_types}\n\n" "{relation_patterns}" + "{attribute_schema_block}" "## Pre-extracted Entities\n" "{entities_json}\n\n" "## Text\n" @@ -74,12 +79,165 @@ "sentence for this relationship starts.\n" "- span_end: the character offset where the evidence sentence ends.\n\n" "Return ONLY a JSON object with two arrays:\n" + "{json_example}\n\n" + "Return ONLY valid JSON, nothing else." +) + +_DEFAULT_JSON_EXAMPLE = ( '{{"entities": [{{"name": "...", "type": "...", "description": "..."}}], ' '"relationships": [{{"source": "...", "target": "...", "type": "...", ' - '"description": "...", "span_start": 0, "span_end": 50}}]}}\n\n' - "Return ONLY valid JSON, nothing else." + '"description": "...", "span_start": 0, "span_end": 50}}]}}' ) +_JSON_EXAMPLE_WITH_ATTRS = ( + '{{"entities": [{{"name": "...", "type": "...", "description": "...", ' + '"attributes": {{}}}}], ' + '"relationships": [{{"source": "...", "target": "...", "type": "...", ' + '"description": "...", "span_start": 0, "span_end": 50, ' + '"attributes": {{}}}}]}}' +) + + +def _format_property_for_prompt(prop: PropertyType) -> str: + desc = f" — {prop.description}" if prop.description else "" + req = " (required)" if prop.required else "" + return f" - {prop.name} ({prop.type}){req}{desc}" + + +def _render_attribute_schema_block(schema: GraphSchema) -> str: + """Return the prompt block listing declared entity/relation attributes. + + Returns ``""`` when no entity or relation type declares any property — this + keeps the prompt identical to today for property-less schemas (no token + drift, no quality regression for the existing extraction path). + """ + ent_lines: list[str] = [] + for et in schema.entities: + if not et.properties: + continue + ent_lines.append(f"- {et.label}:") + ent_lines.extend(_format_property_for_prompt(p) for p in et.properties) + + rel_lines: list[str] = [] + for rt in schema.relations: + if not rt.properties: + continue + rel_lines.append(f"- {rt.label}:") + rel_lines.extend(_format_property_for_prompt(p) for p in rt.properties) + + if not ent_lines and not rel_lines: + return "" + + parts = [ + "## Attribute extraction", + "For each verified entity and each extracted relationship, also extract any " + "of the declared properties below. Use null when the property is not stated " + "in the text. Coerce to the declared type (INTEGER, FLOAT, BOOLEAN, DATE in " + "ISO 8601, STRING, LIST). Place extracted values inside an `attributes` " + "object on each entity / relationship.", + ] + if ent_lines: + parts.append("### Entity attributes") + parts.extend(ent_lines) + if rel_lines: + parts.append("### Relation attributes") + parts.extend(rel_lines) + return "\n".join(parts) + "\n\n" + + +def _schema_has_attributes(schema: GraphSchema) -> bool: + return any(et.properties for et in schema.entities) or any( + rt.properties for rt in schema.relations + ) + + +def _coerce_attribute_value(value: Any, prop_type: str) -> tuple[bool, Any]: + """Coerce ``value`` to ``prop_type``. Returns ``(ok, coerced)``. + + ``ok=False`` means the value should be dropped (incl. ``None``). + """ + if value is None: + return False, None + pt = prop_type.upper() + try: + if pt == "STRING": + s = str(value).strip() + return (True, s) if s else (False, None) + if pt == "INTEGER": + if isinstance(value, bool): + return False, None + return True, int(float(value)) + if pt == "FLOAT": + if isinstance(value, bool): + return False, None + return True, float(value) + if pt == "BOOLEAN": + if isinstance(value, bool): + return True, value + if isinstance(value, (int, float)): + return True, bool(value) + if isinstance(value, str): + v = value.strip().lower() + if v in {"true", "yes", "y", "1"}: + return True, True + if v in {"false", "no", "n", "0"}: + return True, False + return False, None + if pt == "DATE": + if isinstance(value, (date, datetime)): + return True, (value.date() if isinstance(value, datetime) else value).isoformat() + if isinstance(value, str): + v = value.strip() + if not v: + return False, None + # Accept either pure date or full ISO 8601 + try: + parsed = date.fromisoformat(v) + return True, parsed.isoformat() + except ValueError: + try: + parsed_dt = datetime.fromisoformat(v.replace("Z", "+00:00")) + return True, parsed_dt.date().isoformat() + except ValueError: + return False, None + return False, None + if pt == "LIST": + if isinstance(value, dict): + return False, None + if isinstance(value, (list, tuple, set)): + return True, list(value) + return True, [value] + except (TypeError, ValueError): + return False, None + return False, None + + +def _coerce_attributes( + raw: dict[str, Any] | None, + declared: dict[str, PropertyType], +) -> tuple[dict[str, Any], list[str]]: + """Apply per-type coercion against declared properties. + + Returns ``(coerced, missing_required)``. ``missing_required`` lists names of + declared ``required=True`` properties whose value is absent or failed to + coerce — callers use it to decide whether to drop the whole record. + """ + coerced: dict[str, Any] = {} + raw = raw or {} + missing_required: list[str] = [] + for prop_name, prop in declared.items(): + if prop_name not in raw or raw[prop_name] is None: + if prop.required: + missing_required.append(prop_name) + continue + ok, val = _coerce_attribute_value(raw[prop_name], prop.type) + if not ok: + if prop.required: + missing_required.append(prop_name) + continue + coerced[prop_name] = val + return coerced, missing_required + def _optional_extras(obj: Any) -> dict[str, Any]: """Extract optional spans/confidence from a Pydantic extra-allow object.""" @@ -306,12 +464,15 @@ async def _step1(text: str, chunk_uid: str) -> list[ExtractedEntity]: entities_json = json.dumps( [{"name": e.name, "type": e.type, "description": e.description} for e in ents] ) + has_attrs = _schema_has_attributes(schema) prompt = VERIFY_EXTRACT_RELS_PROMPT.format( entity_types=_format_entity_types(entity_types, entity_type_descs), relation_patterns=_format_relation_patterns(schema.relations), + attribute_schema_block=_render_attribute_schema_block(schema), relationship_type_instruction=_relationship_type_instruction(schema.relations), entities_json=entities_json, text=text, + json_example=_JSON_EXAMPLE_WITH_ATTRS if has_attrs else _DEFAULT_JSON_EXAMPLE, ) step2_prompts.append(prompt) step2_indices.append(i) @@ -340,7 +501,7 @@ async def _step1(text: str, chunk_uid: str) -> list[ExtractedEntity]: assert item.response is not None verified_ents, rels = self._parse_step2_response( - item.response.content, entity_types, chunk.uid + item.response.content, entity_types, chunk.uid, schema ) if verified_ents: # Carry over spans/confidence from step 1 entities @@ -436,8 +597,15 @@ def _parse_step2_response( content: str, entity_types: list[str], source_chunk_id: str, + schema: GraphSchema | None = None, ) -> tuple[list[ExtractedEntity], list[ExtractedRelation]]: - """Parse the step 2 LLM response (verified entities + relationships).""" + """Parse the step 2 LLM response (verified entities + relationships). + + When ``schema`` declares attributes for an entity / relation type, the + response's ``attributes`` object is coerced to the declared types. + Records whose ``required=True`` properties fail coercion are dropped + and counted in an aggregated warning per type. + """ text = _strip_markdown_fences(content) try: @@ -449,6 +617,18 @@ def _parse_step2_response( if not isinstance(data, dict): return [], [] + ent_props_by_label: dict[str, dict[str, PropertyType]] = {} + rel_props_by_label: dict[str, dict[str, PropertyType]] = {} + if schema is not None: + for et in schema.entities: + if et.properties: + ent_props_by_label[et.label] = {p.name: p for p in et.properties} + for rt in schema.relations: + if rt.properties: + rel_props_by_label[rt.label] = {p.name: p for p in rt.properties} + + dropped_required: dict[str, int] = {} + # Parse entities entities: list[ExtractedEntity] = [] for item in data.get("entities", []): @@ -463,12 +643,24 @@ def _parse_step2_response( etype = label_for_type(raw_type, entity_types) description = str(item.get("description", "")).strip() + declared = ent_props_by_label.get(etype, {}) + attributes: dict[str, Any] = {} + if declared: + raw_attrs = item.get("attributes") if isinstance(item.get("attributes"), dict) else {} + attributes, missing = _coerce_attributes(raw_attrs, declared) + if missing: + dropped_required[f"entity:{etype}"] = ( + dropped_required.get(f"entity:{etype}", 0) + 1 + ) + continue + entities.append( ExtractedEntity( name=name, type=etype, description=description, source_chunk_ids=[source_chunk_id], + attributes=attributes, ) ) @@ -492,6 +684,17 @@ def _parse_step2_response( if spans: extra["spans"] = spans + declared_rel = rel_props_by_label.get(rel_type, {}) + attributes_rel: dict[str, Any] = {} + if declared_rel: + raw_attrs = item.get("attributes") if isinstance(item.get("attributes"), dict) else {} + attributes_rel, missing = _coerce_attributes(raw_attrs, declared_rel) + if missing: + dropped_required[f"relation:{rel_type}"] = ( + dropped_required.get(f"relation:{rel_type}", 0) + 1 + ) + continue + relations.append( ExtractedRelation( source=source, @@ -499,10 +702,21 @@ def _parse_step2_response( type=rel_type, description=description, source_chunk_ids=[source_chunk_id], + attributes=attributes_rel, **extra, ) ) + if dropped_required: + for key, n in dropped_required.items(): + logger.warning( + "Dropped %d %s record(s) with missing required attribute(s) " + "for chunk %s", + n, + key, + source_chunk_id, + ) + return entities, relations # ── Aggregation ────────────────────────────────────────────── @@ -538,12 +752,17 @@ def _aggregate_entities( for chunk_id, offsets in ent_spans.items(): existing_spans.setdefault(chunk_id, []).extend(offsets) existing.spans = existing_spans # type: ignore[attr-defined] + # Merge attributes: last-write-wins per key, matching FalkorDB's + # ``SET n += props`` semantics at the storage layer. + if ent.attributes: + existing.attributes.update(ent.attributes) else: seen[key] = ExtractedEntity( name=ent.name, type=ent.type, description=ent.description, source_chunk_ids=list(ent.source_chunk_ids), + attributes=dict(ent.attributes), **_optional_extras(ent), ) return list(seen.values()) @@ -577,6 +796,8 @@ def _aggregate_relations( for chunk_id, offsets in rel_spans.items(): existing_spans.setdefault(chunk_id, []).extend(offsets) existing.spans = existing_spans # type: ignore[attr-defined] + if rel.attributes: + existing.attributes.update(rel.attributes) else: seen[key] = ExtractedRelation( source=rel.source, @@ -584,6 +805,7 @@ def _aggregate_relations( type=rel.type, description=rel.description, source_chunk_ids=list(rel.source_chunk_ids), + attributes=dict(rel.attributes), **_optional_extras(rel), ) return list(seen.values()) @@ -609,6 +831,10 @@ def _entities_to_nodes( spans = getattr(ent, "spans", None) if spans: props["spans"] = spans + # Merge schema-declared attributes. Reserved-name collisions are + # rejected at schema-validation time, so update() is safe. + if ent.attributes: + props.update(ent.attributes) nodes.append( GraphNode( id=node_id, @@ -644,6 +870,8 @@ def _relations_to_relationships( spans = getattr(rel, "spans", None) if spans: props["spans"] = spans + if rel.attributes: + props.update(rel.attributes) relationships.append( GraphRelationship( start_node_id=compute_entity_id(rel.source, src_type), diff --git a/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py b/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py index 165cbeeb..4dbcf87e 100644 --- a/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py +++ b/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py @@ -188,6 +188,8 @@ async def run( # Step 5: Prune against schema ctx.log("Step 5/9: Pruning against schema") graph_data = self._prune(graph_data, self.schema) + graph_data = self._validate_attributes(graph_data, self.schema) + graph_data = self._filter_quality(graph_data) # Step 6: Resolve duplicate entities ctx.log("Step 6/9: Resolving duplicates") @@ -449,6 +451,119 @@ def _prune(self, graph_data: GraphData, schema: GraphSchema) -> GraphData: extracted_relations=graph_data.extracted_relations, ) + def _validate_attributes( + self, graph_data: GraphData, schema: GraphSchema + ) -> GraphData: + """Validate node / relationship attributes against the declared schema. + + - Attribute keys not declared in the schema are dropped (debug-logged). + - Nodes / relationships missing a ``required=True`` declared attribute + are dropped entirely, with a single aggregated warning per type. + - Skips ``"Unknown"`` nodes (the low-confidence label preserved by + :py:meth:`_prune`). + + Reserved SDK keys (``name``, ``description``, ``source_chunk_ids``, + ``spans``, ``rel_type``, ``fact``, ``src_name``, ``tgt_name``, + ``type``, ``id``) are always preserved. + """ + if not schema.entities and not schema.relations: + return graph_data + + from graphrag_sdk.core.models import RESERVED_PROPERTY_NAMES + + ent_declared: dict[str, dict[str, Any]] = { + e.label: {p.name: p for p in e.properties} for e in schema.entities + } + rel_declared: dict[str, dict[str, Any]] = { + r.label: {p.name: p for p in r.properties} for r in schema.relations + } + + kept_nodes: list[GraphNode] = [] + required_missing_by_label: dict[str, int] = {} + unknown_dropped_by_label: dict[str, int] = {} + for node in graph_data.nodes: + if node.label == "Unknown" or node.label not in ent_declared: + kept_nodes.append(node) + continue + declared = ent_declared[node.label] + # Drop unknown attribute keys (debug only — high volume). + unknown_keys = [ + k + for k in list(node.properties.keys()) + if k not in declared and k not in RESERVED_PROPERTY_NAMES + ] + for k in unknown_keys: + node.properties.pop(k, None) + unknown_dropped_by_label[node.label] = ( + unknown_dropped_by_label.get(node.label, 0) + 1 + ) + # Drop record on missing required attribute. + missing = [ + name + for name, prop in declared.items() + if prop.required and node.properties.get(name) is None + ] + if missing: + required_missing_by_label[node.label] = ( + required_missing_by_label.get(node.label, 0) + 1 + ) + continue + kept_nodes.append(node) + + for label, count in unknown_dropped_by_label.items(): + logger.debug( + "Dropped %d undeclared attribute key(s) on %s nodes", + count, + label, + ) + for label, count in required_missing_by_label.items(): + logger.warning( + "Dropped %d %s node(s) with missing required attribute(s)", + count, + label, + ) + + kept_rels: list[GraphRelationship] = [] + rel_required_missing_by_label: dict[str, int] = {} + for rel in graph_data.relationships: + rel_label = rel.properties.get("rel_type", rel.type) + if rel_label not in rel_declared: + kept_rels.append(rel) + continue + declared = rel_declared[rel_label] + for k in [ + k + for k in list(rel.properties.keys()) + if k not in declared and k not in RESERVED_PROPERTY_NAMES + ]: + rel.properties.pop(k, None) + missing = [ + name + for name, prop in declared.items() + if prop.required and rel.properties.get(name) is None + ] + if missing: + rel_required_missing_by_label[rel_label] = ( + rel_required_missing_by_label.get(rel_label, 0) + 1 + ) + continue + kept_rels.append(rel) + + for label, count in rel_required_missing_by_label.items(): + logger.warning( + "Dropped %d [%s] relationship(s) with missing required attribute(s)", + count, + label, + ) + + return GraphData( + nodes=kept_nodes, + relationships=kept_rels, + mentions=graph_data.mentions, + extracted_entities=graph_data.extracted_entities, + extracted_relations=graph_data.extracted_relations, + ) + def _filter_quality(self, graph_data: GraphData) -> GraphData: """Remove nodes with empty IDs or labels, and dangling relationships.""" valid_nodes = [n for n in graph_data.nodes if n.id and n.label] diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py index 0dd3d289..cc7c89a7 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py @@ -12,6 +12,8 @@ import re from typing import Any +from graphrag_sdk.core.models import GraphSchema + logger = logging.getLogger(__name__) # ── Valid labels for our graph schema ──────────────────────────── @@ -36,31 +38,134 @@ _ALL_LABELS = _ENTITY_LABELS | _STRUCTURAL_LABELS +_RESERVED_ENTITY_PROPS = ( + ("name", "STRING", "entity name"), + ("description", "STRING", "entity description"), +) +_RESERVED_REL_PROPS = ( + ("rel_type", "STRING", "original relation type as a string"), + ("fact", "STRING", "evidence text for the relation"), + ("src_name", "STRING", "source entity name"), + ("tgt_name", "STRING", "target entity name"), +) + +_NUMERIC_TYPES = frozenset({"INTEGER", "FLOAT"}) + _WRITE_KEYWORDS = re.compile( r"\b(CREATE|DELETE|DETACH|SET|REMOVE|MERGE|DROP|CALL\s+db\.idx)\b", re.IGNORECASE, ) + +def _labels_from_schema(schema: GraphSchema | None) -> frozenset[str]: + """Return the entity labels declared in ``schema``, falling back to the + historical hardcoded set when the schema is empty (open-schema mode). + """ + if schema is None or not schema.entities: + return _ENTITY_LABELS + return frozenset(e.label for e in schema.entities) + + +def render_schema_block(schema: GraphSchema | None) -> str: + """Render a Markdown schema block listing declared entity / relation + properties, derived from the live ``GraphSchema``. + + Mirrors LangChain's ``Neo4jGraph.get_schema()`` and LlamaIndex's + ``Neo4jPropertyGraphStore.get_schema_str()``. Always emits the reserved + SDK keys (``name``, ``description`` on entities; ``rel_type``, ``fact``, + ``src_name``, ``tgt_name`` on RELATES) so the LLM still knows about them + even when the schema declares no custom properties. + """ + labels: list[str] + rel_labels: list[str] + ent_props: dict[str, list[tuple[str, str, str | None]]] = {} + rel_props: dict[str, list[tuple[str, str, str | None]]] = {} + + if schema is not None and schema.entities: + labels = [e.label for e in schema.entities] + for e in schema.entities: + ent_props[e.label] = [(p.name, p.type, p.description) for p in e.properties] + else: + labels = sorted(_ENTITY_LABELS) + + if schema is not None and schema.relations: + rel_labels = [r.label for r in schema.relations] + for r in schema.relations: + rel_props[r.label] = [(p.name, p.type, p.description) for p in r.properties] + else: + rel_labels = [] + + lines: list[str] = [] + lines.append("### Entity node labels (all entities also carry `__Entity__`):") + for label in labels: + lines.append(f"- {label}") + lines.append(" Properties:") + for name, typ, desc in _RESERVED_ENTITY_PROPS: + lines.append(f" - {name} ({typ}) — {desc}") + for name, typ, desc in ent_props.get(label, []): + d = f" — {desc}" if desc else "" + lines.append(f" - {name} ({typ}){d}") + + lines.append("") + lines.append("### Edge types:") + lines.append("- RELATES: connects any entity to any entity.") + lines.append(" Properties:") + for name, typ, desc in _RESERVED_REL_PROPS: + lines.append(f" - {name} ({typ}) — {desc}") + union_rel_props: dict[str, tuple[str, str | None]] = {} + for label in rel_labels: + for name, typ, desc in rel_props.get(label, []): + if name not in union_rel_props: + union_rel_props[name] = (typ, desc) + for name, (typ, desc) in union_rel_props.items(): + d = f" — {desc}" if desc else "" + lines.append(f" - {name} ({typ}){d} # declared on RELATES via rel_type filters") + if rel_labels: + lines.append( + " Allowed `rel_type` values: " + ", ".join(rel_labels) + ) + lines.append("- MENTIONED_IN: connects entity to Chunk node (provenance)") + lines.append("- PART_OF: connects Document to Chunk") + lines.append("- NEXT_CHUNK: connects Chunk to next sequential Chunk") + return "\n".join(lines) + + +def _render_attribute_examples(schema: GraphSchema | None) -> str: + """Synthesize one filter example per declared numeric attribute. + + Helps the LLM learn that custom numeric properties exist and can be used + in ``WHERE`` / ``ORDER BY`` / aggregations. Returns ``""`` when no + numeric attributes are declared. + """ + if schema is None: + return "" + examples: list[str] = [] + for et in schema.entities: + for p in et.properties: + if p.type in _NUMERIC_TYPES and len(examples) < 2: + var = et.label[0].lower() + examples.append( + f"Question: \"Which {et.label} has the highest {p.name}?\"\n" + f"```cypher\n" + f"MATCH ({var}:{et.label})\n" + f"WHERE {var}.{p.name} IS NOT NULL\n" + f"RETURN {var}.name AS name, {var}.{p.name} AS {p.name}\n" + f"ORDER BY {var}.{p.name} DESC\n" + f"LIMIT 10\n" + f"```" + ) + if not examples: + return "" + return "\n\n" + "\n\n".join(examples) + # ── Schema prompt ──────────────────────────────────────────────── -SCHEMA_PROMPT = """\ +_SCHEMA_PROMPT_TEMPLATE = """\ You are a Cypher query generator for a FalkorDB graph database. ## Graph Schema -### Entity node labels (all entities also carry the label `__Entity__`): -Person, Organization, Technology, Product, Location, Date, Event, Concept, Law, Dataset, Method - -### Entity node properties: -- name (string) — entity name -- description (string) — entity description - -### Edge types: -- RELATES: connects any entity to any entity. - Properties: rel_type (string), fact (string — evidence text), src_name (string), tgt_name (string) -- MENTIONED_IN: connects entity to Chunk node (provenance) -- PART_OF: connects Document to Chunk -- NEXT_CHUNK: connects Chunk to next sequential Chunk +{schema_block} ## FalkorDB-specific rules (CRITICAL — violating these causes execution errors): 1. Do NOT use shortestPath() or allShortestPaths() — FalkorDB returns @@ -127,7 +232,7 @@ MATCH (o:Organization)-[r:RELATES]-(t:Technology) RETURN o.name AS organization, t.name AS technology, r.rel_type AS relation, r.fact AS evidence LIMIT 20 -``` +```{attribute_examples} ## Your task @@ -139,6 +244,27 @@ """ +def build_schema_prompt(schema: GraphSchema | None, question: str) -> str: + """Build the full Cypher generation prompt for ``question`` from ``schema``. + + When ``schema`` is empty, the prompt falls back to the historical + hardcoded label set and matches today's behavior bit-for-bit aside from + the new schema block formatting. + """ + return _SCHEMA_PROMPT_TEMPLATE.format( + schema_block=render_schema_block(schema), + attribute_examples=_render_attribute_examples(schema), + question=question, + ) + + +# Backwards-compatible alias for callers that import SCHEMA_PROMPT. +# It exposes the template form (still expects ``{schema_block}``, +# ``{attribute_examples}``, and ``{question}`` placeholders) — direct +# ``.format(question=...)`` callers should migrate to ``build_schema_prompt``. +SCHEMA_PROMPT = _SCHEMA_PROMPT_TEMPLATE + + # ── Cypher extraction ──────────────────────────────────────────── @@ -184,12 +310,16 @@ def _sanitize_cypher(cypher: str) -> str: # ── Cypher validation ──────────────────────────────────────────── -def validate_cypher(cypher: str) -> list[str]: +def validate_cypher(cypher: str, schema: GraphSchema | None = None) -> list[str]: """Validate generated Cypher for safety and correctness. Uses an allowlist approach: the query must start with a read-only keyword, and dangerous constructs are explicitly rejected. + When ``schema`` is provided, label validation uses the labels declared + in the schema (plus structural labels); otherwise it falls back to the + historical hardcoded label set. + Returns list of error strings; empty list means valid. """ errors: list[str] = [] @@ -229,9 +359,10 @@ def validate_cypher(cypher: str) -> list[str]: errors.append("Missing RETURN clause") # Check referenced labels exist in schema + allowed_labels = _labels_from_schema(schema) | _STRUCTURAL_LABELS label_pattern = re.findall(r"\((?:\w+)?:(\w+)", cypher_norm) for label in label_pattern: - if label not in _ALL_LABELS: + if label not in allowed_labels: errors.append(f"Unknown label: {label}") return errors @@ -244,13 +375,17 @@ async def generate_cypher( llm: Any, question: str, *, + schema: GraphSchema | None = None, max_retries: int = 3, ) -> str | None: """Generate a Cypher query from a natural language question. + When ``schema`` is provided, the prompt and validator both use the + declared labels and properties. + Returns the Cypher string, or None if all retries fail. """ - prompt = SCHEMA_PROMPT.format(question=question) + prompt = build_schema_prompt(schema, question) last_error = "" for attempt in range(max_retries): @@ -270,7 +405,7 @@ async def generate_cypher( last_error = "Empty query generated" continue - errors = validate_cypher(cypher) + errors = validate_cypher(cypher, schema) if errors: last_error = "; ".join(errors) continue @@ -292,6 +427,7 @@ async def execute_cypher_retrieval( llm: Any, question: str, *, + schema: GraphSchema | None = None, max_retries: int = 3, ) -> tuple[list[str], dict[str, dict]]: """Full text-to-cypher retrieval: generate -> validate -> execute -> parse. @@ -306,7 +442,9 @@ async def execute_cypher_retrieval( On any failure, returns empty results (silent degradation). """ - cypher = await generate_cypher(llm, question, max_retries=max_retries) + cypher = await generate_cypher( + llm, question, schema=schema, max_retries=max_retries + ) if not cypher: return [], {} diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py index d7bcbd9b..80e29b34 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py @@ -166,6 +166,7 @@ def __init__( rel_top_k: int = 15, # Matched to chunk_top_k for balanced fact/passage coverage keyword_limit: int = 10, # Fulltext search keyword budget from query decomposition enable_cypher: bool = False, # Text-to-Cypher path (experimental, off by default) + schema: Any | None = None, # GraphSchema; forwarded to Cypher generation when enable_cypher ) -> None: super().__init__(graph_store=graph_store, vector_store=vector_store) self._embedder = embedder @@ -176,6 +177,7 @@ def __init__( self._rel_top_k = rel_top_k self._keyword_limit = keyword_limit self._enable_cypher = enable_cypher + self._schema = schema # -- Template Method hook -- @@ -197,7 +199,9 @@ async def _execute( if self._enable_cypher: results = await asyncio.gather( search_relates_edges(self._vector, query_vector, self._rel_top_k), - execute_cypher_retrieval(self._graph, self._llm, query), + execute_cypher_retrieval( + self._graph, self._llm, query, schema=self._schema + ), return_exceptions=True, ) # Unpack RELATES results diff --git a/graphrag_sdk/src/graphrag_sdk/storage/__init__.py b/graphrag_sdk/src/graphrag_sdk/storage/__init__.py index 610abe39..7f33d890 100644 --- a/graphrag_sdk/src/graphrag_sdk/storage/__init__.py +++ b/graphrag_sdk/src/graphrag_sdk/storage/__init__.py @@ -3,6 +3,7 @@ from graphrag_sdk.storage.deduplicator import EntityDeduplicator from graphrag_sdk.storage.graph_store import GraphStore +from graphrag_sdk.storage.ontology_store import OntologyStore from graphrag_sdk.storage.vector_store import VectorStore -__all__ = ["EntityDeduplicator", "GraphStore", "VectorStore"] +__all__ = ["EntityDeduplicator", "GraphStore", "OntologyStore", "VectorStore"] diff --git a/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py b/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py new file mode 100644 index 00000000..71b7c8b5 --- /dev/null +++ b/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py @@ -0,0 +1,228 @@ +"""Persistent ontology storage in a dedicated FalkorDB graph. + +The ontology lives in a separate FalkorDB graph named ``__ontology`` +so it survives drops of the data graph and can be inspected via Cypher. + +Ingest passes call :py:meth:`OntologyStore.register` with the run's local schema; +each register call is an idempotent union into the persisted ontology. Retrieval +calls :py:meth:`OntologyStore.load` to fetch the **global** ontology (union of +every schema ever registered) and feeds it into the Cypher generation prompt. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from graphrag_sdk.core.connection import FalkorDBConnection +from graphrag_sdk.core.models import ( + EntityType, + GraphSchema, + PropertyType, + RelationType, +) + +logger = logging.getLogger(__name__) + + +def _encode_patterns(patterns: list[tuple[str, str]]) -> list[str]: + return [f"{src}|{tgt}" for src, tgt in patterns] + + +def _decode_patterns(encoded: list[str] | None) -> list[tuple[str, str]]: + out: list[tuple[str, str]] = [] + for s in encoded or []: + if not isinstance(s, str) or "|" not in s: + continue + src, tgt = s.split("|", 1) + out.append((src, tgt)) + return out + + +class OntologyStore: + """Persists and loads :py:class:`GraphSchema` in a dedicated FalkorDB graph. + + The store owns its own graph handle, derived from the same FalkorDB driver + as the data-graph connection. Queries go directly to the FalkorDB driver + and bypass the connection's retry / circuit-breaker layer — ontology + operations are infrequent, idempotent, and tolerant of a single failure + (the caller can retry by re-registering). + """ + + ONTOLOGY_GRAPH_SUFFIX = "__ontology" + + def __init__(self, connection: FalkorDBConnection, data_graph_name: str) -> None: + self._conn = connection + self._graph_name = f"{data_graph_name}{self.ONTOLOGY_GRAPH_SUFFIX}" + self._graph: Any | None = None + + def _ensure_graph(self) -> Any: + if self._graph is not None: + return self._graph + self._conn._ensure_client() + driver = self._conn._driver + if driver is None: + raise RuntimeError("FalkorDB driver not initialised on connection") + self._graph = driver.select_graph(self._graph_name) + return self._graph + + @property + def graph_name(self) -> str: + return self._graph_name + + async def _query(self, cypher: str, params: dict[str, Any] | None = None) -> Any: + graph = self._ensure_graph() + return await graph.query(cypher, params=params) + + async def load(self) -> GraphSchema: + """Read the ontology graph and reconstruct a :py:class:`GraphSchema`. + + Returns an empty schema if the ontology graph does not yet exist. + """ + try: + ent_result = await self._query( + "MATCH (e:OntologyEntityType) " + "OPTIONAL MATCH (e)-[:HAS_PROPERTY]->(p:OntologyProperty) " + "RETURN e.label AS label, e.description AS description, " + "collect({name: p.name, type: p.type, description: p.description, " + "required: p.required}) AS properties" + ) + rel_result = await self._query( + "MATCH (r:OntologyRelationType) " + "OPTIONAL MATCH (r)-[:HAS_PROPERTY]->(p:OntologyProperty) " + "RETURN r.label AS label, r.description AS description, " + "r.patterns AS patterns, " + "collect({name: p.name, type: p.type, description: p.description, " + "required: p.required}) AS properties" + ) + except Exception as exc: + logger.debug("Ontology load failed (returning empty schema): %s", exc) + return GraphSchema() + + entities = [ + EntityType( + label=row[0], + description=row[1], + properties=_props_from_rows(row[2]), + ) + for row in (ent_result.result_set or []) + if row[0] + ] + relations = [ + RelationType( + label=row[0], + description=row[1], + patterns=_decode_patterns(row[2]), + properties=_props_from_rows(row[3]), + ) + for row in (rel_result.result_set or []) + if row[0] + ] + return GraphSchema(entities=entities, relations=relations) + + async def register(self, schema: GraphSchema) -> GraphSchema: + """Merge ``schema`` into the persisted ontology; return the new global ontology. + + Idempotent. ``MERGE`` keys on ``(label, name)``; descriptions/types use + last-write-wins; relation ``patterns`` are union-merged. + """ + if not schema.entities and not schema.relations: + return await self.load() + + for et in schema.entities: + await self._upsert_entity_type(et) + for rt in schema.relations: + await self._upsert_relation_type(rt) + + return await self.load() + + async def _upsert_entity_type(self, et: EntityType) -> None: + await self._query( + "MERGE (e:OntologyEntityType {label: $label}) " + "SET e.description = coalesce($description, e.description)", + {"label": et.label, "description": et.description}, + ) + for prop in et.properties: + await self._upsert_property(et.label, prop, owner_label="OntologyEntityType") + + async def _upsert_relation_type(self, rt: RelationType) -> None: + new_patterns = _encode_patterns(rt.patterns) + result = await self._query( + "MATCH (r:OntologyRelationType {label: $label}) " + "RETURN r.patterns AS patterns", + {"label": rt.label}, + ) + existing: list[str] = [] + if result.result_set: + existing = list(result.result_set[0][0] or []) + seen: set[str] = set() + merged: list[str] = [] + for s in existing + new_patterns: + if s not in seen: + seen.add(s) + merged.append(s) + await self._query( + "MERGE (r:OntologyRelationType {label: $label}) " + "SET r.description = coalesce($description, r.description), " + "r.patterns = $patterns", + {"label": rt.label, "description": rt.description, "patterns": merged}, + ) + for prop in rt.properties: + await self._upsert_property(rt.label, prop, owner_label="OntologyRelationType") + + async def _upsert_property( + self, owner_label_value: str, prop: PropertyType, *, owner_label: str + ) -> None: + await self._query( + f"MATCH (o:{owner_label} {{label: $owner}}) " + "MERGE (o)-[:HAS_PROPERTY]->(p:OntologyProperty {name: $name}) " + "SET p.type = $type, " + "p.description = coalesce($description, p.description), " + "p.required = $required", + { + "owner": owner_label_value, + "name": prop.name, + "type": prop.type, + "description": prop.description, + "required": prop.required, + }, + ) + + async def clear(self) -> None: + """Drop the ontology graph (``GRAPH.DELETE``). Idempotent for empty graphs.""" + self._conn._ensure_client() + from redis.asyncio import Redis + + redis: Redis = Redis(connection_pool=self._conn._pool) + try: + await redis.execute_command("GRAPH.DELETE", self._graph_name) + except Exception as exc: + if "empty" in str(exc).lower() or "invalid" in str(exc).lower(): + logger.debug("Ontology graph '%s' already empty", self._graph_name) + else: + raise + self._graph = None + + +def _props_from_rows(rows: list[Any] | None) -> list[PropertyType]: + """Convert a ``collect(...)`` result into ``PropertyType`` objects. + + FalkorDB ``OPTIONAL MATCH`` with ``collect`` yields a list containing one + null-keyed dict when there are no matches; we filter those out. + """ + out: list[PropertyType] = [] + for row in rows or []: + if not row or not isinstance(row, dict): + continue + name = row.get("name") + if not name: + continue + out.append( + PropertyType( + name=name, + type=row.get("type") or "STRING", + description=row.get("description"), + required=bool(row.get("required")), + ) + ) + return out diff --git a/graphrag_sdk/tests/test_attribute_prompt.py b/graphrag_sdk/tests/test_attribute_prompt.py new file mode 100644 index 00000000..0db5a57a --- /dev/null +++ b/graphrag_sdk/tests/test_attribute_prompt.py @@ -0,0 +1,209 @@ +"""Tests for the attribute-aware prompt and coercion helpers in +``graph_extraction.py``.""" +from __future__ import annotations + +import pytest + +from graphrag_sdk.core.models import ( + EntityType, + GraphSchema, + PropertyType, + RelationType, +) +from graphrag_sdk.ingestion.extraction_strategies.graph_extraction import ( + VERIFY_EXTRACT_RELS_PROMPT, + _DEFAULT_JSON_EXAMPLE, + _JSON_EXAMPLE_WITH_ATTRS, + _coerce_attribute_value, + _coerce_attributes, + _render_attribute_schema_block, + _schema_has_attributes, +) + + +# ── _render_attribute_schema_block ────────────────────────────── + + +class TestRenderAttributeSchemaBlock: + def test_empty_schema_renders_empty(self): + assert _render_attribute_schema_block(GraphSchema()) == "" + + def test_schema_with_no_attributes_renders_empty(self): + s = GraphSchema(entities=[EntityType(label="Person")]) + assert _render_attribute_schema_block(s) == "" + + def test_includes_only_types_with_declared_properties(self): + s = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[PropertyType(name="age", type="INTEGER")], + ), + EntityType(label="Company"), # no properties; should not appear + ], + ) + block = _render_attribute_schema_block(s) + assert "Person" in block + assert "age (INTEGER)" in block + # Bare Company without a colon shouldn't appear; the only "Company" + # in the rendered output would be as part of an entity bullet. + assert "- Company:" not in block + + def test_marks_required_attributes(self): + s = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[ + PropertyType(name="birth_date", type="DATE", required=True) + ], + ), + ], + ) + block = _render_attribute_schema_block(s) + assert "(required)" in block + + def test_renders_relation_attributes(self): + s = GraphSchema( + entities=[EntityType(label="Person"), EntityType(label="Company")], + relations=[ + RelationType( + label="WORKS_AT", + properties=[PropertyType(name="since", type="DATE")], + ), + ], + ) + block = _render_attribute_schema_block(s) + assert "Relation attributes" in block + assert "WORKS_AT" in block + assert "since (DATE)" in block + + +class TestSchemaHasAttributes: + def test_empty_schema(self): + assert _schema_has_attributes(GraphSchema()) is False + + def test_only_entity_attrs(self): + s = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[PropertyType(name="age", type="INTEGER")], + ) + ] + ) + assert _schema_has_attributes(s) is True + + def test_only_relation_attrs(self): + s = GraphSchema( + relations=[ + RelationType( + label="WORKS_AT", + properties=[PropertyType(name="since", type="DATE")], + ) + ] + ) + assert _schema_has_attributes(s) is True + + +# ── _coerce_attribute_value ────────────────────────────────────── + + +class TestCoerceAttributeValue: + @pytest.mark.parametrize( + "value,prop_type,expected", + [ + ("56", "INTEGER", (True, 56)), + (56.7, "INTEGER", (True, 56)), + ("abc", "INTEGER", (False, None)), + (None, "INTEGER", (False, None)), + ("1867-11-07", "DATE", (True, "1867-11-07")), + ("1867-11-07T12:00:00", "DATE", (True, "1867-11-07")), + ("not a date", "DATE", (False, None)), + ("yes", "BOOLEAN", (True, True)), + ("False", "BOOLEAN", (True, False)), + ("maybe", "BOOLEAN", (False, None)), + (1, "BOOLEAN", (True, True)), + ("hello", "STRING", (True, "hello")), + (" ", "STRING", (False, None)), + (42, "STRING", (True, "42")), + (["a", "b"], "LIST", (True, ["a", "b"])), + ("solo", "LIST", (True, ["solo"])), + ({"bad": True}, "LIST", (False, None)), + ], + ) + def test_matrix(self, value, prop_type, expected): + assert _coerce_attribute_value(value, prop_type) == expected + + +class TestCoerceAttributes: + def test_drops_unknown_and_optional_invalid(self): + declared = { + "age": PropertyType(name="age", type="INTEGER", required=False), + } + result, missing = _coerce_attributes( + {"age": "abc", "unknown": "x"}, declared + ) + # 'unknown' is not in declared so it never enters the result. + assert result == {} + # 'age' was optional, so coercion failure does NOT flag it as missing. + assert missing == [] + + def test_required_missing_when_value_absent(self): + declared = { + "birth_date": PropertyType(name="birth_date", type="DATE", required=True), + } + _, missing = _coerce_attributes({}, declared) + assert missing == ["birth_date"] + + def test_required_missing_when_value_uncoercible(self): + declared = { + "age": PropertyType(name="age", type="INTEGER", required=True), + } + _, missing = _coerce_attributes({"age": "abc"}, declared) + assert missing == ["age"] + + +# ── prompt template integration ────────────────────────────────── + + +class TestPromptTemplate: + def test_property_less_schema_keeps_block_empty(self): + s = GraphSchema(entities=[EntityType(label="Person")]) + prompt = VERIFY_EXTRACT_RELS_PROMPT.format( + entity_types="Person", + relation_patterns="", + attribute_schema_block=_render_attribute_schema_block(s), + relationship_type_instruction="- type: ...\n", + entities_json="[]", + text="...", + json_example=_DEFAULT_JSON_EXAMPLE, + ) + # No attribute-extraction section should appear. + assert "## Attribute extraction" not in prompt + # Original entity-shape example must be intact. + assert '"description": "..."' in prompt + + def test_schema_with_attributes_injects_section_and_example(self): + s = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[PropertyType(name="age", type="INTEGER")], + ), + ] + ) + prompt = VERIFY_EXTRACT_RELS_PROMPT.format( + entity_types="Person", + relation_patterns="", + attribute_schema_block=_render_attribute_schema_block(s), + relationship_type_instruction="- type: ...\n", + entities_json="[]", + text="...", + json_example=_JSON_EXAMPLE_WITH_ATTRS, + ) + assert "## Attribute extraction" in prompt + assert "age (INTEGER)" in prompt + # The example output schema must mention attributes so the LLM + # knows to include the field. + assert '"attributes"' in prompt diff --git a/graphrag_sdk/tests/test_cypher_generation.py b/graphrag_sdk/tests/test_cypher_generation.py index 58694158..2942287a 100644 --- a/graphrag_sdk/tests/test_cypher_generation.py +++ b/graphrag_sdk/tests/test_cypher_generation.py @@ -197,3 +197,95 @@ async def test_parses_result_rows(self): assert "Alice" in facts[0] assert "alice" in entities assert "bob" in entities + + +# ── Schema-aware prompt + validator ────────────────────────────── + + +class TestRenderSchemaBlock: + def test_open_schema_falls_back_to_historic_labels(self): + from graphrag_sdk.core.models import GraphSchema + from graphrag_sdk.retrieval.strategies.cypher_generation import ( + render_schema_block, + ) + block = render_schema_block(GraphSchema()) + assert "- Person" in block + assert "name (STRING)" in block + assert "rel_type (STRING)" in block + + def test_schema_block_includes_declared_attributes(self): + from graphrag_sdk.core.models import ( + EntityType, + GraphSchema, + PropertyType, + RelationType, + ) + from graphrag_sdk.retrieval.strategies.cypher_generation import ( + render_schema_block, + ) + s = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[ + PropertyType(name="age", type="INTEGER", description="years") + ], + ), + EntityType(label="Company"), + ], + relations=[ + RelationType( + label="WORKS_AT", + patterns=[("Person", "Company")], + properties=[PropertyType(name="since", type="DATE")], + ), + ], + ) + block = render_schema_block(s) + assert "age (INTEGER)" in block + assert "since (DATE)" in block + assert "WORKS_AT" in block + + +class TestBuildSchemaPrompt: + def test_includes_question_and_schema(self): + from graphrag_sdk.core.models import ( + EntityType, + GraphSchema, + PropertyType, + ) + from graphrag_sdk.retrieval.strategies.cypher_generation import ( + build_schema_prompt, + ) + s = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[PropertyType(name="age", type="INTEGER")], + ) + ] + ) + prompt = build_schema_prompt(s, "Who is older than 30?") + assert "Who is older than 30?" in prompt + assert "age (INTEGER)" in prompt + assert ".age" in prompt # synthesized numeric-filter example + + +class TestValidateCypherWithSchema: + def test_unknown_label_flagged_when_schema_provided(self): + from graphrag_sdk.core.models import EntityType, GraphSchema + s = GraphSchema(entities=[EntityType(label="Person")]) + errors = validate_cypher("MATCH (x:Bogus) RETURN x LIMIT 10", s) + assert any("Unknown label: Bogus" in e for e in errors) + + def test_declared_label_accepted(self): + from graphrag_sdk.core.models import EntityType, GraphSchema + s = GraphSchema(entities=[EntityType(label="Customer")]) + errors = validate_cypher( + "MATCH (c:Customer) RETURN c.name LIMIT 10", s + ) + assert errors == [] + + def test_no_schema_falls_back_to_historic_labels(self): + errors = validate_cypher("MATCH (p:Person) RETURN p LIMIT 10") + assert errors == [] diff --git a/graphrag_sdk/tests/test_graph_extraction.py b/graphrag_sdk/tests/test_graph_extraction.py index a7c77937..494ffe6d 100644 --- a/graphrag_sdk/tests/test_graph_extraction.py +++ b/graphrag_sdk/tests/test_graph_extraction.py @@ -712,3 +712,129 @@ def invoke(self, prompt, **kwargs): step2_prompt = captured_prompts[1] assert "## Allowed Relationships" in step2_prompt assert "WORKS_AT (Person \u2192 Company): Works at" in step2_prompt + + +class TestGraphExtractionSchemaAttributes: + """Coverage for schema-declared attributes flowing through extraction.""" + + def test_parse_step2_coerces_declared_attributes(self): + from graphrag_sdk.core.models import PropertyType + content = json.dumps({ + "entities": [ + { + "name": "Marie Curie", + "type": "Person", + "description": "Scientist", + "attributes": {"age": "56", "birth_date": "1867-11-07"}, + }, + ], + "relationships": [], + }) + schema = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[ + PropertyType(name="age", type="INTEGER"), + PropertyType(name="birth_date", type="DATE"), + ], + ) + ] + ) + ents, _ = GraphExtraction._parse_step2_response( + content, ["Person"], "c1", schema + ) + assert len(ents) == 1 + assert ents[0].attributes == {"age": 56, "birth_date": "1867-11-07"} + + def test_parse_step2_drops_entity_missing_required_attribute(self): + from graphrag_sdk.core.models import PropertyType + content = json.dumps({ + "entities": [ + { + "name": "Marie Curie", + "type": "Person", + "description": "", + "attributes": {"age": 56}, + }, + { + "name": "Pierre Curie", + "type": "Person", + "description": "", + "attributes": {"age": 39, "birth_date": "1859-05-15"}, + }, + ], + "relationships": [], + }) + schema = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[ + PropertyType(name="age", type="INTEGER"), + PropertyType(name="birth_date", type="DATE", required=True), + ], + ) + ] + ) + ents, _ = GraphExtraction._parse_step2_response( + content, ["Person"], "c1", schema + ) + assert [e.name for e in ents] == ["Pierre Curie"] + + def test_aggregator_carries_attributes_with_last_write_wins(self): + from graphrag_sdk.core.models import ExtractedEntity + e1 = ExtractedEntity( + name="Marie Curie", + type="Person", + description="", + source_chunk_ids=["c1"], + attributes={"age": 56, "country": "France"}, + ) + e2 = ExtractedEntity( + name="Marie Curie", + type="Person", + description="", + source_chunk_ids=["c2"], + attributes={"age": 58, "birth_year": 1867}, + ) + merged = GraphExtraction._aggregate_entities([e1, e2]) + assert len(merged) == 1 + # `age` is overwritten last-write-wins; the other keys are unioned. + assert merged[0].attributes == { + "age": 58, + "country": "France", + "birth_year": 1867, + } + + def test_attributes_merged_into_node_properties(self): + from graphrag_sdk.core.models import ExtractedEntity + ent = ExtractedEntity( + name="Marie Curie", + type="Person", + description="d", + source_chunk_ids=["c1"], + attributes={"age": 56, "birth_date": "1867-11-07"}, + ) + nodes = GraphExtraction._entities_to_nodes([ent]) + assert nodes[0].properties["age"] == 56 + assert nodes[0].properties["birth_date"] == "1867-11-07" + # Reserved keys still present. + assert nodes[0].properties["name"] == "Marie Curie" + + def test_property_less_schema_keeps_attributes_empty(self): + """Deterministic-extractor regression: schema declares no + properties \u2192 records have empty attributes \u2192 storage gets only the + reserved keys.""" + content = json.dumps({ + "entities": [ + {"name": "Alice", "type": "Person", "description": "A person"}, + ], + "relationships": [], + }) + schema = GraphSchema(entities=[EntityType(label="Person")]) + ents, _ = GraphExtraction._parse_step2_response( + content, ["Person"], "c1", schema + ) + assert len(ents) == 1 + assert ents[0].attributes == {} diff --git a/graphrag_sdk/tests/test_models.py b/graphrag_sdk/tests/test_models.py index b4fdd022..d4306127 100644 --- a/graphrag_sdk/tests/test_models.py +++ b/graphrag_sdk/tests/test_models.py @@ -136,11 +136,90 @@ def test_document_output(self): class TestSchemaTypes: def test_property_type(self): - pt = PropertyType(name="age", type="INT", required=True) + pt = PropertyType(name="age", type="INTEGER", required=True) assert pt.name == "age" - assert pt.type == "INT" + assert pt.type == "INTEGER" assert pt.required is True + def test_property_type_normalizes_case(self): + pt = PropertyType(name="age", type="integer") + assert pt.type == "INTEGER" + + def test_property_type_rejects_unknown_type(self): + with pytest.raises(ValidationError): + PropertyType(name="age", type="OBJECT") + + def test_relation_type_properties_default_empty(self): + rt = RelationType(label="WORKS_AT") + assert rt.properties == [] + + def test_relation_type_with_properties(self): + rt = RelationType( + label="WORKS_AT", + properties=[PropertyType(name="since", type="DATE")], + ) + assert len(rt.properties) == 1 + assert rt.properties[0].type == "DATE" + + def test_graph_schema_rejects_reserved_entity_property_name(self): + with pytest.raises(ValidationError): + GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[PropertyType(name="name", type="STRING")], + ) + ] + ) + + def test_graph_schema_rejects_reserved_relation_property_name(self): + with pytest.raises(ValidationError): + GraphSchema( + relations=[ + RelationType( + label="WORKS_AT", + properties=[PropertyType(name="rel_type", type="STRING")], + ) + ] + ) + + def test_graph_schema_merge_unions_entities_relations_and_properties(self): + a = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[PropertyType(name="age", type="INTEGER")], + ) + ], + relations=[ + RelationType(label="WORKS_AT", patterns=[("Person", "Company")]), + ], + ) + b = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[PropertyType(name="dob", type="DATE")], + ), + EntityType(label="Company"), + ], + relations=[ + RelationType( + label="WORKS_AT", + patterns=[("Person", "Organization")], + properties=[PropertyType(name="since", type="DATE")], + ), + ], + ) + merged = a.merge(b) + assert {e.label for e in merged.entities} == {"Person", "Company"} + person = next(e for e in merged.entities if e.label == "Person") + assert {p.name for p in person.properties} == {"age", "dob"} + works = next(r for r in merged.relations if r.label == "WORKS_AT") + assert ("Person", "Company") in works.patterns + assert ("Person", "Organization") in works.patterns + assert {p.name for p in works.properties} == {"since"} + def test_entity_type(self): et = EntityType(label="Person", description="A human") assert et.label == "Person" diff --git a/graphrag_sdk/tests/test_ontology_store.py b/graphrag_sdk/tests/test_ontology_store.py new file mode 100644 index 00000000..b12ac427 --- /dev/null +++ b/graphrag_sdk/tests/test_ontology_store.py @@ -0,0 +1,226 @@ +"""Tests for storage/ontology_store.py — persistent ontology graph layer. + +The store talks to FalkorDB directly; unit tests here mock the graph handle +through ``FalkorDBConnection``'s private ``_driver.select_graph()`` seam. +Real-FalkorDB exercise is left for the integration suite. +""" +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from graphrag_sdk.core.models import ( + EntityType, + GraphSchema, + PropertyType, + RelationType, +) +from graphrag_sdk.storage.ontology_store import ( + OntologyStore, + _decode_patterns, + _encode_patterns, +) + + +class _FakeQueryResult: + """Stand-in for falkordb's QueryResult.""" + + def __init__(self, rows: list[list]): + self.result_set = rows + + +class _FakeGraph: + """In-memory ``AsyncGraph`` substitute. + + Captures every ``query()`` call (cypher + params) so tests can assert on + them, and serves canned responses for the load() queries. + """ + + def __init__(self): + self.calls: list[tuple[str, dict | None]] = [] + self._ent_rows: list[list] = [] + self._rel_rows: list[list] = [] + self._patterns_for_label: dict[str, list[str]] = {} + + def set_load_response(self, entity_rows, relation_rows): + self._ent_rows = entity_rows + self._rel_rows = relation_rows + + def set_existing_patterns(self, label: str, patterns: list[str]): + self._patterns_for_label[label] = patterns + + async def query(self, cypher: str, params: dict | None = None): + self.calls.append((cypher, params)) + if "MATCH (e:OntologyEntityType)" in cypher and "collect" in cypher: + return _FakeQueryResult(self._ent_rows) + if "MATCH (r:OntologyRelationType)" in cypher and "collect" in cypher: + return _FakeQueryResult(self._rel_rows) + if "MATCH (r:OntologyRelationType {label: $label})" in cypher and "patterns" in cypher: + label = (params or {}).get("label", "") + return _FakeQueryResult([[self._patterns_for_label.get(label, [])]]) + return _FakeQueryResult([]) + + +@pytest.fixture +def fake_graph(): + return _FakeGraph() + + +@pytest.fixture +def store_factory(fake_graph): + """Returns a callable producing an OntologyStore whose `_query` is wired + to ``fake_graph``. Skips the real driver/connection plumbing entirely.""" + + def _make(data_graph_name: str = "kg") -> OntologyStore: + conn = MagicMock() + conn._ensure_client = MagicMock() + conn._driver = SimpleNamespace( + select_graph=MagicMock(return_value=fake_graph), + ) + return OntologyStore(conn, data_graph_name) + + return _make + + +class TestEncoders: + def test_encode_then_decode_roundtrip(self): + patterns = [("Person", "Company"), ("Person", "Organization")] + encoded = _encode_patterns(patterns) + assert encoded == ["Person|Company", "Person|Organization"] + assert _decode_patterns(encoded) == patterns + + def test_decode_handles_none_and_bad_strings(self): + assert _decode_patterns(None) == [] + assert _decode_patterns(["no-pipe", "a|b"]) == [("a", "b")] + + +class TestOntologyStoreGraphName: + def test_suffix(self, store_factory): + store = store_factory("my_kg") + assert store.graph_name == "my_kg__ontology" + + +class TestOntologyStoreRegister: + @pytest.mark.asyncio + async def test_empty_schema_short_circuits_to_load(self, store_factory, fake_graph): + store = store_factory() + result = await store.register(GraphSchema()) + # Only the two load queries should have been issued. + ents_q = [c for c in fake_graph.calls if "MATCH (e:OntologyEntityType)" in c[0]] + assert len(ents_q) == 1 + assert isinstance(result, GraphSchema) + + @pytest.mark.asyncio + async def test_registers_entity_type_and_its_properties(self, store_factory, fake_graph): + store = store_factory() + schema = GraphSchema( + entities=[ + EntityType( + label="Person", + description="A human", + properties=[ + PropertyType(name="age", type="INTEGER"), + PropertyType(name="birth_date", type="DATE", required=True), + ], + ), + ], + ) + await store.register(schema) + # One MERGE for the entity-type node, plus one MERGE per property edge. + entity_merges = [ + c for c in fake_graph.calls if "MERGE (e:OntologyEntityType" in c[0] + ] + property_merges = [ + c for c in fake_graph.calls if "MERGE (o)-[:HAS_PROPERTY]->" in c[0] + ] + assert len(entity_merges) == 1 + assert len(property_merges) == 2 + # Property params carry the declared type and required flag. + prop_names = {(c[1] or {}).get("name") for c in property_merges} + assert prop_names == {"age", "birth_date"} + + @pytest.mark.asyncio + async def test_unions_relation_patterns_with_existing( + self, store_factory, fake_graph + ): + store = store_factory() + fake_graph.set_existing_patterns("WORKS_AT", ["Person|Company"]) + schema = GraphSchema( + entities=[EntityType(label="Person"), EntityType(label="Org")], + relations=[ + RelationType(label="WORKS_AT", patterns=[("Person", "Org")]), + ], + ) + await store.register(schema) + rel_set = [ + c + for c in fake_graph.calls + if "MERGE (r:OntologyRelationType {label: $label})" in c[0] and "SET r." in c[0] + ] + assert rel_set, "expected a SET on the RelationType node" + params = rel_set[-1][1] or {} + assert "Person|Company" in params["patterns"] + assert "Person|Org" in params["patterns"] + + +class TestOntologyStoreLoad: + @pytest.mark.asyncio + async def test_empty_graph_returns_empty_schema(self, store_factory, fake_graph): + store = store_factory() + fake_graph.set_load_response([], []) + result = await store.load() + assert result.entities == [] + assert result.relations == [] + + @pytest.mark.asyncio + async def test_reconstructs_schema_from_query_rows(self, store_factory, fake_graph): + store = store_factory() + fake_graph.set_load_response( + entity_rows=[ + [ + "Person", + "A human", + [ + { + "name": "age", + "type": "INTEGER", + "description": None, + "required": False, + }, + # collect() of an OPTIONAL MATCH that found nothing + # may produce an all-None dict — must be filtered. + { + "name": None, + "type": None, + "description": None, + "required": None, + }, + ], + ], + ["Company", None, []], + ], + relation_rows=[ + [ + "WORKS_AT", + "Employment", + ["Person|Company"], + [ + { + "name": "since", + "type": "DATE", + "description": None, + "required": False, + } + ], + ], + ], + ) + schema = await store.load() + assert {e.label for e in schema.entities} == {"Person", "Company"} + person = next(e for e in schema.entities if e.label == "Person") + assert {p.name for p in person.properties} == {"age"} + works = next(r for r in schema.relations if r.label == "WORKS_AT") + assert works.patterns == [("Person", "Company")] + assert {p.name for p in works.properties} == {"since"} diff --git a/graphrag_sdk/tests/test_pipeline.py b/graphrag_sdk/tests/test_pipeline.py index 3945fcf1..3f04161d 100644 --- a/graphrag_sdk/tests/test_pipeline.py +++ b/graphrag_sdk/tests/test_pipeline.py @@ -694,3 +694,144 @@ def test_prune_pattern_mismatch_sample_is_bounded(self, caplog): # Total count is reported, but the sampled list does not contain 50 entries. assert "Pruned 50" in msg assert msg.count("('Company', 'Person')") <= 3 + + +class TestValidateAttributes: + """Schema-attribute validation pass on top of ``_prune``.""" + + def _pipeline(self): + # We only exercise the validator; bypass construction by mocking the + # required strategy collaborators. + return IngestionPipeline( + loader=MagicMock(spec=LoaderStrategy), + chunker=MagicMock(spec=ChunkingStrategy), + extractor=MagicMock(spec=ExtractionStrategy), + resolver=MagicMock(spec=ResolutionStrategy), + graph_store=MagicMock(), + vector_store=MagicMock(), + schema=GraphSchema(), + ) + + def test_empty_schema_is_noop(self): + from graphrag_sdk.core.models import PropertyType + pipe = self._pipeline() + gd = GraphData( + nodes=[GraphNode(id="1", label="Foo", properties={"a": 1})], + ) + out = pipe._validate_attributes(gd, GraphSchema()) + assert len(out.nodes) == 1 + assert out.nodes[0].properties == {"a": 1} + + def test_unknown_attributes_dropped_on_declared_label(self): + from graphrag_sdk.core.models import PropertyType + pipe = self._pipeline() + schema = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[PropertyType(name="age", type="INTEGER")], + ) + ] + ) + node = GraphNode( + id="p1", + label="Person", + properties={ + "name": "Marie", + "age": 56, + "ssn": "123-45-6789", + }, + ) + out = pipe._validate_attributes(GraphData(nodes=[node]), schema) + assert len(out.nodes) == 1 + kept = out.nodes[0].properties + assert "ssn" not in kept + assert kept["age"] == 56 + # Reserved keys preserved. + assert kept["name"] == "Marie" + + def test_missing_required_drops_node(self, caplog): + import logging + from graphrag_sdk.core.models import PropertyType + pipe = self._pipeline() + schema = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[ + PropertyType(name="age", type="INTEGER"), + PropertyType(name="birth_date", type="DATE", required=True), + ], + ) + ] + ) + nodes = [ + GraphNode( + id="ok", + label="Person", + properties={"name": "Marie", "age": 56, "birth_date": "1867-11-07"}, + ), + GraphNode( + id="missing", + label="Person", + properties={"name": "Pierre", "age": 39}, + ), + ] + with caplog.at_level(logging.WARNING, logger="graphrag_sdk.ingestion.pipeline"): + out = pipe._validate_attributes(GraphData(nodes=nodes), schema) + assert [n.id for n in out.nodes] == ["ok"] + assert any( + "missing required attribute" in r.getMessage() + for r in caplog.records + if r.levelno == logging.WARNING + ) + + def test_unknown_node_label_passes_through(self): + """``"Unknown"`` typed nodes (preserved by ``_prune``) must not be + validated against any declared properties.""" + from graphrag_sdk.core.models import PropertyType + pipe = self._pipeline() + schema = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[ + PropertyType(name="age", type="INTEGER", required=True) + ], + ) + ] + ) + nodes = [GraphNode(id="u", label="Unknown", properties={"name": "x"})] + out = pipe._validate_attributes(GraphData(nodes=nodes), schema) + assert len(out.nodes) == 1 + assert out.nodes[0].id == "u" + + def test_drops_relationship_missing_required_attribute(self): + from graphrag_sdk.core.models import PropertyType + pipe = self._pipeline() + schema = GraphSchema( + entities=[EntityType(label="Person"), EntityType(label="Company")], + relations=[ + RelationType( + label="WORKS_AT", + properties=[PropertyType(name="since", type="DATE", required=True)], + ) + ], + ) + rels = [ + GraphRelationship( + start_node_id="a", + end_node_id="b", + type="RELATES", + properties={"rel_type": "WORKS_AT", "fact": "f1", "since": "2020-01-01"}, + ), + GraphRelationship( + start_node_id="a", + end_node_id="c", + type="RELATES", + properties={"rel_type": "WORKS_AT", "fact": "f2"}, + ), + ] + out = pipe._validate_attributes(GraphData(relationships=rels), schema) + kept = [r.properties.get("fact") for r in out.relationships] + assert kept == ["f1"] From 54e4e05a6397cb20e8e9a2e96292fb2b7e1f4f03 Mon Sep 17 00:00:00 2001 From: Gal Shubeli Date: Mon, 18 May 2026 09:39:31 +0300 Subject: [PATCH 2/7] refactor(ontology): infer from data graph, drop dedicated ontology graph MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The data graph already contains labels, relationship types, properties, and endpoint patterns — querying it is the source of truth, not a parallel persisted graph. Schema-as-config is supported via plain JSON files. - storage/ontology_store: rewritten as inference-only. OntologyStore.infer() calls db.labels() / db.relationshipTypes(), samples keys()+typeof() per label, derives RELATES sub-types from rel_type values, and pulls endpoint patterns. Structural labels/edges and reserved property keys are filtered out so they never leak into the LLM-facing schema. - core/models: GraphSchema.from_file / save_to_file convenience methods for the schema-as-config workflow (versioned JSON, hand-editable, shareable). - api/main: GraphRAG.get_ontology() returns inferred ∪ self.schema so user- declared descriptions and required flags survive (and properties declared but not yet extracted still appear in the retrieval prompt). New save_ontology(path) writes the global ontology to a JSON file. Ingest no longer writes to a separate ontology graph. The previous PR introduced a persistent __ontology FalkorDB graph; that was solving "declared property with zero instances" — a case that is essentially invisible to retrieval (empty results either way). The runtime cost was a second graph to back up, clean up, and reason about. Inference + an opt-in JSON file covers the same ground in less code. Tests: 830 passed, 23 skipped. test_ontology_store rewritten around mocked-driver introspection. test_models gains a JSON round-trip test. Co-Authored-By: Claude Opus 4.7 (1M context) --- graphrag_sdk/src/graphrag_sdk/api/main.py | 64 +-- graphrag_sdk/src/graphrag_sdk/core/models.py | 18 + .../graphrag_sdk/storage/ontology_store.py | 372 +++++++++--------- graphrag_sdk/tests/test_models.py | 30 ++ graphrag_sdk/tests/test_ontology_store.py | 362 ++++++++--------- 5 files changed, 438 insertions(+), 408 deletions(-) diff --git a/graphrag_sdk/src/graphrag_sdk/api/main.py b/graphrag_sdk/src/graphrag_sdk/api/main.py index 1524108b..0698f1c9 100644 --- a/graphrag_sdk/src/graphrag_sdk/api/main.py +++ b/graphrag_sdk/src/graphrag_sdk/api/main.py @@ -193,12 +193,14 @@ def __init__( # Deduplication engine self._deduplicator = EntityDeduplicator(self._graph_store, self.embedder) - # Persistent ontology lives in `__ontology`. Stays in sync - # across multiple ingest passes that may declare different schemas. - self._ontology_store = OntologyStore(self._conn, self._conn.config.graph_name) - # Global ontology used at retrieval time. Initially the user-supplied - # local schema; refreshed from the persisted ontology after each ingest - # (see refresh_ontology()) and on demand. + # Ontology is inferred from the live data graph (no separate + # persistent graph). User-supplied ``schema`` is merged on top at + # retrieval time so declared descriptions / required flags / not-yet- + # extracted properties survive. + self._ontology_store = OntologyStore(self._conn) + # Global ontology used at retrieval time. Initially just the local + # schema; refresh_ontology() merges in the inferred view on demand + # and after each ingest. self._global_schema: GraphSchema = self.schema # Default retrieval strategy @@ -244,27 +246,28 @@ async def close(self) -> None: # ── Ontology ───────────────────────────────────────────────── async def get_ontology(self) -> GraphSchema: - """Return the persisted **global** ontology (union of every schema - ever registered against this graph). + """Return the **global** ontology used at retrieval time. - Reads from the dedicated ``__ontology`` graph. Returns an - empty schema before the first ingest has run. + The global ontology is the union of: + - the ontology inferred from the live data graph + (``db.labels()`` / ``db.relationshipTypes()`` + property sampling), and + - the user-supplied ``schema`` passed to :py:class:`GraphRAG` + (carries descriptions, ``required`` flags, properties not yet + present in the data). + + Returns an empty schema when both are empty. """ - return await self._ontology_store.load() + inferred = await self._ontology_store.infer() + return inferred.merge(self.schema) if inferred.entities or inferred.relations else self.schema async def refresh_ontology(self) -> GraphSchema: - """Reload the global ontology and propagate it to the retrieval path. + """Recompute the global ontology and propagate it to the retrieval path. Called automatically after each ``ingest()``. Call explicitly when - another process has registered new schema and you want the next - retrieval to see it without re-ingesting first. + the data graph has changed under your feet (concurrent writer, + manual edits) and the next retrieval should see it. """ - loaded = await self._ontology_store.load() - if loaded.entities or loaded.relations: - self._global_schema = loaded - else: - # No persisted ontology yet — keep the user-provided local schema. - self._global_schema = self.schema + self._global_schema = await self.get_ontology() # Best-effort propagation to the retrieval strategy. Built-in # MultiPathRetrieval honours this; custom strategies opt in by exposing # a ``_schema`` attribute or accepting it via constructor. @@ -272,6 +275,16 @@ async def refresh_ontology(self) -> GraphSchema: self._retrieval_strategy._schema = self._global_schema return self._global_schema + async def save_ontology(self, path: str, *, indent: int = 2) -> None: + """Write the current global ontology to ``path`` as JSON. + + Convenience for the schema-as-config workflow: ``rag.save_ontology( + "ontology.json")``, hand-edit / version-control it, then load with + ``GraphSchema.from_file("ontology.json")`` on the next run. + """ + ontology = await self.get_ontology() + ontology.save_to_file(path, indent=indent) + # ── Graph admin ────────────────────────────────────────────── async def get_statistics(self) -> dict[str, Any]: @@ -568,17 +581,6 @@ async def _ingest_single( doc_info = DocumentInfo(uid=resolved_id, path=path_for_node) - # Register this run's local schema into the persisted ontology so the - # global ontology is the union of every schema ever registered. The - # local schema continues to drive *this* run's extraction. - if self.schema.entities or self.schema.relations: - try: - await self._ontology_store.register(self.schema) - except Exception as exc: - logger.warning( - "Ontology registration failed (continuing ingest): %s", exc - ) - pipeline = IngestionPipeline( loader=loader or TextLoader(), chunker=chunker or FixedSizeChunking(), diff --git a/graphrag_sdk/src/graphrag_sdk/core/models.py b/graphrag_sdk/src/graphrag_sdk/core/models.py index 86119263..c2dc7937 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/models.py +++ b/graphrag_sdk/src/graphrag_sdk/core/models.py @@ -286,6 +286,24 @@ def _validate_schema(self) -> GraphSchema: ) return self + @classmethod + def from_file(cls, path: str) -> GraphSchema: + """Load a ``GraphSchema`` from a JSON file. + + The schema-as-config workflow: keep the canonical schema in a JSON + file under version control, load it into the SDK with one call. See + :py:meth:`save_to_file` for the reverse direction. + """ + from pathlib import Path + + return cls.model_validate_json(Path(path).read_text(encoding="utf-8")) + + def save_to_file(self, path: str, *, indent: int = 2) -> None: + """Write this schema to ``path`` as JSON (overwrites existing files).""" + from pathlib import Path + + Path(path).write_text(self.model_dump_json(indent=indent), encoding="utf-8") + def merge(self, other: GraphSchema) -> GraphSchema: """Return a new ``GraphSchema`` that is the union of ``self`` and ``other``. diff --git a/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py b/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py index 71b7c8b5..7a424762 100644 --- a/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py +++ b/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py @@ -1,12 +1,18 @@ -"""Persistent ontology storage in a dedicated FalkorDB graph. +"""Ontology inference from the live data graph. -The ontology lives in a separate FalkorDB graph named ``__ontology`` -so it survives drops of the data graph and can be inspected via Cypher. +The schema is **derived** from what's in the data graph, not maintained in a +separate persistent graph. This keeps the architecture honest: the source of +truth for "what entities and relations exist" is the data itself. -Ingest passes call :py:meth:`OntologyStore.register` with the run's local schema; -each register call is an idempotent union into the persisted ontology. Retrieval -calls :py:meth:`OntologyStore.load` to fetch the **global** ontology (union of -every schema ever registered) and feeds it into the Cypher generation prompt. +Two consumers: +- Retrieval reads the inferred schema each session to build the Cypher prompt. +- ``GraphRAG.get_ontology()`` returns it for inspection. + +Users who want a curated, declarative schema (descriptions, required flags, +not-yet-extracted properties) pass a ``local_schema`` to ``GraphRAG`` — it's +unioned with the inferred schema at retrieval time so declared metadata +survives. ``GraphSchema.save_to_file`` / ``GraphSchema.from_file`` cover the +schema-as-config workflow. """ from __future__ import annotations @@ -16,6 +22,7 @@ from graphrag_sdk.core.connection import FalkorDBConnection from graphrag_sdk.core.models import ( + RESERVED_PROPERTY_NAMES, EntityType, GraphSchema, PropertyType, @@ -25,204 +32,217 @@ logger = logging.getLogger(__name__) -def _encode_patterns(patterns: list[tuple[str, str]]) -> list[str]: - return [f"{src}|{tgt}" for src, tgt in patterns] +# Labels created by the SDK that are not user entities. +_STRUCTURAL_LABELS: frozenset[str] = frozenset({"Chunk", "Document", "__Entity__"}) +# Edge labels created by the SDK that are not user relations. +_STRUCTURAL_REL_TYPES: frozenset[str] = frozenset( + {"PART_OF", "NEXT_CHUNK", "MENTIONED_IN"} +) + +# Property keys we never want to expose to the LLM as "custom attributes". +# These are SDK-internal or reserved meanings; the Cypher prompt already +# emits the reserved ones it cares about (``name``, ``description``, etc.). +_INFER_SKIP_KEYS: frozenset[str] = RESERVED_PROPERTY_NAMES | frozenset( + {"content_hash", "path", "text", "uid", "index", "metadata", "embedding"} +) -def _decode_patterns(encoded: list[str] | None) -> list[tuple[str, str]]: - out: list[tuple[str, str]] = [] - for s in encoded or []: - if not isinstance(s, str) or "|" not in s: - continue - src, tgt = s.split("|", 1) - out.append((src, tgt)) - return out + +# FalkorDB ``typeof()`` returns lowercase strings; map to our PropertyType vocabulary. +_TYPE_MAP: dict[str, str] = { + "string": "STRING", + "integer": "INTEGER", + "double": "FLOAT", + "float": "FLOAT", + "boolean": "BOOLEAN", + "array": "LIST", + "list": "LIST", +} + + +def _normalize_type(raw: str | None) -> str | None: + if not raw: + return None + return _TYPE_MAP.get(raw.strip().lower()) class OntologyStore: - """Persists and loads :py:class:`GraphSchema` in a dedicated FalkorDB graph. + """Infers the working :py:class:`GraphSchema` from the data graph. - The store owns its own graph handle, derived from the same FalkorDB driver - as the data-graph connection. Queries go directly to the FalkorDB driver - and bypass the connection's retry / circuit-breaker layer — ontology - operations are infrequent, idempotent, and tolerant of a single failure - (the caller can retry by re-registering). + No separate FalkorDB graph: this introspects the data graph directly via + ``db.labels()`` / ``db.relationshipTypes()`` plus per-label sampling of + property keys and types. """ - ONTOLOGY_GRAPH_SUFFIX = "__ontology" - - def __init__(self, connection: FalkorDBConnection, data_graph_name: str) -> None: + def __init__(self, connection: FalkorDBConnection) -> None: self._conn = connection - self._graph_name = f"{data_graph_name}{self.ONTOLOGY_GRAPH_SUFFIX}" - self._graph: Any | None = None - - def _ensure_graph(self) -> Any: - if self._graph is not None: - return self._graph - self._conn._ensure_client() - driver = self._conn._driver - if driver is None: - raise RuntimeError("FalkorDB driver not initialised on connection") - self._graph = driver.select_graph(self._graph_name) - return self._graph - - @property - def graph_name(self) -> str: - return self._graph_name - - async def _query(self, cypher: str, params: dict[str, Any] | None = None) -> Any: - graph = self._ensure_graph() - return await graph.query(cypher, params=params) - - async def load(self) -> GraphSchema: - """Read the ontology graph and reconstruct a :py:class:`GraphSchema`. - - Returns an empty schema if the ontology graph does not yet exist. + + async def infer(self, *, sample_size: int = 200) -> GraphSchema: + """Build a :py:class:`GraphSchema` from what is currently in the data graph. + + ``sample_size`` caps the per-label scan used to discover property keys + and types. Returns an empty schema on any introspection failure. """ try: - ent_result = await self._query( - "MATCH (e:OntologyEntityType) " - "OPTIONAL MATCH (e)-[:HAS_PROPERTY]->(p:OntologyProperty) " - "RETURN e.label AS label, e.description AS description, " - "collect({name: p.name, type: p.type, description: p.description, " - "required: p.required}) AS properties" - ) - rel_result = await self._query( - "MATCH (r:OntologyRelationType) " - "OPTIONAL MATCH (r)-[:HAS_PROPERTY]->(p:OntologyProperty) " - "RETURN r.label AS label, r.description AS description, " - "r.patterns AS patterns, " - "collect({name: p.name, type: p.type, description: p.description, " - "required: p.required}) AS properties" - ) + labels_result = await self._conn.query("CALL db.labels()") + rel_types_result = await self._conn.query("CALL db.relationshipTypes()") except Exception as exc: - logger.debug("Ontology load failed (returning empty schema): %s", exc) + logger.debug("Ontology inference: labels/types query failed: %s", exc) return GraphSchema() + labels: list[str] = [ + row[0] + for row in (labels_result.result_set or []) + if row and row[0] and row[0] not in _STRUCTURAL_LABELS + ] + rel_types: list[str] = [ + row[0] + for row in (rel_types_result.result_set or []) + if row and row[0] and row[0] not in _STRUCTURAL_REL_TYPES + ] + entities = [ EntityType( - label=row[0], - description=row[1], - properties=_props_from_rows(row[2]), - ) - for row in (ent_result.result_set or []) - if row[0] - ] - relations = [ - RelationType( - label=row[0], - description=row[1], - patterns=_decode_patterns(row[2]), - properties=_props_from_rows(row[3]), + label=label, + properties=await self._properties_for_node(label, sample_size), ) - for row in (rel_result.result_set or []) - if row[0] + for label in labels ] + # The unified data model writes every user relation as a ``RELATES`` + # edge whose ``rel_type`` property carries the original label; the + # SDK's structural edges (PART_OF/NEXT_CHUNK/MENTIONED_IN) are excluded + # above. We surface the distinct ``rel_type`` values as RelationTypes + # so the Cypher prompt knows the allowed values, and expose their + # property keys + endpoint patterns. + relations: list[RelationType] = [] + if "RELATES" in rel_types: + relations = await self._infer_relates_subtypes(sample_size) + return GraphSchema(entities=entities, relations=relations) - async def register(self, schema: GraphSchema) -> GraphSchema: - """Merge ``schema`` into the persisted ontology; return the new global ontology. + async def _properties_for_node( + self, label: str, sample_size: int + ) -> list[PropertyType]: + try: + result = await self._conn.query( + f"MATCH (n:`{label}`) " + "WITH n LIMIT $limit " + "UNWIND keys(n) AS k " + "WITH k, typeof(n[k]) AS t " + "RETURN k AS key, t AS type, count(*) AS c " + "ORDER BY c DESC", + {"limit": sample_size}, + ) + except Exception as exc: + logger.debug( + "Ontology inference: properties query failed for %s: %s", label, exc + ) + return [] + return _props_from_rows(result.result_set) - Idempotent. ``MERGE`` keys on ``(label, name)``; descriptions/types use - last-write-wins; relation ``patterns`` are union-merged. - """ - if not schema.entities and not schema.relations: - return await self.load() - - for et in schema.entities: - await self._upsert_entity_type(et) - for rt in schema.relations: - await self._upsert_relation_type(rt) - - return await self.load() - - async def _upsert_entity_type(self, et: EntityType) -> None: - await self._query( - "MERGE (e:OntologyEntityType {label: $label}) " - "SET e.description = coalesce($description, e.description)", - {"label": et.label, "description": et.description}, - ) - for prop in et.properties: - await self._upsert_property(et.label, prop, owner_label="OntologyEntityType") - - async def _upsert_relation_type(self, rt: RelationType) -> None: - new_patterns = _encode_patterns(rt.patterns) - result = await self._query( - "MATCH (r:OntologyRelationType {label: $label}) " - "RETURN r.patterns AS patterns", - {"label": rt.label}, - ) - existing: list[str] = [] - if result.result_set: - existing = list(result.result_set[0][0] or []) - seen: set[str] = set() - merged: list[str] = [] - for s in existing + new_patterns: - if s not in seen: - seen.add(s) - merged.append(s) - await self._query( - "MERGE (r:OntologyRelationType {label: $label}) " - "SET r.description = coalesce($description, r.description), " - "r.patterns = $patterns", - {"label": rt.label, "description": rt.description, "patterns": merged}, - ) - for prop in rt.properties: - await self._upsert_property(rt.label, prop, owner_label="OntologyRelationType") - - async def _upsert_property( - self, owner_label_value: str, prop: PropertyType, *, owner_label: str - ) -> None: - await self._query( - f"MATCH (o:{owner_label} {{label: $owner}}) " - "MERGE (o)-[:HAS_PROPERTY]->(p:OntologyProperty {name: $name}) " - "SET p.type = $type, " - "p.description = coalesce($description, p.description), " - "p.required = $required", - { - "owner": owner_label_value, - "name": prop.name, - "type": prop.type, - "description": prop.description, - "required": prop.required, - }, - ) - - async def clear(self) -> None: - """Drop the ontology graph (``GRAPH.DELETE``). Idempotent for empty graphs.""" - self._conn._ensure_client() - from redis.asyncio import Redis - - redis: Redis = Redis(connection_pool=self._conn._pool) + async def _infer_relates_subtypes(self, sample_size: int) -> list[RelationType]: + """Group ``RELATES`` edges by ``rel_type`` and infer per-subtype properties.""" + try: + subtypes_result = await self._conn.query( + "MATCH ()-[r:RELATES]->() " + "WITH r LIMIT $limit " + "WITH DISTINCT r.rel_type AS rel_type " + "WHERE rel_type IS NOT NULL " + "RETURN rel_type", + {"limit": sample_size * 5}, # broader pool to capture rare subtypes + ) + except Exception as exc: + logger.debug("Ontology inference: RELATES subtypes query failed: %s", exc) + return [] + + relations: list[RelationType] = [] + for row in subtypes_result.result_set or []: + subtype = row[0] + if not subtype: + continue + properties = await self._properties_for_relates_subtype(subtype, sample_size) + patterns = await self._patterns_for_relates_subtype(subtype) + relations.append( + RelationType(label=subtype, patterns=patterns, properties=properties) + ) + return relations + + async def _properties_for_relates_subtype( + self, subtype: str, sample_size: int + ) -> list[PropertyType]: + try: + result = await self._conn.query( + "MATCH ()-[r:RELATES {rel_type: $sub}]->() " + "WITH r LIMIT $limit " + "UNWIND keys(r) AS k " + "WITH k, typeof(r[k]) AS t " + "RETURN k AS key, t AS type, count(*) AS c " + "ORDER BY c DESC", + {"sub": subtype, "limit": sample_size}, + ) + except Exception as exc: + logger.debug( + "Ontology inference: relation properties query failed for %s: %s", + subtype, + exc, + ) + return [] + return _props_from_rows(result.result_set) + + async def _patterns_for_relates_subtype(self, subtype: str) -> list[tuple[str, str]]: try: - await redis.execute_command("GRAPH.DELETE", self._graph_name) + result = await self._conn.query( + "MATCH (a)-[r:RELATES {rel_type: $sub}]->(b) " + "WITH labels(a) AS la, labels(b) AS lb " + "RETURN DISTINCT la, lb LIMIT 25", + {"sub": subtype}, + ) except Exception as exc: - if "empty" in str(exc).lower() or "invalid" in str(exc).lower(): - logger.debug("Ontology graph '%s' already empty", self._graph_name) - else: - raise - self._graph = None + logger.debug( + "Ontology inference: endpoint patterns query failed for %s: %s", + subtype, + exc, + ) + return [] + patterns: list[tuple[str, str]] = [] + seen: set[tuple[str, str]] = set() + for row in result.result_set or []: + src_labels, tgt_labels = row[0] or [], row[1] or [] + src = next( + (lbl for lbl in src_labels if lbl not in _STRUCTURAL_LABELS), + None, + ) + tgt = next( + (lbl for lbl in tgt_labels if lbl not in _STRUCTURAL_LABELS), + None, + ) + if not src or not tgt: + continue + key = (src, tgt) + if key in seen: + continue + seen.add(key) + patterns.append(key) + return patterns -def _props_from_rows(rows: list[Any] | None) -> list[PropertyType]: - """Convert a ``collect(...)`` result into ``PropertyType`` objects. +def _props_from_rows(rows: list[list[Any]] | None) -> list[PropertyType]: + """Turn ``(key, typeof, count)`` rows into :py:class:`PropertyType` objects. - FalkorDB ``OPTIONAL MATCH`` with ``collect`` yields a list containing one - null-keyed dict when there are no matches; we filter those out. + Skips reserved/system keys and unmappable types so they never leak into the + LLM-facing schema. """ out: list[PropertyType] = [] + seen: set[str] = set() for row in rows or []: - if not row or not isinstance(row, dict): + if not row or len(row) < 2: continue - name = row.get("name") - if not name: + key, raw_type = row[0], row[1] + if not isinstance(key, str) or key in _INFER_SKIP_KEYS or key in seen: continue - out.append( - PropertyType( - name=name, - type=row.get("type") or "STRING", - description=row.get("description"), - required=bool(row.get("required")), - ) - ) + normalized = _normalize_type(raw_type if isinstance(raw_type, str) else None) + if not normalized: + continue + seen.add(key) + out.append(PropertyType(name=key, type=normalized)) return out diff --git a/graphrag_sdk/tests/test_models.py b/graphrag_sdk/tests/test_models.py index d4306127..abfb1a53 100644 --- a/graphrag_sdk/tests/test_models.py +++ b/graphrag_sdk/tests/test_models.py @@ -183,6 +183,36 @@ def test_graph_schema_rejects_reserved_relation_property_name(self): ] ) + def test_graph_schema_roundtrips_through_json_file(self, tmp_path): + schema = GraphSchema( + entities=[ + EntityType( + label="Person", + description="A human", + properties=[ + PropertyType(name="age", type="INTEGER", required=True), + PropertyType(name="birth_date", type="DATE"), + ], + ) + ], + relations=[ + RelationType( + label="WORKS_AT", + patterns=[("Person", "Company")], + properties=[PropertyType(name="since", type="DATE")], + ), + ], + ) + path = tmp_path / "ontology.json" + schema.save_to_file(str(path)) + # File is real JSON and round-trippable. + loaded = GraphSchema.from_file(str(path)) + assert loaded == schema + # Indentation default is set so files are diff-friendly. + text = path.read_text() + assert "\n" in text + assert " " in text + def test_graph_schema_merge_unions_entities_relations_and_properties(self): a = GraphSchema( entities=[ diff --git a/graphrag_sdk/tests/test_ontology_store.py b/graphrag_sdk/tests/test_ontology_store.py index b12ac427..cd466b45 100644 --- a/graphrag_sdk/tests/test_ontology_store.py +++ b/graphrag_sdk/tests/test_ontology_store.py @@ -1,226 +1,186 @@ -"""Tests for storage/ontology_store.py — persistent ontology graph layer. - -The store talks to FalkorDB directly; unit tests here mock the graph handle -through ``FalkorDBConnection``'s private ``_driver.select_graph()`` seam. -Real-FalkorDB exercise is left for the integration suite. -""" +"""Tests for storage/ontology_store.py — data-graph ontology inference.""" from __future__ import annotations -from types import SimpleNamespace -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock import pytest -from graphrag_sdk.core.models import ( - EntityType, - GraphSchema, - PropertyType, - RelationType, -) from graphrag_sdk.storage.ontology_store import ( OntologyStore, - _decode_patterns, - _encode_patterns, + _normalize_type, + _props_from_rows, ) -class _FakeQueryResult: - """Stand-in for falkordb's QueryResult.""" - - def __init__(self, rows: list[list]): +class _FakeResult: + def __init__(self, rows): self.result_set = rows -class _FakeGraph: - """In-memory ``AsyncGraph`` substitute. - - Captures every ``query()`` call (cypher + params) so tests can assert on - them, and serves canned responses for the load() queries. - """ - - def __init__(self): - self.calls: list[tuple[str, dict | None]] = [] - self._ent_rows: list[list] = [] - self._rel_rows: list[list] = [] - self._patterns_for_label: dict[str, list[str]] = {} - - def set_load_response(self, entity_rows, relation_rows): - self._ent_rows = entity_rows - self._rel_rows = relation_rows - - def set_existing_patterns(self, label: str, patterns: list[str]): - self._patterns_for_label[label] = patterns - - async def query(self, cypher: str, params: dict | None = None): - self.calls.append((cypher, params)) - if "MATCH (e:OntologyEntityType)" in cypher and "collect" in cypher: - return _FakeQueryResult(self._ent_rows) - if "MATCH (r:OntologyRelationType)" in cypher and "collect" in cypher: - return _FakeQueryResult(self._rel_rows) - if "MATCH (r:OntologyRelationType {label: $label})" in cypher and "patterns" in cypher: - label = (params or {}).get("label", "") - return _FakeQueryResult([[self._patterns_for_label.get(label, [])]]) - return _FakeQueryResult([]) - - -@pytest.fixture -def fake_graph(): - return _FakeGraph() - - -@pytest.fixture -def store_factory(fake_graph): - """Returns a callable producing an OntologyStore whose `_query` is wired - to ``fake_graph``. Skips the real driver/connection plumbing entirely.""" - - def _make(data_graph_name: str = "kg") -> OntologyStore: - conn = MagicMock() - conn._ensure_client = MagicMock() - conn._driver = SimpleNamespace( - select_graph=MagicMock(return_value=fake_graph), - ) - return OntologyStore(conn, data_graph_name) - - return _make - - -class TestEncoders: - def test_encode_then_decode_roundtrip(self): - patterns = [("Person", "Company"), ("Person", "Organization")] - encoded = _encode_patterns(patterns) - assert encoded == ["Person|Company", "Person|Organization"] - assert _decode_patterns(encoded) == patterns +def _make_connection(handler): + """Wrap an async function (cypher, params) -> _FakeResult into a fake + ``FalkorDBConnection`` that just routes ``query`` to ``handler``.""" + conn = MagicMock() + conn.query = handler + return conn + + +# ── small helpers ──────────────────────────────────────────────── + + +class TestNormalizeType: + @pytest.mark.parametrize( + "raw,expected", + [ + ("string", "STRING"), + ("STRING", "STRING"), + ("integer", "INTEGER"), + ("double", "FLOAT"), + ("float", "FLOAT"), + ("boolean", "BOOLEAN"), + ("array", "LIST"), + ("list", "LIST"), + ("point", None), + ("null", None), + (None, None), + ("", None), + ], + ) + def test_matrix(self, raw, expected): + assert _normalize_type(raw) == expected + + +class TestPropsFromRows: + def test_skips_reserved_and_unknown_types(self): + rows = [ + ["age", "integer", 5], + ["name", "string", 10], # reserved + ["source_chunk_ids", "array", 5], # reserved + ["weird", "point", 1], # unmapped type + ["email", "string", 8], + ["age", "string", 1], # duplicate -> first wins + ] + out = _props_from_rows(rows) + assert [(p.name, p.type) for p in out] == [ + ("age", "INTEGER"), + ("email", "STRING"), + ] - def test_decode_handles_none_and_bad_strings(self): - assert _decode_patterns(None) == [] - assert _decode_patterns(["no-pipe", "a|b"]) == [("a", "b")] + def test_empty_input(self): + assert _props_from_rows(None) == [] + assert _props_from_rows([]) == [] -class TestOntologyStoreGraphName: - def test_suffix(self, store_factory): - store = store_factory("my_kg") - assert store.graph_name == "my_kg__ontology" +# ── infer() end-to-end (mocked driver) ─────────────────────────── -class TestOntologyStoreRegister: +class TestOntologyStoreInfer: @pytest.mark.asyncio - async def test_empty_schema_short_circuits_to_load(self, store_factory, fake_graph): - store = store_factory() - result = await store.register(GraphSchema()) - # Only the two load queries should have been issued. - ents_q = [c for c in fake_graph.calls if "MATCH (e:OntologyEntityType)" in c[0]] - assert len(ents_q) == 1 - assert isinstance(result, GraphSchema) - - @pytest.mark.asyncio - async def test_registers_entity_type_and_its_properties(self, store_factory, fake_graph): - store = store_factory() - schema = GraphSchema( - entities=[ - EntityType( - label="Person", - description="A human", - properties=[ - PropertyType(name="age", type="INTEGER"), - PropertyType(name="birth_date", type="DATE", required=True), - ], - ), - ], - ) - await store.register(schema) - # One MERGE for the entity-type node, plus one MERGE per property edge. - entity_merges = [ - c for c in fake_graph.calls if "MERGE (e:OntologyEntityType" in c[0] - ] - property_merges = [ - c for c in fake_graph.calls if "MERGE (o)-[:HAS_PROPERTY]->" in c[0] - ] - assert len(entity_merges) == 1 - assert len(property_merges) == 2 - # Property params carry the declared type and required flag. - prop_names = {(c[1] or {}).get("name") for c in property_merges} - assert prop_names == {"age", "birth_date"} + async def test_filters_structural_labels_and_edge_types(self): + calls = [] + + async def handler(cypher, params=None): + calls.append(cypher) + if "db.labels()" in cypher: + return _FakeResult( + [["Person"], ["Chunk"], ["Document"], ["__Entity__"]] + ) + if "db.relationshipTypes()" in cypher: + return _FakeResult( + [["PART_OF"], ["NEXT_CHUNK"], ["MENTIONED_IN"]] + ) + return _FakeResult([]) + + store = OntologyStore(_make_connection(handler)) + schema = await store.infer() + assert [e.label for e in schema.entities] == ["Person"] + assert schema.relations == [] @pytest.mark.asyncio - async def test_unions_relation_patterns_with_existing( - self, store_factory, fake_graph - ): - store = store_factory() - fake_graph.set_existing_patterns("WORKS_AT", ["Person|Company"]) - schema = GraphSchema( - entities=[EntityType(label="Person"), EntityType(label="Org")], - relations=[ - RelationType(label="WORKS_AT", patterns=[("Person", "Org")]), - ], - ) - await store.register(schema) - rel_set = [ - c - for c in fake_graph.calls - if "MERGE (r:OntologyRelationType {label: $label})" in c[0] and "SET r." in c[0] - ] - assert rel_set, "expected a SET on the RelationType node" - params = rel_set[-1][1] or {} - assert "Person|Company" in params["patterns"] - assert "Person|Org" in params["patterns"] - + async def test_relates_subtypes_are_surfaced_with_patterns_and_props(self): + async def handler(cypher, params=None): + if "db.labels()" in cypher: + return _FakeResult([["Person"], ["Company"], ["Location"]]) + if "db.relationshipTypes()" in cypher: + return _FakeResult([["RELATES"], ["MENTIONED_IN"]]) + if "MATCH (n:`Person`)" in cypher: + return _FakeResult([["age", "integer", 3], ["name", "string", 3]]) + if "MATCH (n:`Company`)" in cypher or "MATCH (n:`Location`)" in cypher: + return _FakeResult([]) + if "DISTINCT r.rel_type" in cypher: + return _FakeResult([["WORKS_AT"], ["LOCATED_IN"]]) + if "labels(a)" in cypher: + sub = (params or {}).get("sub", "") + if sub == "WORKS_AT": + return _FakeResult( + [[["Person"], ["Company", "__Entity__"]]] + ) + if sub == "LOCATED_IN": + return _FakeResult( + [[["Person"], ["Location", "__Entity__"]]] + ) + if "UNWIND keys(r)" in cypher: + sub = (params or {}).get("sub", "") + if sub == "WORKS_AT": + # rel_type is reserved and must be skipped from + # discovered properties. + return _FakeResult( + [["since", "string", 2], ["rel_type", "string", 2]] + ) + return _FakeResult([]) + return _FakeResult([]) + + store = OntologyStore(_make_connection(handler)) + schema = await store.infer() + + assert {e.label for e in schema.entities} == {"Person", "Company", "Location"} + person = next(e for e in schema.entities if e.label == "Person") + assert [(p.name, p.type) for p in person.properties] == [("age", "INTEGER")] + + rel_by_label = {r.label: r for r in schema.relations} + assert set(rel_by_label) == {"WORKS_AT", "LOCATED_IN"} + # Patterns strip __Entity__ and pick the user-visible label. + assert rel_by_label["WORKS_AT"].patterns == [("Person", "Company")] + # Reserved rel property is suppressed from the inferred schema. + assert [ + (p.name, p.type) for p in rel_by_label["WORKS_AT"].properties + ] == [("since", "STRING")] + assert rel_by_label["LOCATED_IN"].properties == [] -class TestOntologyStoreLoad: @pytest.mark.asyncio - async def test_empty_graph_returns_empty_schema(self, store_factory, fake_graph): - store = store_factory() - fake_graph.set_load_response([], []) - result = await store.load() - assert result.entities == [] - assert result.relations == [] + async def test_endpoint_pattern_drops_structural_targets(self): + async def handler(cypher, params=None): + if "db.labels()" in cypher: + return _FakeResult([["Person"]]) + if "db.relationshipTypes()" in cypher: + return _FakeResult([["RELATES"]]) + if "MATCH (n:`Person`)" in cypher: + return _FakeResult([]) + if "DISTINCT r.rel_type" in cypher: + return _FakeResult([["KNOWS"]]) + if "labels(a)" in cypher: + # Stale data has both a real and a structural target. + return _FakeResult( + [ + [["Person"], ["Chunk"]], # structural target -> dropped + [["Person"], ["Person"]], # real -> kept + ] + ) + if "UNWIND keys(r)" in cypher: + return _FakeResult([]) + return _FakeResult([]) + + store = OntologyStore(_make_connection(handler)) + schema = await store.infer() + knows = next(r for r in schema.relations if r.label == "KNOWS") + assert knows.patterns == [("Person", "Person")] @pytest.mark.asyncio - async def test_reconstructs_schema_from_query_rows(self, store_factory, fake_graph): - store = store_factory() - fake_graph.set_load_response( - entity_rows=[ - [ - "Person", - "A human", - [ - { - "name": "age", - "type": "INTEGER", - "description": None, - "required": False, - }, - # collect() of an OPTIONAL MATCH that found nothing - # may produce an all-None dict — must be filtered. - { - "name": None, - "type": None, - "description": None, - "required": None, - }, - ], - ], - ["Company", None, []], - ], - relation_rows=[ - [ - "WORKS_AT", - "Employment", - ["Person|Company"], - [ - { - "name": "since", - "type": "DATE", - "description": None, - "required": False, - } - ], - ], - ], - ) - schema = await store.load() - assert {e.label for e in schema.entities} == {"Person", "Company"} - person = next(e for e in schema.entities if e.label == "Person") - assert {p.name for p in person.properties} == {"age"} - works = next(r for r in schema.relations if r.label == "WORKS_AT") - assert works.patterns == [("Person", "Company")] - assert {p.name for p in works.properties} == {"since"} + async def test_introspection_failure_returns_empty_schema(self): + async def handler(cypher, params=None): + raise RuntimeError("connection blew up") + + store = OntologyStore(_make_connection(handler)) + schema = await store.infer() + assert schema.entities == [] + assert schema.relations == [] From 0597de99fc9811462bda80c44bd80d6767aecd34 Mon Sep 17 00:00:00 2001 From: Gal Shubeli Date: Mon, 18 May 2026 13:13:18 +0300 Subject: [PATCH 3/7] chore(lint): fix ruff errors flagged on CI - graph_extraction.py: regroup imports (datetime alongside stdlib); remove unused EntityType import; wrap two over-length attribute-fetch lines. - api/main.py: split a long get_ontology return expression. ruff check src/graphrag_sdk/ passes locally. Tests: 830 passed. Co-Authored-By: Claude Opus 4.7 (1M context) --- graphrag_sdk/src/graphrag_sdk/api/main.py | 4 +++- .../extraction_strategies/graph_extraction.py | 10 +++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/graphrag_sdk/src/graphrag_sdk/api/main.py b/graphrag_sdk/src/graphrag_sdk/api/main.py index 0698f1c9..a3d3b36d 100644 --- a/graphrag_sdk/src/graphrag_sdk/api/main.py +++ b/graphrag_sdk/src/graphrag_sdk/api/main.py @@ -258,7 +258,9 @@ async def get_ontology(self) -> GraphSchema: Returns an empty schema when both are empty. """ inferred = await self._ontology_store.infer() - return inferred.merge(self.schema) if inferred.entities or inferred.relations else self.schema + if inferred.entities or inferred.relations: + return inferred.merge(self.schema) + return self.schema async def refresh_ontology(self) -> GraphSchema: """Recompute the global ontology and propagate it to the retrieval path. diff --git a/graphrag_sdk/src/graphrag_sdk/ingestion/extraction_strategies/graph_extraction.py b/graphrag_sdk/src/graphrag_sdk/ingestion/extraction_strategies/graph_extraction.py index 02ec9615..1274c124 100644 --- a/graphrag_sdk/src/graphrag_sdk/ingestion/extraction_strategies/graph_extraction.py +++ b/graphrag_sdk/src/graphrag_sdk/ingestion/extraction_strategies/graph_extraction.py @@ -7,14 +7,12 @@ import asyncio import json import logging +from datetime import date, datetime from typing import Any from graphrag_sdk.core.context import Context -from datetime import date, datetime - from graphrag_sdk.core.models import ( EntityMention, - EntityType, ExtractedEntity, ExtractedRelation, GraphData, @@ -646,7 +644,8 @@ def _parse_step2_response( declared = ent_props_by_label.get(etype, {}) attributes: dict[str, Any] = {} if declared: - raw_attrs = item.get("attributes") if isinstance(item.get("attributes"), dict) else {} + _raw = item.get("attributes") + raw_attrs = _raw if isinstance(_raw, dict) else {} attributes, missing = _coerce_attributes(raw_attrs, declared) if missing: dropped_required[f"entity:{etype}"] = ( @@ -687,7 +686,8 @@ def _parse_step2_response( declared_rel = rel_props_by_label.get(rel_type, {}) attributes_rel: dict[str, Any] = {} if declared_rel: - raw_attrs = item.get("attributes") if isinstance(item.get("attributes"), dict) else {} + _raw = item.get("attributes") + raw_attrs = _raw if isinstance(_raw, dict) else {} attributes_rel, missing = _coerce_attributes(raw_attrs, declared_rel) if missing: dropped_required[f"relation:{rel_type}"] = ( From adc223f91ed29f6ed892798f005bf9a4107bd5c4 Mon Sep 17 00:00:00 2001 From: Gal Shubeli Date: Tue, 19 May 2026 11:48:49 +0300 Subject: [PATCH 4/7] style(ruff): apply ruff format to feature files CI's "ruff format --check src/" flagged 5 files. Pure formatting; no behavioral change. Tests: 830 passed. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../extraction_strategies/graph_extraction.py | 3 +-- .../src/graphrag_sdk/ingestion/pipeline.py | 4 +--- .../retrieval/strategies/cypher_generation.py | 11 ++++------- .../retrieval/strategies/multi_path.py | 4 +--- .../src/graphrag_sdk/storage/ontology_store.py | 16 ++++------------ 5 files changed, 11 insertions(+), 27 deletions(-) diff --git a/graphrag_sdk/src/graphrag_sdk/ingestion/extraction_strategies/graph_extraction.py b/graphrag_sdk/src/graphrag_sdk/ingestion/extraction_strategies/graph_extraction.py index 1274c124..5d766b93 100644 --- a/graphrag_sdk/src/graphrag_sdk/ingestion/extraction_strategies/graph_extraction.py +++ b/graphrag_sdk/src/graphrag_sdk/ingestion/extraction_strategies/graph_extraction.py @@ -710,8 +710,7 @@ def _parse_step2_response( if dropped_required: for key, n in dropped_required.items(): logger.warning( - "Dropped %d %s record(s) with missing required attribute(s) " - "for chunk %s", + "Dropped %d %s record(s) with missing required attribute(s) for chunk %s", n, key, source_chunk_id, diff --git a/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py b/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py index 4dbcf87e..6b16617d 100644 --- a/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py +++ b/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py @@ -451,9 +451,7 @@ def _prune(self, graph_data: GraphData, schema: GraphSchema) -> GraphData: extracted_relations=graph_data.extracted_relations, ) - def _validate_attributes( - self, graph_data: GraphData, schema: GraphSchema - ) -> GraphData: + def _validate_attributes(self, graph_data: GraphData, schema: GraphSchema) -> GraphData: """Validate node / relationship attributes against the declared schema. - Attribute keys not declared in the schema are dropped (debug-logged). diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py index cc7c89a7..e5e6aca4 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py @@ -121,9 +121,7 @@ def render_schema_block(schema: GraphSchema | None) -> str: d = f" — {desc}" if desc else "" lines.append(f" - {name} ({typ}){d} # declared on RELATES via rel_type filters") if rel_labels: - lines.append( - " Allowed `rel_type` values: " + ", ".join(rel_labels) - ) + lines.append(" Allowed `rel_type` values: " + ", ".join(rel_labels)) lines.append("- MENTIONED_IN: connects entity to Chunk node (provenance)") lines.append("- PART_OF: connects Document to Chunk") lines.append("- NEXT_CHUNK: connects Chunk to next sequential Chunk") @@ -145,7 +143,7 @@ def _render_attribute_examples(schema: GraphSchema | None) -> str: if p.type in _NUMERIC_TYPES and len(examples) < 2: var = et.label[0].lower() examples.append( - f"Question: \"Which {et.label} has the highest {p.name}?\"\n" + f'Question: "Which {et.label} has the highest {p.name}?"\n' f"```cypher\n" f"MATCH ({var}:{et.label})\n" f"WHERE {var}.{p.name} IS NOT NULL\n" @@ -158,6 +156,7 @@ def _render_attribute_examples(schema: GraphSchema | None) -> str: return "" return "\n\n" + "\n\n".join(examples) + # ── Schema prompt ──────────────────────────────────────────────── _SCHEMA_PROMPT_TEMPLATE = """\ @@ -442,9 +441,7 @@ async def execute_cypher_retrieval( On any failure, returns empty results (silent degradation). """ - cypher = await generate_cypher( - llm, question, schema=schema, max_retries=max_retries - ) + cypher = await generate_cypher(llm, question, schema=schema, max_retries=max_retries) if not cypher: return [], {} diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py index 80e29b34..7397c587 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py @@ -199,9 +199,7 @@ async def _execute( if self._enable_cypher: results = await asyncio.gather( search_relates_edges(self._vector, query_vector, self._rel_top_k), - execute_cypher_retrieval( - self._graph, self._llm, query, schema=self._schema - ), + execute_cypher_retrieval(self._graph, self._llm, query, schema=self._schema), return_exceptions=True, ) # Unpack RELATES results diff --git a/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py b/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py index 7a424762..ff80687d 100644 --- a/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py +++ b/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py @@ -36,9 +36,7 @@ _STRUCTURAL_LABELS: frozenset[str] = frozenset({"Chunk", "Document", "__Entity__"}) # Edge labels created by the SDK that are not user relations. -_STRUCTURAL_REL_TYPES: frozenset[str] = frozenset( - {"PART_OF", "NEXT_CHUNK", "MENTIONED_IN"} -) +_STRUCTURAL_REL_TYPES: frozenset[str] = frozenset({"PART_OF", "NEXT_CHUNK", "MENTIONED_IN"}) # Property keys we never want to expose to the LLM as "custom attributes". # These are SDK-internal or reserved meanings; the Cypher prompt already @@ -120,9 +118,7 @@ async def infer(self, *, sample_size: int = 200) -> GraphSchema: return GraphSchema(entities=entities, relations=relations) - async def _properties_for_node( - self, label: str, sample_size: int - ) -> list[PropertyType]: + async def _properties_for_node(self, label: str, sample_size: int) -> list[PropertyType]: try: result = await self._conn.query( f"MATCH (n:`{label}`) " @@ -134,9 +130,7 @@ async def _properties_for_node( {"limit": sample_size}, ) except Exception as exc: - logger.debug( - "Ontology inference: properties query failed for %s: %s", label, exc - ) + logger.debug("Ontology inference: properties query failed for %s: %s", label, exc) return [] return _props_from_rows(result.result_set) @@ -162,9 +156,7 @@ async def _infer_relates_subtypes(self, sample_size: int) -> list[RelationType]: continue properties = await self._properties_for_relates_subtype(subtype, sample_size) patterns = await self._patterns_for_relates_subtype(subtype) - relations.append( - RelationType(label=subtype, patterns=patterns, properties=properties) - ) + relations.append(RelationType(label=subtype, patterns=patterns, properties=properties)) return relations async def _properties_for_relates_subtype( From 297d473dad225b1f924c370d5314de293b4b8590 Mon Sep 17 00:00:00 2001 From: Gal Shubeli Date: Tue, 19 May 2026 14:03:55 +0300 Subject: [PATCH 5/7] refactor(schema): drop PropertyType.required; missing values become null MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `required=True` was a footgun for an LLM-extraction pipeline: when the LLM couldn't find a value in the text (often because the text didn't state it), the entire entity was silently dropped. That destroyed real data. New behavior: every declared property name appears in `ExtractedEntity.attributes` with either the coerced value or `None`. The storage layer's existing `_clean_properties` strips `None` before writing so the graph sees "key missing" — the right null semantics for retrieval (`WHERE p.age > N` naturally excludes nodes without `age`). - core/models: remove `required` field from PropertyType. - ingestion/graph_extraction: `_coerce_attributes` now returns just a dict with `None` for missing/uncoercible values; never drops records. The drop-on-missing-required branches and aggregated warnings are gone. - ingestion/pipeline: `_validate_attributes` strips undeclared keys but never drops records. Cascade `_filter_quality` still runs but no longer has anything to cascade in the attribute-validation case. - storage/ontology_store: docstring no longer mentions required flags. - tests: collapsed the three "required missing" cases into a single "every declared property appears in result; uncoercible -> None" test per file. 829 passed, 23 skipped. Co-Authored-By: Claude Opus 4.7 (1M context) --- graphrag_sdk/src/graphrag_sdk/core/models.py | 6 +- .../extraction_strategies/graph_extraction.py | 59 ++++++------------- .../src/graphrag_sdk/ingestion/pipeline.py | 56 +++--------------- .../graphrag_sdk/storage/ontology_store.py | 8 +-- graphrag_sdk/tests/test_attribute_prompt.py | 50 +++++----------- graphrag_sdk/tests/test_graph_extraction.py | 16 +++-- graphrag_sdk/tests/test_models.py | 5 +- graphrag_sdk/tests/test_pipeline.py | 52 ++++++++-------- 8 files changed, 81 insertions(+), 171 deletions(-) diff --git a/graphrag_sdk/src/graphrag_sdk/core/models.py b/graphrag_sdk/src/graphrag_sdk/core/models.py index c2dc7937..cba79f84 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/models.py +++ b/graphrag_sdk/src/graphrag_sdk/core/models.py @@ -163,7 +163,6 @@ class PropertyType(DataModel): name: str type: str = "STRING" # STRING, INTEGER, FLOAT, BOOLEAN, DATE, LIST description: str | None = None - required: bool = False @model_validator(mode="after") def _normalize_type(self) -> PropertyType: @@ -309,9 +308,8 @@ def merge(self, other: GraphSchema) -> GraphSchema: - Entity / relation types are unioned by ``label``. - For each type, ``properties`` are unioned by ``name``. When the same - property name appears in both, the incoming type/description/required - overrides — last-write-wins, matching the persisted ontology's - register() semantics. + property name appears in both, the incoming type/description overrides + (last-write-wins). - For relations, ``patterns`` are unioned (order-preserving, deduped). """ diff --git a/graphrag_sdk/src/graphrag_sdk/ingestion/extraction_strategies/graph_extraction.py b/graphrag_sdk/src/graphrag_sdk/ingestion/extraction_strategies/graph_extraction.py index 5d766b93..7ec67b4f 100644 --- a/graphrag_sdk/src/graphrag_sdk/ingestion/extraction_strategies/graph_extraction.py +++ b/graphrag_sdk/src/graphrag_sdk/ingestion/extraction_strategies/graph_extraction.py @@ -98,8 +98,7 @@ def _format_property_for_prompt(prop: PropertyType) -> str: desc = f" — {prop.description}" if prop.description else "" - req = " (required)" if prop.required else "" - return f" - {prop.name} ({prop.type}){req}{desc}" + return f" - {prop.name} ({prop.type}){desc}" def _render_attribute_schema_block(schema: GraphSchema) -> str: @@ -213,28 +212,24 @@ def _coerce_attribute_value(value: Any, prop_type: str) -> tuple[bool, Any]: def _coerce_attributes( raw: dict[str, Any] | None, declared: dict[str, PropertyType], -) -> tuple[dict[str, Any], list[str]]: +) -> dict[str, Any]: """Apply per-type coercion against declared properties. - Returns ``(coerced, missing_required)``. ``missing_required`` lists names of - declared ``required=True`` properties whose value is absent or failed to - coerce — callers use it to decide whether to drop the whole record. + Every declared property name appears in the result. Values that are absent + from ``raw`` or fail to coerce are recorded as ``None`` — never drop the + enclosing record. Downstream storage strips ``None`` before writing, so + these turn into "key missing" on the graph node, which is the right + behavior for retrieval (``WHERE p.age > N`` naturally excludes them). """ coerced: dict[str, Any] = {} raw = raw or {} - missing_required: list[str] = [] for prop_name, prop in declared.items(): if prop_name not in raw or raw[prop_name] is None: - if prop.required: - missing_required.append(prop_name) + coerced[prop_name] = None continue ok, val = _coerce_attribute_value(raw[prop_name], prop.type) - if not ok: - if prop.required: - missing_required.append(prop_name) - continue - coerced[prop_name] = val - return coerced, missing_required + coerced[prop_name] = val if ok else None + return coerced def _optional_extras(obj: Any) -> dict[str, Any]: @@ -599,10 +594,11 @@ def _parse_step2_response( ) -> tuple[list[ExtractedEntity], list[ExtractedRelation]]: """Parse the step 2 LLM response (verified entities + relationships). - When ``schema`` declares attributes for an entity / relation type, the - response's ``attributes`` object is coerced to the declared types. - Records whose ``required=True`` properties fail coercion are dropped - and counted in an aggregated warning per type. + When ``schema`` declares attributes for an entity / relation type, every + declared attribute appears in the record's ``attributes`` dict, with + ``None`` for values the LLM didn't supply or couldn't coerce. Records + are never dropped — downstream storage strips ``None`` so the graph + sees "key missing" for the unfilled slots. """ text = _strip_markdown_fences(content) @@ -625,8 +621,6 @@ def _parse_step2_response( if rt.properties: rel_props_by_label[rt.label] = {p.name: p for p in rt.properties} - dropped_required: dict[str, int] = {} - # Parse entities entities: list[ExtractedEntity] = [] for item in data.get("entities", []): @@ -646,12 +640,7 @@ def _parse_step2_response( if declared: _raw = item.get("attributes") raw_attrs = _raw if isinstance(_raw, dict) else {} - attributes, missing = _coerce_attributes(raw_attrs, declared) - if missing: - dropped_required[f"entity:{etype}"] = ( - dropped_required.get(f"entity:{etype}", 0) + 1 - ) - continue + attributes = _coerce_attributes(raw_attrs, declared) entities.append( ExtractedEntity( @@ -688,12 +677,7 @@ def _parse_step2_response( if declared_rel: _raw = item.get("attributes") raw_attrs = _raw if isinstance(_raw, dict) else {} - attributes_rel, missing = _coerce_attributes(raw_attrs, declared_rel) - if missing: - dropped_required[f"relation:{rel_type}"] = ( - dropped_required.get(f"relation:{rel_type}", 0) + 1 - ) - continue + attributes_rel = _coerce_attributes(raw_attrs, declared_rel) relations.append( ExtractedRelation( @@ -707,15 +691,6 @@ def _parse_step2_response( ) ) - if dropped_required: - for key, n in dropped_required.items(): - logger.warning( - "Dropped %d %s record(s) with missing required attribute(s) for chunk %s", - n, - key, - source_chunk_id, - ) - return entities, relations # ── Aggregation ────────────────────────────────────────────── diff --git a/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py b/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py index 6b16617d..98d5d99f 100644 --- a/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py +++ b/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py @@ -452,17 +452,14 @@ def _prune(self, graph_data: GraphData, schema: GraphSchema) -> GraphData: ) def _validate_attributes(self, graph_data: GraphData, schema: GraphSchema) -> GraphData: - """Validate node / relationship attributes against the declared schema. - - - Attribute keys not declared in the schema are dropped (debug-logged). - - Nodes / relationships missing a ``required=True`` declared attribute - are dropped entirely, with a single aggregated warning per type. - - Skips ``"Unknown"`` nodes (the low-confidence label preserved by - :py:meth:`_prune`). - - Reserved SDK keys (``name``, ``description``, ``source_chunk_ids``, - ``spans``, ``rel_type``, ``fact``, ``src_name``, ``tgt_name``, - ``type``, ``id``) are always preserved. + """Strip undeclared attribute keys from nodes / relationships. + + - Keys not declared in the schema and not in the SDK-reserved set are + dropped (debug-logged). + - Records are never dropped. Missing values stay missing on the graph, + which is the correct null semantics for retrieval queries + (``WHERE p.age > N`` naturally excludes nodes without ``age``). + - Skips ``"Unknown"`` nodes (preserved by :py:meth:`_prune`). """ if not schema.entities and not schema.relations: return graph_data @@ -477,14 +474,12 @@ def _validate_attributes(self, graph_data: GraphData, schema: GraphSchema) -> Gr } kept_nodes: list[GraphNode] = [] - required_missing_by_label: dict[str, int] = {} unknown_dropped_by_label: dict[str, int] = {} for node in graph_data.nodes: if node.label == "Unknown" or node.label not in ent_declared: kept_nodes.append(node) continue declared = ent_declared[node.label] - # Drop unknown attribute keys (debug only — high volume). unknown_keys = [ k for k in list(node.properties.keys()) @@ -495,17 +490,6 @@ def _validate_attributes(self, graph_data: GraphData, schema: GraphSchema) -> Gr unknown_dropped_by_label[node.label] = ( unknown_dropped_by_label.get(node.label, 0) + 1 ) - # Drop record on missing required attribute. - missing = [ - name - for name, prop in declared.items() - if prop.required and node.properties.get(name) is None - ] - if missing: - required_missing_by_label[node.label] = ( - required_missing_by_label.get(node.label, 0) + 1 - ) - continue kept_nodes.append(node) for label, count in unknown_dropped_by_label.items(): @@ -514,15 +498,8 @@ def _validate_attributes(self, graph_data: GraphData, schema: GraphSchema) -> Gr count, label, ) - for label, count in required_missing_by_label.items(): - logger.warning( - "Dropped %d %s node(s) with missing required attribute(s)", - count, - label, - ) kept_rels: list[GraphRelationship] = [] - rel_required_missing_by_label: dict[str, int] = {} for rel in graph_data.relationships: rel_label = rel.properties.get("rel_type", rel.type) if rel_label not in rel_declared: @@ -535,25 +512,8 @@ def _validate_attributes(self, graph_data: GraphData, schema: GraphSchema) -> Gr if k not in declared and k not in RESERVED_PROPERTY_NAMES ]: rel.properties.pop(k, None) - missing = [ - name - for name, prop in declared.items() - if prop.required and rel.properties.get(name) is None - ] - if missing: - rel_required_missing_by_label[rel_label] = ( - rel_required_missing_by_label.get(rel_label, 0) + 1 - ) - continue kept_rels.append(rel) - for label, count in rel_required_missing_by_label.items(): - logger.warning( - "Dropped %d [%s] relationship(s) with missing required attribute(s)", - count, - label, - ) - return GraphData( nodes=kept_nodes, relationships=kept_rels, diff --git a/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py b/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py index ff80687d..30835086 100644 --- a/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py +++ b/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py @@ -8,10 +8,10 @@ - Retrieval reads the inferred schema each session to build the Cypher prompt. - ``GraphRAG.get_ontology()`` returns it for inspection. -Users who want a curated, declarative schema (descriptions, required flags, -not-yet-extracted properties) pass a ``local_schema`` to ``GraphRAG`` — it's -unioned with the inferred schema at retrieval time so declared metadata -survives. ``GraphSchema.save_to_file`` / ``GraphSchema.from_file`` cover the +Users who want a curated, declarative schema (descriptions, not-yet-extracted +properties) pass a ``local_schema`` to ``GraphRAG`` — it's unioned with the +inferred schema at retrieval time so declared metadata survives. +``GraphSchema.save_to_file`` / ``GraphSchema.from_file`` cover the schema-as-config workflow. """ diff --git a/graphrag_sdk/tests/test_attribute_prompt.py b/graphrag_sdk/tests/test_attribute_prompt.py index 0db5a57a..905a1fbc 100644 --- a/graphrag_sdk/tests/test_attribute_prompt.py +++ b/graphrag_sdk/tests/test_attribute_prompt.py @@ -49,20 +49,6 @@ def test_includes_only_types_with_declared_properties(self): # in the rendered output would be as part of an entity bullet. assert "- Company:" not in block - def test_marks_required_attributes(self): - s = GraphSchema( - entities=[ - EntityType( - label="Person", - properties=[ - PropertyType(name="birth_date", type="DATE", required=True) - ], - ), - ], - ) - block = _render_attribute_schema_block(s) - assert "(required)" in block - def test_renders_relation_attributes(self): s = GraphSchema( entities=[EntityType(label="Person"), EntityType(label="Company")], @@ -137,31 +123,27 @@ def test_matrix(self, value, prop_type, expected): class TestCoerceAttributes: - def test_drops_unknown_and_optional_invalid(self): + def test_every_declared_property_appears_in_result(self): + """Result shape is uniform: every declared key is present, with the + coerced value or None. Never drop the enclosing record.""" declared = { - "age": PropertyType(name="age", type="INTEGER", required=False), + "age": PropertyType(name="age", type="INTEGER"), + "birth_date": PropertyType(name="birth_date", type="DATE"), + "nickname": PropertyType(name="nickname", type="STRING"), } - result, missing = _coerce_attributes( - {"age": "abc", "unknown": "x"}, declared + result = _coerce_attributes( + {"age": "56", "birth_date": "1867-11-07"}, declared ) - # 'unknown' is not in declared so it never enters the result. - assert result == {} - # 'age' was optional, so coercion failure does NOT flag it as missing. - assert missing == [] + assert result == {"age": 56, "birth_date": "1867-11-07", "nickname": None} - def test_required_missing_when_value_absent(self): - declared = { - "birth_date": PropertyType(name="birth_date", type="DATE", required=True), - } - _, missing = _coerce_attributes({}, declared) - assert missing == ["birth_date"] + def test_uncoercible_value_becomes_none(self): + declared = {"age": PropertyType(name="age", type="INTEGER")} + assert _coerce_attributes({"age": "abc"}, declared) == {"age": None} - def test_required_missing_when_value_uncoercible(self): - declared = { - "age": PropertyType(name="age", type="INTEGER", required=True), - } - _, missing = _coerce_attributes({"age": "abc"}, declared) - assert missing == ["age"] + def test_undeclared_keys_are_ignored(self): + declared = {"age": PropertyType(name="age", type="INTEGER")} + result = _coerce_attributes({"age": 56, "unknown": "x"}, declared) + assert result == {"age": 56} # ── prompt template integration ────────────────────────────────── diff --git a/graphrag_sdk/tests/test_graph_extraction.py b/graphrag_sdk/tests/test_graph_extraction.py index 494ffe6d..7041ecbc 100644 --- a/graphrag_sdk/tests/test_graph_extraction.py +++ b/graphrag_sdk/tests/test_graph_extraction.py @@ -747,7 +747,10 @@ def test_parse_step2_coerces_declared_attributes(self): assert len(ents) == 1 assert ents[0].attributes == {"age": 56, "birth_date": "1867-11-07"} - def test_parse_step2_drops_entity_missing_required_attribute(self): + def test_parse_step2_missing_attribute_becomes_none(self): + """Missing values are represented as ``None`` in attributes — the + entity is NEVER dropped. Storage strips ``None`` before writing so + the graph sees "key missing", which is the right null semantics.""" from graphrag_sdk.core.models import PropertyType content = json.dumps({ "entities": [ @@ -757,12 +760,6 @@ def test_parse_step2_drops_entity_missing_required_attribute(self): "description": "", "attributes": {"age": 56}, }, - { - "name": "Pierre Curie", - "type": "Person", - "description": "", - "attributes": {"age": 39, "birth_date": "1859-05-15"}, - }, ], "relationships": [], }) @@ -772,7 +769,7 @@ def test_parse_step2_drops_entity_missing_required_attribute(self): label="Person", properties=[ PropertyType(name="age", type="INTEGER"), - PropertyType(name="birth_date", type="DATE", required=True), + PropertyType(name="birth_date", type="DATE"), ], ) ] @@ -780,7 +777,8 @@ def test_parse_step2_drops_entity_missing_required_attribute(self): ents, _ = GraphExtraction._parse_step2_response( content, ["Person"], "c1", schema ) - assert [e.name for e in ents] == ["Pierre Curie"] + assert len(ents) == 1 + assert ents[0].attributes == {"age": 56, "birth_date": None} def test_aggregator_carries_attributes_with_last_write_wins(self): from graphrag_sdk.core.models import ExtractedEntity diff --git a/graphrag_sdk/tests/test_models.py b/graphrag_sdk/tests/test_models.py index abfb1a53..7a177024 100644 --- a/graphrag_sdk/tests/test_models.py +++ b/graphrag_sdk/tests/test_models.py @@ -136,10 +136,9 @@ def test_document_output(self): class TestSchemaTypes: def test_property_type(self): - pt = PropertyType(name="age", type="INTEGER", required=True) + pt = PropertyType(name="age", type="INTEGER") assert pt.name == "age" assert pt.type == "INTEGER" - assert pt.required is True def test_property_type_normalizes_case(self): pt = PropertyType(name="age", type="integer") @@ -190,7 +189,7 @@ def test_graph_schema_roundtrips_through_json_file(self, tmp_path): label="Person", description="A human", properties=[ - PropertyType(name="age", type="INTEGER", required=True), + PropertyType(name="age", type="INTEGER"), PropertyType(name="birth_date", type="DATE"), ], ) diff --git a/graphrag_sdk/tests/test_pipeline.py b/graphrag_sdk/tests/test_pipeline.py index 3f04161d..8e020c07 100644 --- a/graphrag_sdk/tests/test_pipeline.py +++ b/graphrag_sdk/tests/test_pipeline.py @@ -750,8 +750,10 @@ def test_unknown_attributes_dropped_on_declared_label(self): # Reserved keys preserved. assert kept["name"] == "Marie" - def test_missing_required_drops_node(self, caplog): - import logging + def test_missing_attribute_keeps_node(self): + """Missing declared attributes don't drop the node — they stay + missing on the graph, which is the right null semantics for + retrieval (``WHERE p.age > N`` naturally excludes them).""" from graphrag_sdk.core.models import PropertyType pipe = self._pipeline() schema = GraphSchema( @@ -760,31 +762,25 @@ def test_missing_required_drops_node(self, caplog): label="Person", properties=[ PropertyType(name="age", type="INTEGER"), - PropertyType(name="birth_date", type="DATE", required=True), + PropertyType(name="birth_date", type="DATE"), ], ) ] ) nodes = [ GraphNode( - id="ok", + id="full", label="Person", properties={"name": "Marie", "age": 56, "birth_date": "1867-11-07"}, ), GraphNode( - id="missing", + id="partial", label="Person", properties={"name": "Pierre", "age": 39}, ), ] - with caplog.at_level(logging.WARNING, logger="graphrag_sdk.ingestion.pipeline"): - out = pipe._validate_attributes(GraphData(nodes=nodes), schema) - assert [n.id for n in out.nodes] == ["ok"] - assert any( - "missing required attribute" in r.getMessage() - for r in caplog.records - if r.levelno == logging.WARNING - ) + out = pipe._validate_attributes(GraphData(nodes=nodes), schema) + assert sorted(n.id for n in out.nodes) == ["full", "partial"] def test_unknown_node_label_passes_through(self): """``"Unknown"`` typed nodes (preserved by ``_prune``) must not be @@ -795,9 +791,7 @@ def test_unknown_node_label_passes_through(self): entities=[ EntityType( label="Person", - properties=[ - PropertyType(name="age", type="INTEGER", required=True) - ], + properties=[PropertyType(name="age", type="INTEGER")], ) ] ) @@ -806,7 +800,7 @@ def test_unknown_node_label_passes_through(self): assert len(out.nodes) == 1 assert out.nodes[0].id == "u" - def test_drops_relationship_missing_required_attribute(self): + def test_undeclared_relationship_attribute_keys_are_stripped(self): from graphrag_sdk.core.models import PropertyType pipe = self._pipeline() schema = GraphSchema( @@ -814,7 +808,7 @@ def test_drops_relationship_missing_required_attribute(self): relations=[ RelationType( label="WORKS_AT", - properties=[PropertyType(name="since", type="DATE", required=True)], + properties=[PropertyType(name="since", type="DATE")], ) ], ) @@ -823,15 +817,19 @@ def test_drops_relationship_missing_required_attribute(self): start_node_id="a", end_node_id="b", type="RELATES", - properties={"rel_type": "WORKS_AT", "fact": "f1", "since": "2020-01-01"}, - ), - GraphRelationship( - start_node_id="a", - end_node_id="c", - type="RELATES", - properties={"rel_type": "WORKS_AT", "fact": "f2"}, + properties={ + "rel_type": "WORKS_AT", + "fact": "f1", + "since": "2020-01-01", + "leaked_key": "noise", + }, ), ] out = pipe._validate_attributes(GraphData(relationships=rels), schema) - kept = [r.properties.get("fact") for r in out.relationships] - assert kept == ["f1"] + assert len(out.relationships) == 1 + props = out.relationships[0].properties + assert "leaked_key" not in props + assert props["since"] == "2020-01-01" + # Reserved keys still preserved. + assert props["rel_type"] == "WORKS_AT" + assert props["fact"] == "f1" From 2a88b2917f3b3395b803507f4a2d62d6b761911a Mon Sep 17 00:00:00 2001 From: Gal Shubeli Date: Tue, 19 May 2026 15:58:03 +0300 Subject: [PATCH 6/7] refactor(ontology): restore ontology graph as anchor with validation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The data graph alone wasn't enough to anchor the schema: - types are lossy (FalkorDB `typeof()` collapses DATE → "string") - multiple processes/instances couldn't share a declarative view - nothing caught schema-typo contradictions across sessions - declarative metadata (descriptions, future flags) had no canonical home Restore the persistent ontology graph (``__ontology``) as the single source of truth, with additive-only semantics: - OntologyStore.register() validates incoming schema against the persisted ontology. Re-typing an existing property raises OntologyContradictionError before any partial state is persisted; additions (new labels, properties, patterns) go through unchanged. - OntologyStore.load() / clear() round out the lifecycle. - GraphRAG._ensure_ontology_initialized() lazy first-touch loads the persisted ontology and registers self.schema. Called from ingest (so contradictions surface before expensive extraction) and from get_ontology / retrieval paths. - delete_all() now drops both data and ontology graphs and resets the initialised flag so the next ingest re-registers self.schema cleanly. - get_ontology() always reads from the ontology graph; refresh_ontology() is a thin re-load for cross-process freshness. save_ontology(path) remains as a JSON-file bridge. This reverts the architecture of 54e4e05 (drop ontology graph, infer from data). The "save a graph" simplicity wasn't worth losing validation, multi-process safety, type fidelity, and a queryable schema artifact. Tests: 825 passed, 23 skipped. Co-Authored-By: Claude Opus 4.7 (1M context) --- graphrag_sdk/src/graphrag_sdk/api/main.py | 117 +++-- .../graphrag_sdk/storage/ontology_store.py | 450 ++++++++++-------- graphrag_sdk/tests/test_ontology_store.py | 418 ++++++++++------ 3 files changed, 601 insertions(+), 384 deletions(-) diff --git a/graphrag_sdk/src/graphrag_sdk/api/main.py b/graphrag_sdk/src/graphrag_sdk/api/main.py index b50d8e14..cc21f838 100644 --- a/graphrag_sdk/src/graphrag_sdk/api/main.py +++ b/graphrag_sdk/src/graphrag_sdk/api/main.py @@ -195,14 +195,14 @@ def __init__( # Deduplication engine self._deduplicator = EntityDeduplicator(self._graph_store, self.embedder) - # Ontology is inferred from the live data graph (no separate - # persistent graph). User-supplied ``schema`` is merged on top at - # retrieval time so declared descriptions / required flags / not-yet- - # extracted properties survive. - self._ontology_store = OntologyStore(self._conn) - # Global ontology used at retrieval time. Initially just the local - # schema; refresh_ontology() merges in the inferred view on demand - # and after each ingest. + # Persistent ontology graph (``__ontology``). Always-on, + # always the anchor: ``self.schema`` is registered into it on first + # connection, and ``get_ontology()`` / retrieval always read from it. + self._ontology_store = OntologyStore(self._conn, self._conn.config.graph_name) + # Lazy-init flag; the first async call that needs the schema fires + # ``_ensure_ontology_initialized()`` to load + register the user's schema. + self._ontology_initialized = False + # Working schema used by retrieval; populated by ``_ensure_ontology_initialized()``. self._global_schema: GraphSchema = self.schema # Default retrieval strategy @@ -247,44 +247,58 @@ async def close(self) -> None: # ── Ontology ───────────────────────────────────────────────── - async def get_ontology(self) -> GraphSchema: - """Return the **global** ontology used at retrieval time. - - The global ontology is the union of: - - the ontology inferred from the live data graph - (``db.labels()`` / ``db.relationshipTypes()`` + property sampling), and - - the user-supplied ``schema`` passed to :py:class:`GraphRAG` - (carries descriptions, ``required`` flags, properties not yet - present in the data). + async def _ensure_ontology_initialized(self) -> None: + """Lazy first-touch: load the persisted ontology and register the + user-supplied :py:attr:`schema` into it. - Returns an empty schema when both are empty. + Idempotent. The first async call that touches the ontology (ingest, + get_ontology, retrieval) fires this; subsequent calls are no-ops. + Raises :py:class:`OntologyContradictionError` if ``self.schema`` + re-defines an existing property's type. """ - inferred = await self._ontology_store.infer() - if inferred.entities or inferred.relations: - return inferred.merge(self.schema) - return self.schema + if self._ontology_initialized: + return + loaded = await self._ontology_store.load() + if self.schema.entities or self.schema.relations: + self._global_schema = await self._ontology_store.register(self.schema) + else: + self._global_schema = loaded + if hasattr(self._retrieval_strategy, "_schema"): + self._retrieval_strategy._schema = self._global_schema + self._ontology_initialized = True - async def refresh_ontology(self) -> GraphSchema: - """Recompute the global ontology and propagate it to the retrieval path. + async def get_ontology(self) -> GraphSchema: + """Return the persisted global ontology. - Called automatically after each ``ingest()``. Call explicitly when - the data graph has changed under your feet (concurrent writer, - manual edits) and the next retrieval should see it. + Reads from the ontology graph — the single source of truth. On the + first call, also registers any ``schema`` passed to :py:class:`GraphRAG` + into the ontology graph (validating no type contradictions with what's + already persisted). """ - self._global_schema = await self.get_ontology() - # Best-effort propagation to the retrieval strategy. Built-in - # MultiPathRetrieval honours this; custom strategies opt in by exposing - # a ``_schema`` attribute or accepting it via constructor. + await self._ensure_ontology_initialized() + # Always reflect the latest persisted state in case another writer + # has registered new types since our last ensure_initialized. + loaded = await self._ontology_store.load() + self._global_schema = loaded if hasattr(self._retrieval_strategy, "_schema"): self._retrieval_strategy._schema = self._global_schema return self._global_schema + async def refresh_ontology(self) -> GraphSchema: + """Reload the global ontology and propagate it to the retrieval path. + + Call explicitly when another process has registered new schema and + you want the next retrieval to see it without re-ingesting first. + """ + return await self.get_ontology() + async def save_ontology(self, path: str, *, indent: int = 2) -> None: """Write the current global ontology to ``path`` as JSON. - Convenience for the schema-as-config workflow: ``rag.save_ontology( - "ontology.json")``, hand-edit / version-control it, then load with - ``GraphSchema.from_file("ontology.json")`` on the next run. + Bridges the persisted ontology graph to a versionable JSON artifact: + ``rag.save_ontology("ontology.json")``, hand-edit / version-control + it, then load with ``GraphSchema.from_file("ontology.json")`` on the + next run. The ontology graph remains the canonical copy. """ ontology = await self.get_ontology() ontology.save_to_file(path, indent=indent) @@ -300,20 +314,29 @@ async def get_statistics(self) -> dict[str, Any]: return await self._graph_store.get_statistics() async def delete_all(self) -> None: - """Drop the entire knowledge graph. + """Drop the entire knowledge graph (data + ontology). - Irreversible. Removes all nodes, relationships, and indexes - managed by this ``GraphRAG`` instance. Also invalidates the - cached config and index flags so a follow-up ``ingest()`` on - the same instance re-runs validation and re-creates indexes - instead of trusting stale state. + Irreversible. Removes all nodes, relationships, and indexes managed + by this ``GraphRAG`` instance, plus the paired ontology graph at + ``__ontology``. Also invalidates the cached config and + index flags so a follow-up ``ingest()`` on the same instance re-runs + validation, re-creates indexes, and re-registers the user's schema. """ await self._graph_store.delete_all() + # Drop the ontology graph alongside the data graph so the two never + # outlive each other; a fresh ingest re-registers self.schema from + # scratch via _ensure_ontology_initialized(). + try: + await self._ontology_store.clear() + except Exception as exc: + logger.warning("Ontology graph clear failed during delete_all (continuing): %s", exc) # Indexes were dropped along with the graph; force re-creation # on the next ensure_indices() call. self._vector_store._indices_ensured = False # The __GraphRAGConfig__ node is gone too; re-validate next time. self._config_validated = False + # Force re-registration of self.schema next call. + self._ontology_initialized = False # ── Ingestion ──────────────────────────────────────────────── @@ -588,6 +611,11 @@ async def _ingest_single( doc_info = DocumentInfo(uid=resolved_id, path=path_for_node) + # Load the persisted ontology and register self.schema before the + # pipeline runs — contradictions surface early, before any expensive + # extraction work is done. + await self._ensure_ontology_initialized() + pipeline = IngestionPipeline( loader=loader or TextLoader(), chunker=chunker or SentenceTokenCapChunking(), @@ -595,18 +623,11 @@ async def _ingest_single( resolver=resolver or ExactMatchResolution(), graph_store=self._graph_store, vector_store=self._vector_store, - schema=self.schema, + schema=self._global_schema, ) result = await pipeline.run(source, ctx, text=text, document_info=doc_info) - # Refresh the global ontology so the next retrieval call sees any - # new properties declared by this run. - try: - await self.refresh_ontology() - except Exception as exc: - logger.warning("Ontology refresh failed (continuing): %s", exc) - if not _skip_post: # Post-ingestion: create indices only. # backfill_entity_embeddings() is intentionally NOT called here — diff --git a/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py b/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py index 30835086..81342639 100644 --- a/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py +++ b/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py @@ -1,18 +1,22 @@ -"""Ontology inference from the live data graph. - -The schema is **derived** from what's in the data graph, not maintained in a -separate persistent graph. This keeps the architecture honest: the source of -truth for "what entities and relations exist" is the data itself. - -Two consumers: -- Retrieval reads the inferred schema each session to build the Cypher prompt. -- ``GraphRAG.get_ontology()`` returns it for inspection. - -Users who want a curated, declarative schema (descriptions, not-yet-extracted -properties) pass a ``local_schema`` to ``GraphRAG`` — it's unioned with the -inferred schema at retrieval time so declared metadata survives. -``GraphSchema.save_to_file`` / ``GraphSchema.from_file`` cover the -schema-as-config workflow. +"""Persistent ontology storage in a dedicated FalkorDB graph. + +The ontology lives in a separate FalkorDB graph named ``__ontology`` +and is the **anchor** for the working schema: + +- Always-on: every :py:class:`GraphRAG` has exactly one ontology graph, + created lazily on first use, dropped on ``delete_all()``. +- Single source of truth: retrieval, ``get_ontology()``, and any cross-process + worker all read from the same graph. +- Additive only: :py:meth:`register` validates incoming schema against what's + already persisted and refuses **type contradictions** on existing properties. + New entity types, relation types, properties, and relation patterns are all + welcome. Re-typing an existing property is not. + +Users who want a curated, declarative schema (descriptions, future flags, +properties not yet observed in the data) supply a ``schema`` to ``GraphRAG``; +it gets registered into the ontology graph on first connection. JSON +import/export via :py:meth:`GraphSchema.save_to_file` / ``from_file`` is a +review / version-control bridge — the ontology graph is the canonical copy. """ from __future__ import annotations @@ -22,7 +26,6 @@ from graphrag_sdk.core.connection import FalkorDBConnection from graphrag_sdk.core.models import ( - RESERVED_PROPERTY_NAMES, EntityType, GraphSchema, PropertyType, @@ -32,209 +35,270 @@ logger = logging.getLogger(__name__) -# Labels created by the SDK that are not user entities. -_STRUCTURAL_LABELS: frozenset[str] = frozenset({"Chunk", "Document", "__Entity__"}) +class OntologyContradictionError(ValueError): + """Raised when an incoming schema redefines an existing property's type. -# Edge labels created by the SDK that are not user relations. -_STRUCTURAL_REL_TYPES: frozenset[str] = frozenset({"PART_OF", "NEXT_CHUNK", "MENTIONED_IN"}) + The ontology is additive: new labels, properties, and relation patterns are + welcome, but re-typing a property already registered on a label is + explicitly rejected so downstream Cypher queries don't break silently. + """ -# Property keys we never want to expose to the LLM as "custom attributes". -# These are SDK-internal or reserved meanings; the Cypher prompt already -# emits the reserved ones it cares about (``name``, ``description``, etc.). -_INFER_SKIP_KEYS: frozenset[str] = RESERVED_PROPERTY_NAMES | frozenset( - {"content_hash", "path", "text", "uid", "index", "metadata", "embedding"} -) +def _encode_patterns(patterns: list[tuple[str, str]]) -> list[str]: + return [f"{src}|{tgt}" for src, tgt in patterns] -# FalkorDB ``typeof()`` returns lowercase strings; map to our PropertyType vocabulary. -_TYPE_MAP: dict[str, str] = { - "string": "STRING", - "integer": "INTEGER", - "double": "FLOAT", - "float": "FLOAT", - "boolean": "BOOLEAN", - "array": "LIST", - "list": "LIST", -} +def _decode_patterns(encoded: list[str] | None) -> list[tuple[str, str]]: + out: list[tuple[str, str]] = [] + for s in encoded or []: + if not isinstance(s, str) or "|" not in s: + continue + src, tgt = s.split("|", 1) + out.append((src, tgt)) + return out -def _normalize_type(raw: str | None) -> str | None: - if not raw: - return None - return _TYPE_MAP.get(raw.strip().lower()) + +def _props_from_rows(rows: list[Any] | None) -> list[PropertyType]: + """Reconstruct PropertyType objects from a ``collect(...)`` query result. + + Filters out the null-keyed dict FalkorDB returns for an OPTIONAL MATCH + with no matches. + """ + out: list[PropertyType] = [] + for row in rows or []: + if not row or not isinstance(row, dict): + continue + name = row.get("name") + if not name: + continue + out.append( + PropertyType( + name=name, + type=row.get("type") or "STRING", + description=row.get("description"), + ) + ) + return out class OntologyStore: - """Infers the working :py:class:`GraphSchema` from the data graph. + """Persists and loads :py:class:`GraphSchema` in a dedicated FalkorDB graph. - No separate FalkorDB graph: this introspects the data graph directly via - ``db.labels()`` / ``db.relationshipTypes()`` plus per-label sampling of - property keys and types. + Owns its own graph handle, derived from the data-graph connection's + driver. Queries go directly to the FalkorDB driver and bypass the + connection's retry / circuit-breaker — ontology operations are + infrequent, idempotent, and tolerant of a single failure. """ - def __init__(self, connection: FalkorDBConnection) -> None: - self._conn = connection - - async def infer(self, *, sample_size: int = 200) -> GraphSchema: - """Build a :py:class:`GraphSchema` from what is currently in the data graph. + ONTOLOGY_GRAPH_SUFFIX = "__ontology" - ``sample_size`` caps the per-label scan used to discover property keys - and types. Returns an empty schema on any introspection failure. + def __init__(self, connection: FalkorDBConnection, data_graph_name: str) -> None: + self._conn = connection + self._graph_name = f"{data_graph_name}{self.ONTOLOGY_GRAPH_SUFFIX}" + self._graph: Any | None = None + + @property + def graph_name(self) -> str: + return self._graph_name + + def _ensure_graph(self) -> Any: + if self._graph is not None: + return self._graph + self._conn._ensure_client() + driver = self._conn._driver + if driver is None: + raise RuntimeError("FalkorDB driver not initialised on connection") + self._graph = driver.select_graph(self._graph_name) + return self._graph + + async def _query(self, cypher: str, params: dict[str, Any] | None = None) -> Any: + graph = self._ensure_graph() + return await graph.query(cypher, params=params) + + # ── Load ───────────────────────────────────────────────────── + + async def load(self) -> GraphSchema: + """Read the ontology graph and reconstruct a :py:class:`GraphSchema`. + + Returns an empty schema if the ontology graph does not yet exist or + introspection fails. Failure is logged at DEBUG so we don't spam an + unconfigured GraphRAG instance with warnings. """ try: - labels_result = await self._conn.query("CALL db.labels()") - rel_types_result = await self._conn.query("CALL db.relationshipTypes()") + ent_result = await self._query( + "MATCH (e:OntologyEntityType) " + "OPTIONAL MATCH (e)-[:HAS_PROPERTY]->(p:OntologyProperty) " + "RETURN e.label AS label, e.description AS description, " + "collect({name: p.name, type: p.type, description: p.description}) AS properties" + ) + rel_result = await self._query( + "MATCH (r:OntologyRelationType) " + "OPTIONAL MATCH (r)-[:HAS_PROPERTY]->(p:OntologyProperty) " + "RETURN r.label AS label, r.description AS description, " + "r.patterns AS patterns, " + "collect({name: p.name, type: p.type, description: p.description}) AS properties" + ) except Exception as exc: - logger.debug("Ontology inference: labels/types query failed: %s", exc) + logger.debug("Ontology load failed (returning empty schema): %s", exc) return GraphSchema() - labels: list[str] = [ - row[0] - for row in (labels_result.result_set or []) - if row and row[0] and row[0] not in _STRUCTURAL_LABELS - ] - rel_types: list[str] = [ - row[0] - for row in (rel_types_result.result_set or []) - if row and row[0] and row[0] not in _STRUCTURAL_REL_TYPES - ] + ent_rows = getattr(ent_result, "result_set", None) or [] + rel_rows = getattr(rel_result, "result_set", None) or [] + if not isinstance(ent_rows, list): + ent_rows = [] + if not isinstance(rel_rows, list): + rel_rows = [] entities = [ EntityType( - label=label, - properties=await self._properties_for_node(label, sample_size), + label=row[0], + description=row[1], + properties=_props_from_rows(row[2]), ) - for label in labels + for row in ent_rows + if isinstance(row, list) and len(row) >= 3 and row[0] ] - # The unified data model writes every user relation as a ``RELATES`` - # edge whose ``rel_type`` property carries the original label; the - # SDK's structural edges (PART_OF/NEXT_CHUNK/MENTIONED_IN) are excluded - # above. We surface the distinct ``rel_type`` values as RelationTypes - # so the Cypher prompt knows the allowed values, and expose their - # property keys + endpoint patterns. - relations: list[RelationType] = [] - if "RELATES" in rel_types: - relations = await self._infer_relates_subtypes(sample_size) - - return GraphSchema(entities=entities, relations=relations) - - async def _properties_for_node(self, label: str, sample_size: int) -> list[PropertyType]: - try: - result = await self._conn.query( - f"MATCH (n:`{label}`) " - "WITH n LIMIT $limit " - "UNWIND keys(n) AS k " - "WITH k, typeof(n[k]) AS t " - "RETURN k AS key, t AS type, count(*) AS c " - "ORDER BY c DESC", - {"limit": sample_size}, + relations = [ + RelationType( + label=row[0], + description=row[1], + patterns=_decode_patterns(row[2]), + properties=_props_from_rows(row[3]), ) - except Exception as exc: - logger.debug("Ontology inference: properties query failed for %s: %s", label, exc) - return [] - return _props_from_rows(result.result_set) + for row in rel_rows + if isinstance(row, list) and len(row) >= 4 and row[0] + ] + return GraphSchema(entities=entities, relations=relations) - async def _infer_relates_subtypes(self, sample_size: int) -> list[RelationType]: - """Group ``RELATES`` edges by ``rel_type`` and infer per-subtype properties.""" - try: - subtypes_result = await self._conn.query( - "MATCH ()-[r:RELATES]->() " - "WITH r LIMIT $limit " - "WITH DISTINCT r.rel_type AS rel_type " - "WHERE rel_type IS NOT NULL " - "RETURN rel_type", - {"limit": sample_size * 5}, # broader pool to capture rare subtypes - ) - except Exception as exc: - logger.debug("Ontology inference: RELATES subtypes query failed: %s", exc) - return [] - - relations: list[RelationType] = [] - for row in subtypes_result.result_set or []: - subtype = row[0] - if not subtype: - continue - properties = await self._properties_for_relates_subtype(subtype, sample_size) - patterns = await self._patterns_for_relates_subtype(subtype) - relations.append(RelationType(label=subtype, patterns=patterns, properties=properties)) - return relations - - async def _properties_for_relates_subtype( - self, subtype: str, sample_size: int - ) -> list[PropertyType]: - try: - result = await self._conn.query( - "MATCH ()-[r:RELATES {rel_type: $sub}]->() " - "WITH r LIMIT $limit " - "UNWIND keys(r) AS k " - "WITH k, typeof(r[k]) AS t " - "RETURN k AS key, t AS type, count(*) AS c " - "ORDER BY c DESC", - {"sub": subtype, "limit": sample_size}, - ) - except Exception as exc: - logger.debug( - "Ontology inference: relation properties query failed for %s: %s", - subtype, - exc, - ) - return [] - return _props_from_rows(result.result_set) + # ── Register ───────────────────────────────────────────────── - async def _patterns_for_relates_subtype(self, subtype: str) -> list[tuple[str, str]]: - try: - result = await self._conn.query( - "MATCH (a)-[r:RELATES {rel_type: $sub}]->(b) " - "WITH labels(a) AS la, labels(b) AS lb " - "RETURN DISTINCT la, lb LIMIT 25", - {"sub": subtype}, - ) - except Exception as exc: - logger.debug( - "Ontology inference: endpoint patterns query failed for %s: %s", - subtype, - exc, - ) - return [] - patterns: list[tuple[str, str]] = [] - seen: set[tuple[str, str]] = set() - for row in result.result_set or []: - src_labels, tgt_labels = row[0] or [], row[1] or [] - src = next( - (lbl for lbl in src_labels if lbl not in _STRUCTURAL_LABELS), - None, - ) - tgt = next( - (lbl for lbl in tgt_labels if lbl not in _STRUCTURAL_LABELS), - None, - ) - if not src or not tgt: - continue - key = (src, tgt) - if key in seen: - continue - seen.add(key) - patterns.append(key) - return patterns + async def register(self, schema: GraphSchema) -> GraphSchema: + """Merge ``schema`` into the persisted ontology and return the union. + Validates first: if ``schema`` redefines the type of a property + already registered on the same entity/relation label, raises + :py:class:`OntologyContradictionError` before any partial state is + persisted. -def _props_from_rows(rows: list[list[Any]] | None) -> list[PropertyType]: - """Turn ``(key, typeof, count)`` rows into :py:class:`PropertyType` objects. + Additive operations — new entity types, new relations, new properties, + new relation patterns — go through unchanged. + """ + if not schema.entities and not schema.relations: + return await self.load() + + existing = await self.load() + self._check_no_contradictions(existing, schema) + + for et in schema.entities: + await self._upsert_entity_type(et) + for rt in schema.relations: + await self._upsert_relation_type(rt) + + return await self.load() + + @staticmethod + def _check_no_contradictions(existing: GraphSchema, incoming: GraphSchema) -> None: + """Raise :py:class:`OntologyContradictionError` on any type re-declaration.""" + existing_ent_types: dict[tuple[str, str], str] = { + (e.label, p.name): p.type for e in existing.entities for p in e.properties + } + for et in incoming.entities: + for p in et.properties: + prior = existing_ent_types.get((et.label, p.name)) + if prior is not None and prior != p.type: + raise OntologyContradictionError( + f"Property '{et.label}.{p.name}' is already registered as " + f"{prior}; refusing to redefine as {p.type}. The ontology " + f"is additive — drop the data graph and start fresh if you " + f"need to change a property's type." + ) + + existing_rel_types: dict[tuple[str, str], str] = { + (r.label, p.name): p.type for r in existing.relations for p in r.properties + } + for rt in incoming.relations: + for p in rt.properties: + prior = existing_rel_types.get((rt.label, p.name)) + if prior is not None and prior != p.type: + raise OntologyContradictionError( + f"Property '{rt.label}.{p.name}' (on relation) is already " + f"registered as {prior}; refusing to redefine as {p.type}." + ) + + async def _upsert_entity_type(self, et: EntityType) -> None: + await self._query( + "MERGE (e:OntologyEntityType {label: $label}) " + "SET e.description = coalesce($description, e.description)", + {"label": et.label, "description": et.description}, + ) + for prop in et.properties: + await self._upsert_property(et.label, prop, owner_label="OntologyEntityType") + + async def _upsert_relation_type(self, rt: RelationType) -> None: + new_patterns = _encode_patterns(rt.patterns) + result = await self._query( + "MATCH (r:OntologyRelationType {label: $label}) RETURN r.patterns AS patterns", + {"label": rt.label}, + ) + existing_patterns: list[str] = [] + rows = getattr(result, "result_set", None) or [] + if isinstance(rows, list) and rows and isinstance(rows[0], list) and rows[0]: + existing_patterns = list(rows[0][0] or []) + seen: set[str] = set() + merged: list[str] = [] + for s in existing_patterns + new_patterns: + if s not in seen: + seen.add(s) + merged.append(s) + await self._query( + "MERGE (r:OntologyRelationType {label: $label}) " + "SET r.description = coalesce($description, r.description), " + "r.patterns = $patterns", + {"label": rt.label, "description": rt.description, "patterns": merged}, + ) + for prop in rt.properties: + await self._upsert_property(rt.label, prop, owner_label="OntologyRelationType") + + async def _upsert_property( + self, owner_label_value: str, prop: PropertyType, *, owner_label: str + ) -> None: + # Property nodes are keyed by ``(owner_label_kind, owner_label, name)`` + # so two different types can declare the same property name without + # trampling each other's metadata. + owner_alias = "ent" if owner_label == "OntologyEntityType" else "rel" + await self._query( + f"MATCH ({owner_alias}:{owner_label} {{label: $owner}}) " + f"MERGE ({owner_alias})-[:HAS_PROPERTY]->" + f"(p:OntologyProperty {{name: $name, owner: $owner_kind, owner_label: $owner}}) " + "SET p.type = $type, " + "p.description = coalesce($description, p.description)", + { + "owner": owner_label_value, + "owner_kind": owner_label, + "name": prop.name, + "type": prop.type, + "description": prop.description, + }, + ) + + # ── Clear ──────────────────────────────────────────────────── + + async def clear(self) -> None: + """Drop the ontology graph (``GRAPH.DELETE``). Idempotent. + + Called from ``GraphRAG.delete_all()`` so the ontology graph never + outlives the data graph. + """ + self._conn._ensure_client() + from redis.asyncio import Redis - Skips reserved/system keys and unmappable types so they never leak into the - LLM-facing schema. - """ - out: list[PropertyType] = [] - seen: set[str] = set() - for row in rows or []: - if not row or len(row) < 2: - continue - key, raw_type = row[0], row[1] - if not isinstance(key, str) or key in _INFER_SKIP_KEYS or key in seen: - continue - normalized = _normalize_type(raw_type if isinstance(raw_type, str) else None) - if not normalized: - continue - seen.add(key) - out.append(PropertyType(name=key, type=normalized)) - return out + redis: Redis = Redis(connection_pool=self._conn._pool) + try: + await redis.execute_command("GRAPH.DELETE", self._graph_name) + except Exception as exc: + msg = str(exc).lower() + if "empty" in msg or "invalid" in msg or "key" in msg: + logger.debug("Ontology graph '%s' already empty", self._graph_name) + else: + raise + self._graph = None diff --git a/graphrag_sdk/tests/test_ontology_store.py b/graphrag_sdk/tests/test_ontology_store.py index cd466b45..560f4978 100644 --- a/graphrag_sdk/tests/test_ontology_store.py +++ b/graphrag_sdk/tests/test_ontology_store.py @@ -1,186 +1,318 @@ -"""Tests for storage/ontology_store.py — data-graph ontology inference.""" +"""Tests for storage/ontology_store.py — persistent ontology graph. + +The store talks to FalkorDB directly; unit tests here mock the graph handle +through ``FalkorDBConnection``'s ``_driver.select_graph()`` seam. Real-FalkorDB +exercise is left for the integration suite. +""" + from __future__ import annotations +from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphrag_sdk.core.models import ( + EntityType, + GraphSchema, + PropertyType, + RelationType, +) from graphrag_sdk.storage.ontology_store import ( + OntologyContradictionError, OntologyStore, - _normalize_type, + _decode_patterns, + _encode_patterns, _props_from_rows, ) -class _FakeResult: +class _FakeQueryResult: + """Stand-in for FalkorDB's QueryResult.""" + def __init__(self, rows): self.result_set = rows -def _make_connection(handler): - """Wrap an async function (cypher, params) -> _FakeResult into a fake - ``FalkorDBConnection`` that just routes ``query`` to ``handler``.""" - conn = MagicMock() - conn.query = handler - return conn +class _FakeGraph: + """In-memory async ``query()`` substitute. Records calls and serves canned + responses for the load + read-existing-patterns queries.""" + + def __init__(self): + self.calls: list[tuple[str, dict | None]] = [] + self._ent_rows: list[list] = [] + self._rel_rows: list[list] = [] + self._patterns_for_label: dict[str, list[str]] = {} + + def set_load_response(self, entity_rows, relation_rows): + self._ent_rows = entity_rows + self._rel_rows = relation_rows + + def set_existing_patterns(self, label: str, patterns: list[str]): + self._patterns_for_label[label] = patterns + + async def query(self, cypher: str, params: dict | None = None): + self.calls.append((cypher, params)) + if "MATCH (e:OntologyEntityType)" in cypher and "collect" in cypher: + return _FakeQueryResult(self._ent_rows) + if "MATCH (r:OntologyRelationType)" in cypher and "collect" in cypher: + return _FakeQueryResult(self._rel_rows) + if "MATCH (r:OntologyRelationType {label: $label})" in cypher and "patterns" in cypher: + label = (params or {}).get("label", "") + return _FakeQueryResult([[self._patterns_for_label.get(label, [])]]) + return _FakeQueryResult([]) + + +@pytest.fixture +def fake_graph(): + return _FakeGraph() + + +@pytest.fixture +def store_factory(fake_graph): + """Returns a callable producing an ``OntologyStore`` wired to ``fake_graph``.""" + + def _make(data_graph_name: str = "kg") -> OntologyStore: + conn = MagicMock() + conn._ensure_client = MagicMock() + conn._driver = SimpleNamespace(select_graph=MagicMock(return_value=fake_graph)) + return OntologyStore(conn, data_graph_name) + + return _make # ── small helpers ──────────────────────────────────────────────── -class TestNormalizeType: - @pytest.mark.parametrize( - "raw,expected", - [ - ("string", "STRING"), - ("STRING", "STRING"), - ("integer", "INTEGER"), - ("double", "FLOAT"), - ("float", "FLOAT"), - ("boolean", "BOOLEAN"), - ("array", "LIST"), - ("list", "LIST"), - ("point", None), - ("null", None), - (None, None), - ("", None), - ], - ) - def test_matrix(self, raw, expected): - assert _normalize_type(raw) == expected +class TestEncoders: + def test_roundtrip(self): + patterns = [("Person", "Company"), ("Person", "Organization")] + encoded = _encode_patterns(patterns) + assert encoded == ["Person|Company", "Person|Organization"] + assert _decode_patterns(encoded) == patterns + + def test_decode_tolerates_garbage(self): + assert _decode_patterns(None) == [] + assert _decode_patterns(["no-pipe", "a|b"]) == [("a", "b")] class TestPropsFromRows: - def test_skips_reserved_and_unknown_types(self): + def test_filters_null_keyed_rows(self): rows = [ - ["age", "integer", 5], - ["name", "string", 10], # reserved - ["source_chunk_ids", "array", 5], # reserved - ["weird", "point", 1], # unmapped type - ["email", "string", 8], - ["age", "string", 1], # duplicate -> first wins - ] - out = _props_from_rows(rows) - assert [(p.name, p.type) for p in out] == [ - ("age", "INTEGER"), - ("email", "STRING"), + {"name": "age", "type": "INTEGER", "description": None}, + {"name": None, "type": None, "description": None}, # optional-match empty + None, + "not a dict", ] + result = _props_from_rows(rows) + assert [(p.name, p.type) for p in result] == [("age", "INTEGER")] + - def test_empty_input(self): - assert _props_from_rows(None) == [] - assert _props_from_rows([]) == [] +# ── store identity ─────────────────────────────────────────────── -# ── infer() end-to-end (mocked driver) ─────────────────────────── +class TestOntologyStoreGraphName: + def test_suffix(self, store_factory): + store = store_factory("my_kg") + assert store.graph_name == "my_kg__ontology" -class TestOntologyStoreInfer: +# ── register / load ────────────────────────────────────────────── + + +class TestRegister: @pytest.mark.asyncio - async def test_filters_structural_labels_and_edge_types(self): - calls = [] - - async def handler(cypher, params=None): - calls.append(cypher) - if "db.labels()" in cypher: - return _FakeResult( - [["Person"], ["Chunk"], ["Document"], ["__Entity__"]] - ) - if "db.relationshipTypes()" in cypher: - return _FakeResult( - [["PART_OF"], ["NEXT_CHUNK"], ["MENTIONED_IN"]] - ) - return _FakeResult([]) - - store = OntologyStore(_make_connection(handler)) - schema = await store.infer() - assert [e.label for e in schema.entities] == ["Person"] - assert schema.relations == [] + async def test_empty_schema_short_circuits_to_load(self, store_factory, fake_graph): + store = store_factory() + result = await store.register(GraphSchema()) + # No upsert queries — only the two load queries. + upserts = [c for c in fake_graph.calls if "MERGE" in c[0]] + assert upserts == [] + assert isinstance(result, GraphSchema) @pytest.mark.asyncio - async def test_relates_subtypes_are_surfaced_with_patterns_and_props(self): - async def handler(cypher, params=None): - if "db.labels()" in cypher: - return _FakeResult([["Person"], ["Company"], ["Location"]]) - if "db.relationshipTypes()" in cypher: - return _FakeResult([["RELATES"], ["MENTIONED_IN"]]) - if "MATCH (n:`Person`)" in cypher: - return _FakeResult([["age", "integer", 3], ["name", "string", 3]]) - if "MATCH (n:`Company`)" in cypher or "MATCH (n:`Location`)" in cypher: - return _FakeResult([]) - if "DISTINCT r.rel_type" in cypher: - return _FakeResult([["WORKS_AT"], ["LOCATED_IN"]]) - if "labels(a)" in cypher: - sub = (params or {}).get("sub", "") - if sub == "WORKS_AT": - return _FakeResult( - [[["Person"], ["Company", "__Entity__"]]] - ) - if sub == "LOCATED_IN": - return _FakeResult( - [[["Person"], ["Location", "__Entity__"]]] - ) - if "UNWIND keys(r)" in cypher: - sub = (params or {}).get("sub", "") - if sub == "WORKS_AT": - # rel_type is reserved and must be skipped from - # discovered properties. - return _FakeResult( - [["since", "string", 2], ["rel_type", "string", 2]] - ) - return _FakeResult([]) - return _FakeResult([]) - - store = OntologyStore(_make_connection(handler)) - schema = await store.infer() - - assert {e.label for e in schema.entities} == {"Person", "Company", "Location"} - person = next(e for e in schema.entities if e.label == "Person") - assert [(p.name, p.type) for p in person.properties] == [("age", "INTEGER")] + async def test_persists_entity_type_with_properties(self, store_factory, fake_graph): + store = store_factory() + schema = GraphSchema( + entities=[ + EntityType( + label="Person", + description="A human", + properties=[ + PropertyType(name="age", type="INTEGER"), + PropertyType(name="birth_date", type="DATE"), + ], + ), + ], + ) + await store.register(schema) + ent_merges = [c for c in fake_graph.calls if "MERGE (e:OntologyEntityType" in c[0]] + prop_merges = [c for c in fake_graph.calls if "MERGE (ent)-[:HAS_PROPERTY]->" in c[0]] + assert len(ent_merges) == 1 + assert len(prop_merges) == 2 + prop_names = {(c[1] or {}).get("name") for c in prop_merges} + assert prop_names == {"age", "birth_date"} + + @pytest.mark.asyncio + async def test_unions_relation_patterns(self, store_factory, fake_graph): + store = store_factory() + fake_graph.set_existing_patterns("WORKS_AT", ["Person|Company"]) + schema = GraphSchema( + entities=[EntityType(label="Person"), EntityType(label="Org")], + relations=[RelationType(label="WORKS_AT", patterns=[("Person", "Org")])], + ) + await store.register(schema) + rel_set_calls = [ + c + for c in fake_graph.calls + if "MERGE (r:OntologyRelationType {label: $label})" in c[0] and "SET r." in c[0] + ] + assert rel_set_calls, "expected SET on RelationType" + patterns = (rel_set_calls[-1][1] or {})["patterns"] + assert "Person|Company" in patterns + assert "Person|Org" in patterns - rel_by_label = {r.label: r for r in schema.relations} - assert set(rel_by_label) == {"WORKS_AT", "LOCATED_IN"} - # Patterns strip __Entity__ and pick the user-visible label. - assert rel_by_label["WORKS_AT"].patterns == [("Person", "Company")] - # Reserved rel property is suppressed from the inferred schema. - assert [ - (p.name, p.type) for p in rel_by_label["WORKS_AT"].properties - ] == [("since", "STRING")] - assert rel_by_label["LOCATED_IN"].properties == [] +class TestLoad: @pytest.mark.asyncio - async def test_endpoint_pattern_drops_structural_targets(self): - async def handler(cypher, params=None): - if "db.labels()" in cypher: - return _FakeResult([["Person"]]) - if "db.relationshipTypes()" in cypher: - return _FakeResult([["RELATES"]]) - if "MATCH (n:`Person`)" in cypher: - return _FakeResult([]) - if "DISTINCT r.rel_type" in cypher: - return _FakeResult([["KNOWS"]]) - if "labels(a)" in cypher: - # Stale data has both a real and a structural target. - return _FakeResult( - [ - [["Person"], ["Chunk"]], # structural target -> dropped - [["Person"], ["Person"]], # real -> kept - ] - ) - if "UNWIND keys(r)" in cypher: - return _FakeResult([]) - return _FakeResult([]) - - store = OntologyStore(_make_connection(handler)) - schema = await store.infer() - knows = next(r for r in schema.relations if r.label == "KNOWS") - assert knows.patterns == [("Person", "Person")] + async def test_empty_graph_yields_empty_schema(self, store_factory, fake_graph): + store = store_factory() + fake_graph.set_load_response([], []) + schema = await store.load() + assert schema.entities == [] + assert schema.relations == [] @pytest.mark.asyncio - async def test_introspection_failure_returns_empty_schema(self): - async def handler(cypher, params=None): + async def test_introspection_failure_returns_empty(self, store_factory, fake_graph): + async def boom(cypher, params=None): raise RuntimeError("connection blew up") - store = OntologyStore(_make_connection(handler)) - schema = await store.infer() + fake_graph.query = boom + store = store_factory() + schema = await store.load() assert schema.entities == [] assert schema.relations == [] + + @pytest.mark.asyncio + async def test_reconstructs_schema(self, store_factory, fake_graph): + store = store_factory() + fake_graph.set_load_response( + entity_rows=[ + [ + "Person", + "A human", + [ + {"name": "age", "type": "INTEGER", "description": None}, + # OPTIONAL MATCH empty row — must be filtered. + {"name": None, "type": None, "description": None}, + ], + ], + ["Company", None, []], + ], + relation_rows=[ + [ + "WORKS_AT", + "Employment", + ["Person|Company"], + [{"name": "since", "type": "DATE", "description": None}], + ], + ], + ) + schema = await store.load() + assert {e.label for e in schema.entities} == {"Person", "Company"} + person = next(e for e in schema.entities if e.label == "Person") + assert [(p.name, p.type) for p in person.properties] == [("age", "INTEGER")] + works = next(r for r in schema.relations if r.label == "WORKS_AT") + assert works.patterns == [("Person", "Company")] + assert [(p.name, p.type) for p in works.properties] == [("since", "DATE")] + + +# ── contradiction validation ───────────────────────────────────── + + +class TestContradictionDetection: + """Additive-only schema: re-typing an existing property is rejected.""" + + @pytest.mark.asyncio + async def test_compatible_addition_is_accepted(self, store_factory, fake_graph): + store = store_factory() + fake_graph.set_load_response( + entity_rows=[ + [ + "Person", + None, + [{"name": "age", "type": "INTEGER", "description": None}], + ], + ], + relation_rows=[], + ) + incoming = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[ + PropertyType(name="age", type="INTEGER"), # same as before + PropertyType(name="birth_date", type="DATE"), # new property + ], + ), + ], + ) + await store.register(incoming) # must not raise + + @pytest.mark.asyncio + async def test_redefining_entity_property_type_is_rejected( + self, store_factory, fake_graph + ): + store = store_factory() + fake_graph.set_load_response( + entity_rows=[ + [ + "Person", + None, + [{"name": "age", "type": "INTEGER", "description": None}], + ], + ], + relation_rows=[], + ) + incoming = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[PropertyType(name="age", type="STRING")], + ), + ], + ) + with pytest.raises(OntologyContradictionError) as exc: + await store.register(incoming) + assert "Person.age" in str(exc.value) + # And no MERGE-on-EntityType happened (validation runs before persistence). + upserts = [c for c in fake_graph.calls if "MERGE (e:OntologyEntityType" in c[0]] + assert upserts == [] + + @pytest.mark.asyncio + async def test_redefining_relation_property_type_is_rejected( + self, store_factory, fake_graph + ): + store = store_factory() + fake_graph.set_load_response( + entity_rows=[], + relation_rows=[ + [ + "WORKS_AT", + None, + [], + [{"name": "since", "type": "DATE", "description": None}], + ], + ], + ) + incoming = GraphSchema( + relations=[ + RelationType( + label="WORKS_AT", + properties=[PropertyType(name="since", type="STRING")], + ), + ], + ) + with pytest.raises(OntologyContradictionError) as exc: + await store.register(incoming) + assert "WORKS_AT.since" in str(exc.value) From ac0e1c4e8738669bcb10f2bfce098435c9b8d808 Mon Sep 17 00:00:00 2001 From: Gal Shubeli Date: Tue, 19 May 2026 18:53:06 +0300 Subject: [PATCH 7/7] feat(ontology): seed default schema when neither user nor graph supply one MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Before, a user passing no schema got an empty ontology graph even though extraction was using DEFAULT_ENTITY_TYPES under the hood. The anchor said "we know nothing" while the LLM was producing Person/Organization/... nodes. Now ``_ensure_ontology_initialized()`` has three branches: - user passed a schema → register it (validate + persist) - ontology graph already populated (prior session) → use as-is - both empty → register DEFAULT_ENTITY_TYPES so the ontology graph faithfully reflects what the extractor will produce After this, ``get_ontology()``, the Cypher-generation prompt, and the extractor all read from the same source. Test fixture: ``mock_conn`` in test_facade.py now stubs the connection's ``_driver.select_graph()`` chain so OntologyStore can open the ontology graph handle against the mock. 825 passed, 23 skipped. Co-Authored-By: Claude Opus 4.7 (1M context) --- graphrag_sdk/src/graphrag_sdk/api/main.py | 26 ++++++++++++++++++----- graphrag_sdk/tests/test_facade.py | 9 ++++++++ 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/graphrag_sdk/src/graphrag_sdk/api/main.py b/graphrag_sdk/src/graphrag_sdk/api/main.py index cc21f838..47ee910f 100644 --- a/graphrag_sdk/src/graphrag_sdk/api/main.py +++ b/graphrag_sdk/src/graphrag_sdk/api/main.py @@ -22,6 +22,7 @@ ChatMessage, DeleteDocumentResult, DocumentInfo, + EntityType, FinalizeResult, GraphSchema, IngestionResult, @@ -35,6 +36,9 @@ SentenceTokenCapChunking, ) from graphrag_sdk.ingestion.extraction_strategies.base import ExtractionStrategy +from graphrag_sdk.ingestion.extraction_strategies.entity_extractors import ( + DEFAULT_ENTITY_TYPES, +) from graphrag_sdk.ingestion.extraction_strategies.graph_extraction import GraphExtraction from graphrag_sdk.ingestion.loaders.base import LoaderStrategy from graphrag_sdk.ingestion.loaders.markdown_loader import MarkdownLoader @@ -251,18 +255,30 @@ async def _ensure_ontology_initialized(self) -> None: """Lazy first-touch: load the persisted ontology and register the user-supplied :py:attr:`schema` into it. - Idempotent. The first async call that touches the ontology (ingest, - get_ontology, retrieval) fires this; subsequent calls are no-ops. - Raises :py:class:`OntologyContradictionError` if ``self.schema`` - re-defines an existing property's type. + Three states: + + - ``self.schema`` is non-empty → register it (validate + persist). + - ``self.schema`` empty but the ontology graph already has content + (previous session, another writer) → use it as-is. + - Both empty (first connection, no user schema) → register the + built-in :py:data:`DEFAULT_ENTITY_TYPES` so the ontology graph + accurately reflects the labels the extractor will produce. + + Idempotent. Raises :py:class:`OntologyContradictionError` if + ``self.schema`` re-defines an existing property's type. """ if self._ontology_initialized: return loaded = await self._ontology_store.load() if self.schema.entities or self.schema.relations: self._global_schema = await self._ontology_store.register(self.schema) - else: + elif loaded.entities or loaded.relations: self._global_schema = loaded + else: + default_schema = GraphSchema( + entities=[EntityType(label=label) for label in DEFAULT_ENTITY_TYPES], + ) + self._global_schema = await self._ontology_store.register(default_schema) if hasattr(self._retrieval_strategy, "_schema"): self._retrieval_strategy._schema = self._global_schema self._ontology_initialized = True diff --git a/graphrag_sdk/tests/test_facade.py b/graphrag_sdk/tests/test_facade.py index 78072b4f..b5b943ea 100644 --- a/graphrag_sdk/tests/test_facade.py +++ b/graphrag_sdk/tests/test_facade.py @@ -35,6 +35,15 @@ def mock_conn(): result_mock.result_set = [] conn.query = AsyncMock(return_value=result_mock) conn.config = ConnectionConfig() + # OntologyStore reaches into the connection's driver to open a handle on + # the paired ``__ontology`` graph. Stub the chain so the + # ontology graph appears as a writable async-query target whose queries + # return empty result sets. + ontology_graph = MagicMock() + ontology_graph.query = AsyncMock(return_value=result_mock) + conn._driver = MagicMock() + conn._driver.select_graph = MagicMock(return_value=ontology_graph) + conn._ensure_client = MagicMock() return conn