diff --git a/graphrag_sdk/src/graphrag_sdk/api/main.py b/graphrag_sdk/src/graphrag_sdk/api/main.py index 684c7a2d..47ee910f 100644 --- a/graphrag_sdk/src/graphrag_sdk/api/main.py +++ b/graphrag_sdk/src/graphrag_sdk/api/main.py @@ -22,6 +22,7 @@ ChatMessage, DeleteDocumentResult, DocumentInfo, + EntityType, FinalizeResult, GraphSchema, IngestionResult, @@ -35,6 +36,9 @@ SentenceTokenCapChunking, ) from graphrag_sdk.ingestion.extraction_strategies.base import ExtractionStrategy +from graphrag_sdk.ingestion.extraction_strategies.entity_extractors import ( + DEFAULT_ENTITY_TYPES, +) from graphrag_sdk.ingestion.extraction_strategies.graph_extraction import GraphExtraction from graphrag_sdk.ingestion.loaders.base import LoaderStrategy from graphrag_sdk.ingestion.loaders.markdown_loader import MarkdownLoader @@ -48,6 +52,7 @@ from graphrag_sdk.retrieval.strategies.multi_path import MultiPathRetrieval from graphrag_sdk.storage.deduplicator import EntityDeduplicator from graphrag_sdk.storage.graph_store import GraphStore +from graphrag_sdk.storage.ontology_store import OntologyStore from graphrag_sdk.storage.vector_store import VectorStore logger = logging.getLogger(__name__) @@ -194,12 +199,23 @@ def __init__( # Deduplication engine self._deduplicator = EntityDeduplicator(self._graph_store, self.embedder) + # Persistent ontology graph (``__ontology``). Always-on, + # always the anchor: ``self.schema`` is registered into it on first + # connection, and ``get_ontology()`` / retrieval always read from it. + self._ontology_store = OntologyStore(self._conn, self._conn.config.graph_name) + # Lazy-init flag; the first async call that needs the schema fires + # ``_ensure_ontology_initialized()`` to load + register the user's schema. + self._ontology_initialized = False + # Working schema used by retrieval; populated by ``_ensure_ontology_initialized()``. + self._global_schema: GraphSchema = self.schema + # Default retrieval strategy self._retrieval_strategy = retrieval_strategy or MultiPathRetrieval( graph_store=self._graph_store, vector_store=self._vector_store, embedder=self.embedder, llm=self.llm, + schema=self._global_schema, ) # -- Async context manager ------------------------------------------- @@ -233,6 +249,76 @@ async def close(self) -> None: """Close the underlying database connection.""" await self._conn.close() + # ── Ontology ───────────────────────────────────────────────── + + async def _ensure_ontology_initialized(self) -> None: + """Lazy first-touch: load the persisted ontology and register the + user-supplied :py:attr:`schema` into it. + + Three states: + + - ``self.schema`` is non-empty → register it (validate + persist). + - ``self.schema`` empty but the ontology graph already has content + (previous session, another writer) → use it as-is. + - Both empty (first connection, no user schema) → register the + built-in :py:data:`DEFAULT_ENTITY_TYPES` so the ontology graph + accurately reflects the labels the extractor will produce. + + Idempotent. Raises :py:class:`OntologyContradictionError` if + ``self.schema`` re-defines an existing property's type. + """ + if self._ontology_initialized: + return + loaded = await self._ontology_store.load() + if self.schema.entities or self.schema.relations: + self._global_schema = await self._ontology_store.register(self.schema) + elif loaded.entities or loaded.relations: + self._global_schema = loaded + else: + default_schema = GraphSchema( + entities=[EntityType(label=label) for label in DEFAULT_ENTITY_TYPES], + ) + self._global_schema = await self._ontology_store.register(default_schema) + if hasattr(self._retrieval_strategy, "_schema"): + self._retrieval_strategy._schema = self._global_schema + self._ontology_initialized = True + + async def get_ontology(self) -> GraphSchema: + """Return the persisted global ontology. + + Reads from the ontology graph — the single source of truth. On the + first call, also registers any ``schema`` passed to :py:class:`GraphRAG` + into the ontology graph (validating no type contradictions with what's + already persisted). + """ + await self._ensure_ontology_initialized() + # Always reflect the latest persisted state in case another writer + # has registered new types since our last ensure_initialized. + loaded = await self._ontology_store.load() + self._global_schema = loaded + if hasattr(self._retrieval_strategy, "_schema"): + self._retrieval_strategy._schema = self._global_schema + return self._global_schema + + async def refresh_ontology(self) -> GraphSchema: + """Reload the global ontology and propagate it to the retrieval path. + + Call explicitly when another process has registered new schema and + you want the next retrieval to see it without re-ingesting first. + """ + return await self.get_ontology() + + async def save_ontology(self, path: str, *, indent: int = 2) -> None: + """Write the current global ontology to ``path`` as JSON. + + Bridges the persisted ontology graph to a versionable JSON artifact: + ``rag.save_ontology("ontology.json")``, hand-edit / version-control + it, then load with ``GraphSchema.from_file("ontology.json")`` on the + next run. The ontology graph remains the canonical copy. + """ + ontology = await self.get_ontology() + ontology.save_to_file(path, indent=indent) + # ── Graph admin ────────────────────────────────────────────── async def get_statistics(self) -> dict[str, Any]: @@ -244,20 +330,29 @@ async def get_statistics(self) -> dict[str, Any]: return await self._graph_store.get_statistics() async def delete_all(self) -> None: - """Drop the entire knowledge graph. + """Drop the entire knowledge graph (data + ontology). - Irreversible. Removes all nodes, relationships, and indexes - managed by this ``GraphRAG`` instance. Also invalidates the - cached config and index flags so a follow-up ``ingest()`` on - the same instance re-runs validation and re-creates indexes - instead of trusting stale state. + Irreversible. Removes all nodes, relationships, and indexes managed + by this ``GraphRAG`` instance, plus the paired ontology graph at + ``__ontology``. Also invalidates the cached config and + index flags so a follow-up ``ingest()`` on the same instance re-runs + validation, re-creates indexes, and re-registers the user's schema. """ await self._graph_store.delete_all() + # Drop the ontology graph alongside the data graph so the two never + # outlive each other; a fresh ingest re-registers self.schema from + # scratch via _ensure_ontology_initialized(). + try: + await self._ontology_store.clear() + except Exception as exc: + logger.warning("Ontology graph clear failed during delete_all (continuing): %s", exc) # Indexes were dropped along with the graph; force re-creation # on the next ensure_indices() call. self._vector_store._indices_ensured = False # The __GraphRAGConfig__ node is gone too; re-validate next time. self._config_validated = False + # Force re-registration of self.schema next call. + self._ontology_initialized = False # ── Ingestion ──────────────────────────────────────────────── @@ -532,6 +627,11 @@ async def _ingest_single( doc_info = DocumentInfo(uid=resolved_id, path=path_for_node) + # Load the persisted ontology and register self.schema before the + # pipeline runs — contradictions surface early, before any expensive + # extraction work is done. + await self._ensure_ontology_initialized() + pipeline = IngestionPipeline( loader=loader or TextLoader(), chunker=chunker or SentenceTokenCapChunking(), @@ -539,7 +639,7 @@ async def _ingest_single( resolver=resolver or ExactMatchResolution(), graph_store=self._graph_store, vector_store=self._vector_store, - schema=self.schema, + schema=self._global_schema, ) result = await pipeline.run(source, ctx, text=text, document_info=doc_info) diff --git a/graphrag_sdk/src/graphrag_sdk/core/models.py b/graphrag_sdk/src/graphrag_sdk/core/models.py index 8d88619b..cba79f84 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/models.py +++ b/graphrag_sdk/src/graphrag_sdk/core/models.py @@ -136,13 +136,44 @@ class DocumentRecord(DataModel): # ── Schema Types ───────────────────────────────────────────────── +_PROPERTY_TYPES: frozenset[str] = frozenset( + {"STRING", "INTEGER", "FLOAT", "BOOLEAN", "DATE", "LIST"} +) + +# Property names reserved by the SDK; user schemas may not declare them. +RESERVED_PROPERTY_NAMES: frozenset[str] = frozenset( + { + "name", + "type", + "description", + "source_chunk_ids", + "spans", + "rel_type", + "fact", + "src_name", + "tgt_name", + "id", + } +) + + class PropertyType(DataModel): """A property definition for a node or relationship type.""" name: str type: str = "STRING" # STRING, INTEGER, FLOAT, BOOLEAN, DATE, LIST description: str | None = None - required: bool = False + + @model_validator(mode="after") + def _normalize_type(self) -> PropertyType: + normalized = (self.type or "STRING").strip().upper() + if normalized not in _PROPERTY_TYPES: + raise ValueError( + f"PropertyType '{self.name}' has unsupported type " + f"{self.type!r}. Allowed: {sorted(_PROPERTY_TYPES)}" + ) + self.type = normalized + return self class EntityType(DataModel): @@ -185,6 +216,7 @@ class RelationType(DataModel): label: str description: str | None = None patterns: list[tuple[str, str]] = Field(default_factory=list) + properties: list[PropertyType] = Field(default_factory=list) # Identity is by label only — two RelationType instances with the same # label but different patterns compare/hash equal. Schemas are expected @@ -209,30 +241,122 @@ class GraphSchema(DataModel): relations: list[RelationType] = Field(default_factory=list) @model_validator(mode="after") - def _warn_on_undeclared_pattern_labels(self) -> GraphSchema: - """Warn when a ``RelationType.patterns`` references undeclared entity labels. - - Catches typos like ``("Persn", "Company")`` at config time, before any - extraction has run. We warn rather than raise: open-schema setups may - legitimately reference labels not (yet) listed in ``entities``. + def _validate_schema(self) -> GraphSchema: + """Validate the schema at config time. + + - Rejects user-declared property names that collide with SDK-reserved + keys (``name``, ``description``, ``source_chunk_ids``, ``spans``, + ``rel_type``, ``fact``, ``src_name``, ``tgt_name``, ``id``, ``type``). + - Warns when a ``RelationType.patterns`` references undeclared entity + labels (typo guard) — open-schema setups may legitimately reference + labels not yet listed in ``entities``, so this is a warning, not an + error. """ - if not self.entities: - return self - declared = {e.label for e in self.entities} - for rel in self.relations: - for src, tgt in rel.patterns: - missing = [lbl for lbl in (src, tgt) if lbl not in declared] - if missing: - logger.warning( - "RelationType '%s' pattern (%s, %s) references " - "entity label(s) not declared in schema.entities: %s", - rel.label, - src, - tgt, - ", ".join(missing), + for et in self.entities: + for prop in et.properties: + if prop.name in RESERVED_PROPERTY_NAMES: + raise ValueError( + f"EntityType '{et.label}' declares property " + f"'{prop.name}', which is reserved by the SDK. " + f"Reserved names: {sorted(RESERVED_PROPERTY_NAMES)}" + ) + for rt in self.relations: + for prop in rt.properties: + if prop.name in RESERVED_PROPERTY_NAMES: + raise ValueError( + f"RelationType '{rt.label}' declares property " + f"'{prop.name}', which is reserved by the SDK. " + f"Reserved names: {sorted(RESERVED_PROPERTY_NAMES)}" ) + + if self.entities: + declared = {e.label for e in self.entities} + for rel in self.relations: + for src, tgt in rel.patterns: + missing = [lbl for lbl in (src, tgt) if lbl not in declared] + if missing: + logger.warning( + "RelationType '%s' pattern (%s, %s) references " + "entity label(s) not declared in schema.entities: %s", + rel.label, + src, + tgt, + ", ".join(missing), + ) return self + @classmethod + def from_file(cls, path: str) -> GraphSchema: + """Load a ``GraphSchema`` from a JSON file. + + The schema-as-config workflow: keep the canonical schema in a JSON + file under version control, load it into the SDK with one call. See + :py:meth:`save_to_file` for the reverse direction. + """ + from pathlib import Path + + return cls.model_validate_json(Path(path).read_text(encoding="utf-8")) + + def save_to_file(self, path: str, *, indent: int = 2) -> None: + """Write this schema to ``path`` as JSON (overwrites existing files).""" + from pathlib import Path + + Path(path).write_text(self.model_dump_json(indent=indent), encoding="utf-8") + + def merge(self, other: GraphSchema) -> GraphSchema: + """Return a new ``GraphSchema`` that is the union of ``self`` and ``other``. + + - Entity / relation types are unioned by ``label``. + - For each type, ``properties`` are unioned by ``name``. When the same + property name appears in both, the incoming type/description overrides + (last-write-wins). + - For relations, ``patterns`` are unioned (order-preserving, deduped). + """ + + def _merge_props( + existing: list[PropertyType], incoming: list[PropertyType] + ) -> list[PropertyType]: + by_name: dict[str, PropertyType] = {p.name: p for p in existing} + for p in incoming: + by_name[p.name] = p + return list(by_name.values()) + + ent_by_label: dict[str, EntityType] = {e.label: e for e in self.entities} + for e in other.entities: + if e.label in ent_by_label: + cur = ent_by_label[e.label] + ent_by_label[e.label] = EntityType( + label=cur.label, + description=e.description or cur.description, + properties=_merge_props(cur.properties, e.properties), + ) + else: + ent_by_label[e.label] = e + + rel_by_label: dict[str, RelationType] = {r.label: r for r in self.relations} + for r in other.relations: + if r.label in rel_by_label: + cur = rel_by_label[r.label] + seen: set[tuple[str, str]] = set() + merged_patterns: list[tuple[str, str]] = [] + for pat in list(cur.patterns) + list(r.patterns): + if pat not in seen: + seen.add(pat) + merged_patterns.append(pat) + rel_by_label[r.label] = RelationType( + label=cur.label, + description=r.description or cur.description, + patterns=merged_patterns, + properties=_merge_props(cur.properties, r.properties), + ) + else: + rel_by_label[r.label] = r + + return GraphSchema( + entities=list(ent_by_label.values()), + relations=list(rel_by_label.values()), + ) + # ── Extraction / Resolution Output Types ───────────────────────── @@ -254,6 +378,7 @@ class ExtractedEntity(DataModel): type: str description: str = "" source_chunk_ids: list[str] = Field(default_factory=list) + attributes: dict[str, Any] = Field(default_factory=dict) class ExtractedRelation(DataModel): @@ -266,6 +391,7 @@ class ExtractedRelation(DataModel): description: str = "" weight: float = 1.0 source_chunk_ids: list[str] = Field(default_factory=list) + attributes: dict[str, Any] = Field(default_factory=dict) class EntityMention(DataModel): diff --git a/graphrag_sdk/src/graphrag_sdk/ingestion/extraction_strategies/graph_extraction.py b/graphrag_sdk/src/graphrag_sdk/ingestion/extraction_strategies/graph_extraction.py index 09a56121..7ec67b4f 100644 --- a/graphrag_sdk/src/graphrag_sdk/ingestion/extraction_strategies/graph_extraction.py +++ b/graphrag_sdk/src/graphrag_sdk/ingestion/extraction_strategies/graph_extraction.py @@ -7,6 +7,7 @@ import asyncio import json import logging +from datetime import date, datetime from typing import Any from graphrag_sdk.core.context import Context @@ -18,6 +19,7 @@ GraphNode, GraphRelationship, GraphSchema, + PropertyType, RelationType, TextChunks, ) @@ -48,6 +50,7 @@ "## Entity Types\n" "{entity_types}\n\n" "{relation_patterns}" + "{attribute_schema_block}" "## Pre-extracted Entities\n" "{entities_json}\n\n" "## Text\n" @@ -74,13 +77,161 @@ "sentence for this relationship starts.\n" "- span_end: the character offset where the evidence sentence ends.\n\n" "Return ONLY a JSON object with two arrays:\n" + "{json_example}\n\n" + "Return ONLY valid JSON, nothing else." +) + +_DEFAULT_JSON_EXAMPLE = ( '{{"entities": [{{"name": "...", "type": "...", "description": "..."}}], ' '"relationships": [{{"source": "...", "target": "...", "type": "...", ' - '"description": "...", "span_start": 0, "span_end": 50}}]}}\n\n' - "Return ONLY valid JSON, nothing else." + '"description": "...", "span_start": 0, "span_end": 50}}]}}' +) + +_JSON_EXAMPLE_WITH_ATTRS = ( + '{{"entities": [{{"name": "...", "type": "...", "description": "...", ' + '"attributes": {{}}}}], ' + '"relationships": [{{"source": "...", "target": "...", "type": "...", ' + '"description": "...", "span_start": 0, "span_end": 50, ' + '"attributes": {{}}}}]}}' ) +def _format_property_for_prompt(prop: PropertyType) -> str: + desc = f" — {prop.description}" if prop.description else "" + return f" - {prop.name} ({prop.type}){desc}" + + +def _render_attribute_schema_block(schema: GraphSchema) -> str: + """Return the prompt block listing declared entity/relation attributes. + + Returns ``""`` when no entity or relation type declares any property — this + keeps the prompt identical to today for property-less schemas (no token + drift, no quality regression for the existing extraction path). + """ + ent_lines: list[str] = [] + for et in schema.entities: + if not et.properties: + continue + ent_lines.append(f"- {et.label}:") + ent_lines.extend(_format_property_for_prompt(p) for p in et.properties) + + rel_lines: list[str] = [] + for rt in schema.relations: + if not rt.properties: + continue + rel_lines.append(f"- {rt.label}:") + rel_lines.extend(_format_property_for_prompt(p) for p in rt.properties) + + if not ent_lines and not rel_lines: + return "" + + parts = [ + "## Attribute extraction", + "For each verified entity and each extracted relationship, also extract any " + "of the declared properties below. Use null when the property is not stated " + "in the text. Coerce to the declared type (INTEGER, FLOAT, BOOLEAN, DATE in " + "ISO 8601, STRING, LIST). Place extracted values inside an `attributes` " + "object on each entity / relationship.", + ] + if ent_lines: + parts.append("### Entity attributes") + parts.extend(ent_lines) + if rel_lines: + parts.append("### Relation attributes") + parts.extend(rel_lines) + return "\n".join(parts) + "\n\n" + + +def _schema_has_attributes(schema: GraphSchema) -> bool: + return any(et.properties for et in schema.entities) or any( + rt.properties for rt in schema.relations + ) + + +def _coerce_attribute_value(value: Any, prop_type: str) -> tuple[bool, Any]: + """Coerce ``value`` to ``prop_type``. Returns ``(ok, coerced)``. + + ``ok=False`` means the value should be dropped (incl. ``None``). + """ + if value is None: + return False, None + pt = prop_type.upper() + try: + if pt == "STRING": + s = str(value).strip() + return (True, s) if s else (False, None) + if pt == "INTEGER": + if isinstance(value, bool): + return False, None + return True, int(float(value)) + if pt == "FLOAT": + if isinstance(value, bool): + return False, None + return True, float(value) + if pt == "BOOLEAN": + if isinstance(value, bool): + return True, value + if isinstance(value, (int, float)): + return True, bool(value) + if isinstance(value, str): + v = value.strip().lower() + if v in {"true", "yes", "y", "1"}: + return True, True + if v in {"false", "no", "n", "0"}: + return True, False + return False, None + if pt == "DATE": + if isinstance(value, (date, datetime)): + return True, (value.date() if isinstance(value, datetime) else value).isoformat() + if isinstance(value, str): + v = value.strip() + if not v: + return False, None + # Accept either pure date or full ISO 8601 + try: + parsed = date.fromisoformat(v) + return True, parsed.isoformat() + except ValueError: + try: + parsed_dt = datetime.fromisoformat(v.replace("Z", "+00:00")) + return True, parsed_dt.date().isoformat() + except ValueError: + return False, None + return False, None + if pt == "LIST": + if isinstance(value, dict): + return False, None + if isinstance(value, (list, tuple, set)): + return True, list(value) + return True, [value] + except (TypeError, ValueError): + return False, None + return False, None + + +def _coerce_attributes( + raw: dict[str, Any] | None, + declared: dict[str, PropertyType], +) -> dict[str, Any]: + """Apply per-type coercion against declared properties. + + Every declared property name appears in the result. Values that are absent + from ``raw`` or fail to coerce are recorded as ``None`` — never drop the + enclosing record. Downstream storage strips ``None`` before writing, so + these turn into "key missing" on the graph node, which is the right + behavior for retrieval (``WHERE p.age > N`` naturally excludes them). + """ + coerced: dict[str, Any] = {} + raw = raw or {} + for prop_name, prop in declared.items(): + if prop_name not in raw or raw[prop_name] is None: + coerced[prop_name] = None + continue + ok, val = _coerce_attribute_value(raw[prop_name], prop.type) + coerced[prop_name] = val if ok else None + return coerced + + def _optional_extras(obj: Any) -> dict[str, Any]: """Extract optional spans/confidence from a Pydantic extra-allow object.""" extra: dict[str, Any] = {} @@ -306,12 +457,15 @@ async def _step1(text: str, chunk_uid: str) -> list[ExtractedEntity]: entities_json = json.dumps( [{"name": e.name, "type": e.type, "description": e.description} for e in ents] ) + has_attrs = _schema_has_attributes(schema) prompt = VERIFY_EXTRACT_RELS_PROMPT.format( entity_types=_format_entity_types(entity_types, entity_type_descs), relation_patterns=_format_relation_patterns(schema.relations), + attribute_schema_block=_render_attribute_schema_block(schema), relationship_type_instruction=_relationship_type_instruction(schema.relations), entities_json=entities_json, text=text, + json_example=_JSON_EXAMPLE_WITH_ATTRS if has_attrs else _DEFAULT_JSON_EXAMPLE, ) step2_prompts.append(prompt) step2_indices.append(i) @@ -340,7 +494,7 @@ async def _step1(text: str, chunk_uid: str) -> list[ExtractedEntity]: assert item.response is not None verified_ents, rels = self._parse_step2_response( - item.response.content, entity_types, chunk.uid + item.response.content, entity_types, chunk.uid, schema ) if verified_ents: # Carry over spans/confidence from step 1 entities @@ -436,8 +590,16 @@ def _parse_step2_response( content: str, entity_types: list[str], source_chunk_id: str, + schema: GraphSchema | None = None, ) -> tuple[list[ExtractedEntity], list[ExtractedRelation]]: - """Parse the step 2 LLM response (verified entities + relationships).""" + """Parse the step 2 LLM response (verified entities + relationships). + + When ``schema`` declares attributes for an entity / relation type, every + declared attribute appears in the record's ``attributes`` dict, with + ``None`` for values the LLM didn't supply or couldn't coerce. Records + are never dropped — downstream storage strips ``None`` so the graph + sees "key missing" for the unfilled slots. + """ text = _strip_markdown_fences(content) try: @@ -449,6 +611,16 @@ def _parse_step2_response( if not isinstance(data, dict): return [], [] + ent_props_by_label: dict[str, dict[str, PropertyType]] = {} + rel_props_by_label: dict[str, dict[str, PropertyType]] = {} + if schema is not None: + for et in schema.entities: + if et.properties: + ent_props_by_label[et.label] = {p.name: p for p in et.properties} + for rt in schema.relations: + if rt.properties: + rel_props_by_label[rt.label] = {p.name: p for p in rt.properties} + # Parse entities entities: list[ExtractedEntity] = [] for item in data.get("entities", []): @@ -463,12 +635,20 @@ def _parse_step2_response( etype = label_for_type(raw_type, entity_types) description = str(item.get("description", "")).strip() + declared = ent_props_by_label.get(etype, {}) + attributes: dict[str, Any] = {} + if declared: + _raw = item.get("attributes") + raw_attrs = _raw if isinstance(_raw, dict) else {} + attributes = _coerce_attributes(raw_attrs, declared) + entities.append( ExtractedEntity( name=name, type=etype, description=description, source_chunk_ids=[source_chunk_id], + attributes=attributes, ) ) @@ -492,6 +672,13 @@ def _parse_step2_response( if spans: extra["spans"] = spans + declared_rel = rel_props_by_label.get(rel_type, {}) + attributes_rel: dict[str, Any] = {} + if declared_rel: + _raw = item.get("attributes") + raw_attrs = _raw if isinstance(_raw, dict) else {} + attributes_rel = _coerce_attributes(raw_attrs, declared_rel) + relations.append( ExtractedRelation( source=source, @@ -499,6 +686,7 @@ def _parse_step2_response( type=rel_type, description=description, source_chunk_ids=[source_chunk_id], + attributes=attributes_rel, **extra, ) ) @@ -538,12 +726,17 @@ def _aggregate_entities( for chunk_id, offsets in ent_spans.items(): existing_spans.setdefault(chunk_id, []).extend(offsets) existing.spans = existing_spans # type: ignore[attr-defined] + # Merge attributes: last-write-wins per key, matching FalkorDB's + # ``SET n += props`` semantics at the storage layer. + if ent.attributes: + existing.attributes.update(ent.attributes) else: seen[key] = ExtractedEntity( name=ent.name, type=ent.type, description=ent.description, source_chunk_ids=list(ent.source_chunk_ids), + attributes=dict(ent.attributes), **_optional_extras(ent), ) return list(seen.values()) @@ -577,6 +770,8 @@ def _aggregate_relations( for chunk_id, offsets in rel_spans.items(): existing_spans.setdefault(chunk_id, []).extend(offsets) existing.spans = existing_spans # type: ignore[attr-defined] + if rel.attributes: + existing.attributes.update(rel.attributes) else: seen[key] = ExtractedRelation( source=rel.source, @@ -584,6 +779,7 @@ def _aggregate_relations( type=rel.type, description=rel.description, source_chunk_ids=list(rel.source_chunk_ids), + attributes=dict(rel.attributes), **_optional_extras(rel), ) return list(seen.values()) @@ -609,6 +805,10 @@ def _entities_to_nodes( spans = getattr(ent, "spans", None) if spans: props["spans"] = spans + # Merge schema-declared attributes. Reserved-name collisions are + # rejected at schema-validation time, so update() is safe. + if ent.attributes: + props.update(ent.attributes) nodes.append( GraphNode( id=node_id, @@ -644,6 +844,8 @@ def _relations_to_relationships( spans = getattr(rel, "spans", None) if spans: props["spans"] = spans + if rel.attributes: + props.update(rel.attributes) relationships.append( GraphRelationship( start_node_id=compute_entity_id(rel.source, src_type), diff --git a/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py b/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py index 165cbeeb..98d5d99f 100644 --- a/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py +++ b/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py @@ -188,6 +188,8 @@ async def run( # Step 5: Prune against schema ctx.log("Step 5/9: Pruning against schema") graph_data = self._prune(graph_data, self.schema) + graph_data = self._validate_attributes(graph_data, self.schema) + graph_data = self._filter_quality(graph_data) # Step 6: Resolve duplicate entities ctx.log("Step 6/9: Resolving duplicates") @@ -449,6 +451,77 @@ def _prune(self, graph_data: GraphData, schema: GraphSchema) -> GraphData: extracted_relations=graph_data.extracted_relations, ) + def _validate_attributes(self, graph_data: GraphData, schema: GraphSchema) -> GraphData: + """Strip undeclared attribute keys from nodes / relationships. + + - Keys not declared in the schema and not in the SDK-reserved set are + dropped (debug-logged). + - Records are never dropped. Missing values stay missing on the graph, + which is the correct null semantics for retrieval queries + (``WHERE p.age > N`` naturally excludes nodes without ``age``). + - Skips ``"Unknown"`` nodes (preserved by :py:meth:`_prune`). + """ + if not schema.entities and not schema.relations: + return graph_data + + from graphrag_sdk.core.models import RESERVED_PROPERTY_NAMES + + ent_declared: dict[str, dict[str, Any]] = { + e.label: {p.name: p for p in e.properties} for e in schema.entities + } + rel_declared: dict[str, dict[str, Any]] = { + r.label: {p.name: p for p in r.properties} for r in schema.relations + } + + kept_nodes: list[GraphNode] = [] + unknown_dropped_by_label: dict[str, int] = {} + for node in graph_data.nodes: + if node.label == "Unknown" or node.label not in ent_declared: + kept_nodes.append(node) + continue + declared = ent_declared[node.label] + unknown_keys = [ + k + for k in list(node.properties.keys()) + if k not in declared and k not in RESERVED_PROPERTY_NAMES + ] + for k in unknown_keys: + node.properties.pop(k, None) + unknown_dropped_by_label[node.label] = ( + unknown_dropped_by_label.get(node.label, 0) + 1 + ) + kept_nodes.append(node) + + for label, count in unknown_dropped_by_label.items(): + logger.debug( + "Dropped %d undeclared attribute key(s) on %s nodes", + count, + label, + ) + + kept_rels: list[GraphRelationship] = [] + for rel in graph_data.relationships: + rel_label = rel.properties.get("rel_type", rel.type) + if rel_label not in rel_declared: + kept_rels.append(rel) + continue + declared = rel_declared[rel_label] + for k in [ + k + for k in list(rel.properties.keys()) + if k not in declared and k not in RESERVED_PROPERTY_NAMES + ]: + rel.properties.pop(k, None) + kept_rels.append(rel) + + return GraphData( + nodes=kept_nodes, + relationships=kept_rels, + mentions=graph_data.mentions, + extracted_entities=graph_data.extracted_entities, + extracted_relations=graph_data.extracted_relations, + ) + def _filter_quality(self, graph_data: GraphData) -> GraphData: """Remove nodes with empty IDs or labels, and dangling relationships.""" valid_nodes = [n for n in graph_data.nodes if n.id and n.label] diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py index 0dd3d289..e5e6aca4 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py @@ -12,6 +12,8 @@ import re from typing import Any +from graphrag_sdk.core.models import GraphSchema + logger = logging.getLogger(__name__) # ── Valid labels for our graph schema ──────────────────────────── @@ -36,31 +38,133 @@ _ALL_LABELS = _ENTITY_LABELS | _STRUCTURAL_LABELS +_RESERVED_ENTITY_PROPS = ( + ("name", "STRING", "entity name"), + ("description", "STRING", "entity description"), +) +_RESERVED_REL_PROPS = ( + ("rel_type", "STRING", "original relation type as a string"), + ("fact", "STRING", "evidence text for the relation"), + ("src_name", "STRING", "source entity name"), + ("tgt_name", "STRING", "target entity name"), +) + +_NUMERIC_TYPES = frozenset({"INTEGER", "FLOAT"}) + _WRITE_KEYWORDS = re.compile( r"\b(CREATE|DELETE|DETACH|SET|REMOVE|MERGE|DROP|CALL\s+db\.idx)\b", re.IGNORECASE, ) + +def _labels_from_schema(schema: GraphSchema | None) -> frozenset[str]: + """Return the entity labels declared in ``schema``, falling back to the + historical hardcoded set when the schema is empty (open-schema mode). + """ + if schema is None or not schema.entities: + return _ENTITY_LABELS + return frozenset(e.label for e in schema.entities) + + +def render_schema_block(schema: GraphSchema | None) -> str: + """Render a Markdown schema block listing declared entity / relation + properties, derived from the live ``GraphSchema``. + + Mirrors LangChain's ``Neo4jGraph.get_schema()`` and LlamaIndex's + ``Neo4jPropertyGraphStore.get_schema_str()``. Always emits the reserved + SDK keys (``name``, ``description`` on entities; ``rel_type``, ``fact``, + ``src_name``, ``tgt_name`` on RELATES) so the LLM still knows about them + even when the schema declares no custom properties. + """ + labels: list[str] + rel_labels: list[str] + ent_props: dict[str, list[tuple[str, str, str | None]]] = {} + rel_props: dict[str, list[tuple[str, str, str | None]]] = {} + + if schema is not None and schema.entities: + labels = [e.label for e in schema.entities] + for e in schema.entities: + ent_props[e.label] = [(p.name, p.type, p.description) for p in e.properties] + else: + labels = sorted(_ENTITY_LABELS) + + if schema is not None and schema.relations: + rel_labels = [r.label for r in schema.relations] + for r in schema.relations: + rel_props[r.label] = [(p.name, p.type, p.description) for p in r.properties] + else: + rel_labels = [] + + lines: list[str] = [] + lines.append("### Entity node labels (all entities also carry `__Entity__`):") + for label in labels: + lines.append(f"- {label}") + lines.append(" Properties:") + for name, typ, desc in _RESERVED_ENTITY_PROPS: + lines.append(f" - {name} ({typ}) — {desc}") + for name, typ, desc in ent_props.get(label, []): + d = f" — {desc}" if desc else "" + lines.append(f" - {name} ({typ}){d}") + + lines.append("") + lines.append("### Edge types:") + lines.append("- RELATES: connects any entity to any entity.") + lines.append(" Properties:") + for name, typ, desc in _RESERVED_REL_PROPS: + lines.append(f" - {name} ({typ}) — {desc}") + union_rel_props: dict[str, tuple[str, str | None]] = {} + for label in rel_labels: + for name, typ, desc in rel_props.get(label, []): + if name not in union_rel_props: + union_rel_props[name] = (typ, desc) + for name, (typ, desc) in union_rel_props.items(): + d = f" — {desc}" if desc else "" + lines.append(f" - {name} ({typ}){d} # declared on RELATES via rel_type filters") + if rel_labels: + lines.append(" Allowed `rel_type` values: " + ", ".join(rel_labels)) + lines.append("- MENTIONED_IN: connects entity to Chunk node (provenance)") + lines.append("- PART_OF: connects Document to Chunk") + lines.append("- NEXT_CHUNK: connects Chunk to next sequential Chunk") + return "\n".join(lines) + + +def _render_attribute_examples(schema: GraphSchema | None) -> str: + """Synthesize one filter example per declared numeric attribute. + + Helps the LLM learn that custom numeric properties exist and can be used + in ``WHERE`` / ``ORDER BY`` / aggregations. Returns ``""`` when no + numeric attributes are declared. + """ + if schema is None: + return "" + examples: list[str] = [] + for et in schema.entities: + for p in et.properties: + if p.type in _NUMERIC_TYPES and len(examples) < 2: + var = et.label[0].lower() + examples.append( + f'Question: "Which {et.label} has the highest {p.name}?"\n' + f"```cypher\n" + f"MATCH ({var}:{et.label})\n" + f"WHERE {var}.{p.name} IS NOT NULL\n" + f"RETURN {var}.name AS name, {var}.{p.name} AS {p.name}\n" + f"ORDER BY {var}.{p.name} DESC\n" + f"LIMIT 10\n" + f"```" + ) + if not examples: + return "" + return "\n\n" + "\n\n".join(examples) + + # ── Schema prompt ──────────────────────────────────────────────── -SCHEMA_PROMPT = """\ +_SCHEMA_PROMPT_TEMPLATE = """\ You are a Cypher query generator for a FalkorDB graph database. ## Graph Schema -### Entity node labels (all entities also carry the label `__Entity__`): -Person, Organization, Technology, Product, Location, Date, Event, Concept, Law, Dataset, Method - -### Entity node properties: -- name (string) — entity name -- description (string) — entity description - -### Edge types: -- RELATES: connects any entity to any entity. - Properties: rel_type (string), fact (string — evidence text), src_name (string), tgt_name (string) -- MENTIONED_IN: connects entity to Chunk node (provenance) -- PART_OF: connects Document to Chunk -- NEXT_CHUNK: connects Chunk to next sequential Chunk +{schema_block} ## FalkorDB-specific rules (CRITICAL — violating these causes execution errors): 1. Do NOT use shortestPath() or allShortestPaths() — FalkorDB returns @@ -127,7 +231,7 @@ MATCH (o:Organization)-[r:RELATES]-(t:Technology) RETURN o.name AS organization, t.name AS technology, r.rel_type AS relation, r.fact AS evidence LIMIT 20 -``` +```{attribute_examples} ## Your task @@ -139,6 +243,27 @@ """ +def build_schema_prompt(schema: GraphSchema | None, question: str) -> str: + """Build the full Cypher generation prompt for ``question`` from ``schema``. + + When ``schema`` is empty, the prompt falls back to the historical + hardcoded label set and matches today's behavior bit-for-bit aside from + the new schema block formatting. + """ + return _SCHEMA_PROMPT_TEMPLATE.format( + schema_block=render_schema_block(schema), + attribute_examples=_render_attribute_examples(schema), + question=question, + ) + + +# Backwards-compatible alias for callers that import SCHEMA_PROMPT. +# It exposes the template form (still expects ``{schema_block}``, +# ``{attribute_examples}``, and ``{question}`` placeholders) — direct +# ``.format(question=...)`` callers should migrate to ``build_schema_prompt``. +SCHEMA_PROMPT = _SCHEMA_PROMPT_TEMPLATE + + # ── Cypher extraction ──────────────────────────────────────────── @@ -184,12 +309,16 @@ def _sanitize_cypher(cypher: str) -> str: # ── Cypher validation ──────────────────────────────────────────── -def validate_cypher(cypher: str) -> list[str]: +def validate_cypher(cypher: str, schema: GraphSchema | None = None) -> list[str]: """Validate generated Cypher for safety and correctness. Uses an allowlist approach: the query must start with a read-only keyword, and dangerous constructs are explicitly rejected. + When ``schema`` is provided, label validation uses the labels declared + in the schema (plus structural labels); otherwise it falls back to the + historical hardcoded label set. + Returns list of error strings; empty list means valid. """ errors: list[str] = [] @@ -229,9 +358,10 @@ def validate_cypher(cypher: str) -> list[str]: errors.append("Missing RETURN clause") # Check referenced labels exist in schema + allowed_labels = _labels_from_schema(schema) | _STRUCTURAL_LABELS label_pattern = re.findall(r"\((?:\w+)?:(\w+)", cypher_norm) for label in label_pattern: - if label not in _ALL_LABELS: + if label not in allowed_labels: errors.append(f"Unknown label: {label}") return errors @@ -244,13 +374,17 @@ async def generate_cypher( llm: Any, question: str, *, + schema: GraphSchema | None = None, max_retries: int = 3, ) -> str | None: """Generate a Cypher query from a natural language question. + When ``schema`` is provided, the prompt and validator both use the + declared labels and properties. + Returns the Cypher string, or None if all retries fail. """ - prompt = SCHEMA_PROMPT.format(question=question) + prompt = build_schema_prompt(schema, question) last_error = "" for attempt in range(max_retries): @@ -270,7 +404,7 @@ async def generate_cypher( last_error = "Empty query generated" continue - errors = validate_cypher(cypher) + errors = validate_cypher(cypher, schema) if errors: last_error = "; ".join(errors) continue @@ -292,6 +426,7 @@ async def execute_cypher_retrieval( llm: Any, question: str, *, + schema: GraphSchema | None = None, max_retries: int = 3, ) -> tuple[list[str], dict[str, dict]]: """Full text-to-cypher retrieval: generate -> validate -> execute -> parse. @@ -306,7 +441,7 @@ async def execute_cypher_retrieval( On any failure, returns empty results (silent degradation). """ - cypher = await generate_cypher(llm, question, max_retries=max_retries) + cypher = await generate_cypher(llm, question, schema=schema, max_retries=max_retries) if not cypher: return [], {} diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py index d7bcbd9b..7397c587 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py @@ -166,6 +166,7 @@ def __init__( rel_top_k: int = 15, # Matched to chunk_top_k for balanced fact/passage coverage keyword_limit: int = 10, # Fulltext search keyword budget from query decomposition enable_cypher: bool = False, # Text-to-Cypher path (experimental, off by default) + schema: Any | None = None, # GraphSchema; forwarded to Cypher generation when enable_cypher ) -> None: super().__init__(graph_store=graph_store, vector_store=vector_store) self._embedder = embedder @@ -176,6 +177,7 @@ def __init__( self._rel_top_k = rel_top_k self._keyword_limit = keyword_limit self._enable_cypher = enable_cypher + self._schema = schema # -- Template Method hook -- @@ -197,7 +199,7 @@ async def _execute( if self._enable_cypher: results = await asyncio.gather( search_relates_edges(self._vector, query_vector, self._rel_top_k), - execute_cypher_retrieval(self._graph, self._llm, query), + execute_cypher_retrieval(self._graph, self._llm, query, schema=self._schema), return_exceptions=True, ) # Unpack RELATES results diff --git a/graphrag_sdk/src/graphrag_sdk/storage/__init__.py b/graphrag_sdk/src/graphrag_sdk/storage/__init__.py index 610abe39..7f33d890 100644 --- a/graphrag_sdk/src/graphrag_sdk/storage/__init__.py +++ b/graphrag_sdk/src/graphrag_sdk/storage/__init__.py @@ -3,6 +3,7 @@ from graphrag_sdk.storage.deduplicator import EntityDeduplicator from graphrag_sdk.storage.graph_store import GraphStore +from graphrag_sdk.storage.ontology_store import OntologyStore from graphrag_sdk.storage.vector_store import VectorStore -__all__ = ["EntityDeduplicator", "GraphStore", "VectorStore"] +__all__ = ["EntityDeduplicator", "GraphStore", "OntologyStore", "VectorStore"] diff --git a/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py b/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py new file mode 100644 index 00000000..81342639 --- /dev/null +++ b/graphrag_sdk/src/graphrag_sdk/storage/ontology_store.py @@ -0,0 +1,304 @@ +"""Persistent ontology storage in a dedicated FalkorDB graph. + +The ontology lives in a separate FalkorDB graph named ``__ontology`` +and is the **anchor** for the working schema: + +- Always-on: every :py:class:`GraphRAG` has exactly one ontology graph, + created lazily on first use, dropped on ``delete_all()``. +- Single source of truth: retrieval, ``get_ontology()``, and any cross-process + worker all read from the same graph. +- Additive only: :py:meth:`register` validates incoming schema against what's + already persisted and refuses **type contradictions** on existing properties. + New entity types, relation types, properties, and relation patterns are all + welcome. Re-typing an existing property is not. + +Users who want a curated, declarative schema (descriptions, future flags, +properties not yet observed in the data) supply a ``schema`` to ``GraphRAG``; +it gets registered into the ontology graph on first connection. JSON +import/export via :py:meth:`GraphSchema.save_to_file` / ``from_file`` is a +review / version-control bridge — the ontology graph is the canonical copy. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from graphrag_sdk.core.connection import FalkorDBConnection +from graphrag_sdk.core.models import ( + EntityType, + GraphSchema, + PropertyType, + RelationType, +) + +logger = logging.getLogger(__name__) + + +class OntologyContradictionError(ValueError): + """Raised when an incoming schema redefines an existing property's type. + + The ontology is additive: new labels, properties, and relation patterns are + welcome, but re-typing a property already registered on a label is + explicitly rejected so downstream Cypher queries don't break silently. + """ + + +def _encode_patterns(patterns: list[tuple[str, str]]) -> list[str]: + return [f"{src}|{tgt}" for src, tgt in patterns] + + +def _decode_patterns(encoded: list[str] | None) -> list[tuple[str, str]]: + out: list[tuple[str, str]] = [] + for s in encoded or []: + if not isinstance(s, str) or "|" not in s: + continue + src, tgt = s.split("|", 1) + out.append((src, tgt)) + return out + + +def _props_from_rows(rows: list[Any] | None) -> list[PropertyType]: + """Reconstruct PropertyType objects from a ``collect(...)`` query result. + + Filters out the null-keyed dict FalkorDB returns for an OPTIONAL MATCH + with no matches. + """ + out: list[PropertyType] = [] + for row in rows or []: + if not row or not isinstance(row, dict): + continue + name = row.get("name") + if not name: + continue + out.append( + PropertyType( + name=name, + type=row.get("type") or "STRING", + description=row.get("description"), + ) + ) + return out + + +class OntologyStore: + """Persists and loads :py:class:`GraphSchema` in a dedicated FalkorDB graph. + + Owns its own graph handle, derived from the data-graph connection's + driver. Queries go directly to the FalkorDB driver and bypass the + connection's retry / circuit-breaker — ontology operations are + infrequent, idempotent, and tolerant of a single failure. + """ + + ONTOLOGY_GRAPH_SUFFIX = "__ontology" + + def __init__(self, connection: FalkorDBConnection, data_graph_name: str) -> None: + self._conn = connection + self._graph_name = f"{data_graph_name}{self.ONTOLOGY_GRAPH_SUFFIX}" + self._graph: Any | None = None + + @property + def graph_name(self) -> str: + return self._graph_name + + def _ensure_graph(self) -> Any: + if self._graph is not None: + return self._graph + self._conn._ensure_client() + driver = self._conn._driver + if driver is None: + raise RuntimeError("FalkorDB driver not initialised on connection") + self._graph = driver.select_graph(self._graph_name) + return self._graph + + async def _query(self, cypher: str, params: dict[str, Any] | None = None) -> Any: + graph = self._ensure_graph() + return await graph.query(cypher, params=params) + + # ── Load ───────────────────────────────────────────────────── + + async def load(self) -> GraphSchema: + """Read the ontology graph and reconstruct a :py:class:`GraphSchema`. + + Returns an empty schema if the ontology graph does not yet exist or + introspection fails. Failure is logged at DEBUG so we don't spam an + unconfigured GraphRAG instance with warnings. + """ + try: + ent_result = await self._query( + "MATCH (e:OntologyEntityType) " + "OPTIONAL MATCH (e)-[:HAS_PROPERTY]->(p:OntologyProperty) " + "RETURN e.label AS label, e.description AS description, " + "collect({name: p.name, type: p.type, description: p.description}) AS properties" + ) + rel_result = await self._query( + "MATCH (r:OntologyRelationType) " + "OPTIONAL MATCH (r)-[:HAS_PROPERTY]->(p:OntologyProperty) " + "RETURN r.label AS label, r.description AS description, " + "r.patterns AS patterns, " + "collect({name: p.name, type: p.type, description: p.description}) AS properties" + ) + except Exception as exc: + logger.debug("Ontology load failed (returning empty schema): %s", exc) + return GraphSchema() + + ent_rows = getattr(ent_result, "result_set", None) or [] + rel_rows = getattr(rel_result, "result_set", None) or [] + if not isinstance(ent_rows, list): + ent_rows = [] + if not isinstance(rel_rows, list): + rel_rows = [] + + entities = [ + EntityType( + label=row[0], + description=row[1], + properties=_props_from_rows(row[2]), + ) + for row in ent_rows + if isinstance(row, list) and len(row) >= 3 and row[0] + ] + relations = [ + RelationType( + label=row[0], + description=row[1], + patterns=_decode_patterns(row[2]), + properties=_props_from_rows(row[3]), + ) + for row in rel_rows + if isinstance(row, list) and len(row) >= 4 and row[0] + ] + return GraphSchema(entities=entities, relations=relations) + + # ── Register ───────────────────────────────────────────────── + + async def register(self, schema: GraphSchema) -> GraphSchema: + """Merge ``schema`` into the persisted ontology and return the union. + + Validates first: if ``schema`` redefines the type of a property + already registered on the same entity/relation label, raises + :py:class:`OntologyContradictionError` before any partial state is + persisted. + + Additive operations — new entity types, new relations, new properties, + new relation patterns — go through unchanged. + """ + if not schema.entities and not schema.relations: + return await self.load() + + existing = await self.load() + self._check_no_contradictions(existing, schema) + + for et in schema.entities: + await self._upsert_entity_type(et) + for rt in schema.relations: + await self._upsert_relation_type(rt) + + return await self.load() + + @staticmethod + def _check_no_contradictions(existing: GraphSchema, incoming: GraphSchema) -> None: + """Raise :py:class:`OntologyContradictionError` on any type re-declaration.""" + existing_ent_types: dict[tuple[str, str], str] = { + (e.label, p.name): p.type for e in existing.entities for p in e.properties + } + for et in incoming.entities: + for p in et.properties: + prior = existing_ent_types.get((et.label, p.name)) + if prior is not None and prior != p.type: + raise OntologyContradictionError( + f"Property '{et.label}.{p.name}' is already registered as " + f"{prior}; refusing to redefine as {p.type}. The ontology " + f"is additive — drop the data graph and start fresh if you " + f"need to change a property's type." + ) + + existing_rel_types: dict[tuple[str, str], str] = { + (r.label, p.name): p.type for r in existing.relations for p in r.properties + } + for rt in incoming.relations: + for p in rt.properties: + prior = existing_rel_types.get((rt.label, p.name)) + if prior is not None and prior != p.type: + raise OntologyContradictionError( + f"Property '{rt.label}.{p.name}' (on relation) is already " + f"registered as {prior}; refusing to redefine as {p.type}." + ) + + async def _upsert_entity_type(self, et: EntityType) -> None: + await self._query( + "MERGE (e:OntologyEntityType {label: $label}) " + "SET e.description = coalesce($description, e.description)", + {"label": et.label, "description": et.description}, + ) + for prop in et.properties: + await self._upsert_property(et.label, prop, owner_label="OntologyEntityType") + + async def _upsert_relation_type(self, rt: RelationType) -> None: + new_patterns = _encode_patterns(rt.patterns) + result = await self._query( + "MATCH (r:OntologyRelationType {label: $label}) RETURN r.patterns AS patterns", + {"label": rt.label}, + ) + existing_patterns: list[str] = [] + rows = getattr(result, "result_set", None) or [] + if isinstance(rows, list) and rows and isinstance(rows[0], list) and rows[0]: + existing_patterns = list(rows[0][0] or []) + seen: set[str] = set() + merged: list[str] = [] + for s in existing_patterns + new_patterns: + if s not in seen: + seen.add(s) + merged.append(s) + await self._query( + "MERGE (r:OntologyRelationType {label: $label}) " + "SET r.description = coalesce($description, r.description), " + "r.patterns = $patterns", + {"label": rt.label, "description": rt.description, "patterns": merged}, + ) + for prop in rt.properties: + await self._upsert_property(rt.label, prop, owner_label="OntologyRelationType") + + async def _upsert_property( + self, owner_label_value: str, prop: PropertyType, *, owner_label: str + ) -> None: + # Property nodes are keyed by ``(owner_label_kind, owner_label, name)`` + # so two different types can declare the same property name without + # trampling each other's metadata. + owner_alias = "ent" if owner_label == "OntologyEntityType" else "rel" + await self._query( + f"MATCH ({owner_alias}:{owner_label} {{label: $owner}}) " + f"MERGE ({owner_alias})-[:HAS_PROPERTY]->" + f"(p:OntologyProperty {{name: $name, owner: $owner_kind, owner_label: $owner}}) " + "SET p.type = $type, " + "p.description = coalesce($description, p.description)", + { + "owner": owner_label_value, + "owner_kind": owner_label, + "name": prop.name, + "type": prop.type, + "description": prop.description, + }, + ) + + # ── Clear ──────────────────────────────────────────────────── + + async def clear(self) -> None: + """Drop the ontology graph (``GRAPH.DELETE``). Idempotent. + + Called from ``GraphRAG.delete_all()`` so the ontology graph never + outlives the data graph. + """ + self._conn._ensure_client() + from redis.asyncio import Redis + + redis: Redis = Redis(connection_pool=self._conn._pool) + try: + await redis.execute_command("GRAPH.DELETE", self._graph_name) + except Exception as exc: + msg = str(exc).lower() + if "empty" in msg or "invalid" in msg or "key" in msg: + logger.debug("Ontology graph '%s' already empty", self._graph_name) + else: + raise + self._graph = None diff --git a/graphrag_sdk/tests/test_attribute_prompt.py b/graphrag_sdk/tests/test_attribute_prompt.py new file mode 100644 index 00000000..905a1fbc --- /dev/null +++ b/graphrag_sdk/tests/test_attribute_prompt.py @@ -0,0 +1,191 @@ +"""Tests for the attribute-aware prompt and coercion helpers in +``graph_extraction.py``.""" +from __future__ import annotations + +import pytest + +from graphrag_sdk.core.models import ( + EntityType, + GraphSchema, + PropertyType, + RelationType, +) +from graphrag_sdk.ingestion.extraction_strategies.graph_extraction import ( + VERIFY_EXTRACT_RELS_PROMPT, + _DEFAULT_JSON_EXAMPLE, + _JSON_EXAMPLE_WITH_ATTRS, + _coerce_attribute_value, + _coerce_attributes, + _render_attribute_schema_block, + _schema_has_attributes, +) + + +# ── _render_attribute_schema_block ────────────────────────────── + + +class TestRenderAttributeSchemaBlock: + def test_empty_schema_renders_empty(self): + assert _render_attribute_schema_block(GraphSchema()) == "" + + def test_schema_with_no_attributes_renders_empty(self): + s = GraphSchema(entities=[EntityType(label="Person")]) + assert _render_attribute_schema_block(s) == "" + + def test_includes_only_types_with_declared_properties(self): + s = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[PropertyType(name="age", type="INTEGER")], + ), + EntityType(label="Company"), # no properties; should not appear + ], + ) + block = _render_attribute_schema_block(s) + assert "Person" in block + assert "age (INTEGER)" in block + # Bare Company without a colon shouldn't appear; the only "Company" + # in the rendered output would be as part of an entity bullet. + assert "- Company:" not in block + + def test_renders_relation_attributes(self): + s = GraphSchema( + entities=[EntityType(label="Person"), EntityType(label="Company")], + relations=[ + RelationType( + label="WORKS_AT", + properties=[PropertyType(name="since", type="DATE")], + ), + ], + ) + block = _render_attribute_schema_block(s) + assert "Relation attributes" in block + assert "WORKS_AT" in block + assert "since (DATE)" in block + + +class TestSchemaHasAttributes: + def test_empty_schema(self): + assert _schema_has_attributes(GraphSchema()) is False + + def test_only_entity_attrs(self): + s = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[PropertyType(name="age", type="INTEGER")], + ) + ] + ) + assert _schema_has_attributes(s) is True + + def test_only_relation_attrs(self): + s = GraphSchema( + relations=[ + RelationType( + label="WORKS_AT", + properties=[PropertyType(name="since", type="DATE")], + ) + ] + ) + assert _schema_has_attributes(s) is True + + +# ── _coerce_attribute_value ────────────────────────────────────── + + +class TestCoerceAttributeValue: + @pytest.mark.parametrize( + "value,prop_type,expected", + [ + ("56", "INTEGER", (True, 56)), + (56.7, "INTEGER", (True, 56)), + ("abc", "INTEGER", (False, None)), + (None, "INTEGER", (False, None)), + ("1867-11-07", "DATE", (True, "1867-11-07")), + ("1867-11-07T12:00:00", "DATE", (True, "1867-11-07")), + ("not a date", "DATE", (False, None)), + ("yes", "BOOLEAN", (True, True)), + ("False", "BOOLEAN", (True, False)), + ("maybe", "BOOLEAN", (False, None)), + (1, "BOOLEAN", (True, True)), + ("hello", "STRING", (True, "hello")), + (" ", "STRING", (False, None)), + (42, "STRING", (True, "42")), + (["a", "b"], "LIST", (True, ["a", "b"])), + ("solo", "LIST", (True, ["solo"])), + ({"bad": True}, "LIST", (False, None)), + ], + ) + def test_matrix(self, value, prop_type, expected): + assert _coerce_attribute_value(value, prop_type) == expected + + +class TestCoerceAttributes: + def test_every_declared_property_appears_in_result(self): + """Result shape is uniform: every declared key is present, with the + coerced value or None. Never drop the enclosing record.""" + declared = { + "age": PropertyType(name="age", type="INTEGER"), + "birth_date": PropertyType(name="birth_date", type="DATE"), + "nickname": PropertyType(name="nickname", type="STRING"), + } + result = _coerce_attributes( + {"age": "56", "birth_date": "1867-11-07"}, declared + ) + assert result == {"age": 56, "birth_date": "1867-11-07", "nickname": None} + + def test_uncoercible_value_becomes_none(self): + declared = {"age": PropertyType(name="age", type="INTEGER")} + assert _coerce_attributes({"age": "abc"}, declared) == {"age": None} + + def test_undeclared_keys_are_ignored(self): + declared = {"age": PropertyType(name="age", type="INTEGER")} + result = _coerce_attributes({"age": 56, "unknown": "x"}, declared) + assert result == {"age": 56} + + +# ── prompt template integration ────────────────────────────────── + + +class TestPromptTemplate: + def test_property_less_schema_keeps_block_empty(self): + s = GraphSchema(entities=[EntityType(label="Person")]) + prompt = VERIFY_EXTRACT_RELS_PROMPT.format( + entity_types="Person", + relation_patterns="", + attribute_schema_block=_render_attribute_schema_block(s), + relationship_type_instruction="- type: ...\n", + entities_json="[]", + text="...", + json_example=_DEFAULT_JSON_EXAMPLE, + ) + # No attribute-extraction section should appear. + assert "## Attribute extraction" not in prompt + # Original entity-shape example must be intact. + assert '"description": "..."' in prompt + + def test_schema_with_attributes_injects_section_and_example(self): + s = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[PropertyType(name="age", type="INTEGER")], + ), + ] + ) + prompt = VERIFY_EXTRACT_RELS_PROMPT.format( + entity_types="Person", + relation_patterns="", + attribute_schema_block=_render_attribute_schema_block(s), + relationship_type_instruction="- type: ...\n", + entities_json="[]", + text="...", + json_example=_JSON_EXAMPLE_WITH_ATTRS, + ) + assert "## Attribute extraction" in prompt + assert "age (INTEGER)" in prompt + # The example output schema must mention attributes so the LLM + # knows to include the field. + assert '"attributes"' in prompt diff --git a/graphrag_sdk/tests/test_cypher_generation.py b/graphrag_sdk/tests/test_cypher_generation.py index 58694158..2942287a 100644 --- a/graphrag_sdk/tests/test_cypher_generation.py +++ b/graphrag_sdk/tests/test_cypher_generation.py @@ -197,3 +197,95 @@ async def test_parses_result_rows(self): assert "Alice" in facts[0] assert "alice" in entities assert "bob" in entities + + +# ── Schema-aware prompt + validator ────────────────────────────── + + +class TestRenderSchemaBlock: + def test_open_schema_falls_back_to_historic_labels(self): + from graphrag_sdk.core.models import GraphSchema + from graphrag_sdk.retrieval.strategies.cypher_generation import ( + render_schema_block, + ) + block = render_schema_block(GraphSchema()) + assert "- Person" in block + assert "name (STRING)" in block + assert "rel_type (STRING)" in block + + def test_schema_block_includes_declared_attributes(self): + from graphrag_sdk.core.models import ( + EntityType, + GraphSchema, + PropertyType, + RelationType, + ) + from graphrag_sdk.retrieval.strategies.cypher_generation import ( + render_schema_block, + ) + s = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[ + PropertyType(name="age", type="INTEGER", description="years") + ], + ), + EntityType(label="Company"), + ], + relations=[ + RelationType( + label="WORKS_AT", + patterns=[("Person", "Company")], + properties=[PropertyType(name="since", type="DATE")], + ), + ], + ) + block = render_schema_block(s) + assert "age (INTEGER)" in block + assert "since (DATE)" in block + assert "WORKS_AT" in block + + +class TestBuildSchemaPrompt: + def test_includes_question_and_schema(self): + from graphrag_sdk.core.models import ( + EntityType, + GraphSchema, + PropertyType, + ) + from graphrag_sdk.retrieval.strategies.cypher_generation import ( + build_schema_prompt, + ) + s = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[PropertyType(name="age", type="INTEGER")], + ) + ] + ) + prompt = build_schema_prompt(s, "Who is older than 30?") + assert "Who is older than 30?" in prompt + assert "age (INTEGER)" in prompt + assert ".age" in prompt # synthesized numeric-filter example + + +class TestValidateCypherWithSchema: + def test_unknown_label_flagged_when_schema_provided(self): + from graphrag_sdk.core.models import EntityType, GraphSchema + s = GraphSchema(entities=[EntityType(label="Person")]) + errors = validate_cypher("MATCH (x:Bogus) RETURN x LIMIT 10", s) + assert any("Unknown label: Bogus" in e for e in errors) + + def test_declared_label_accepted(self): + from graphrag_sdk.core.models import EntityType, GraphSchema + s = GraphSchema(entities=[EntityType(label="Customer")]) + errors = validate_cypher( + "MATCH (c:Customer) RETURN c.name LIMIT 10", s + ) + assert errors == [] + + def test_no_schema_falls_back_to_historic_labels(self): + errors = validate_cypher("MATCH (p:Person) RETURN p LIMIT 10") + assert errors == [] diff --git a/graphrag_sdk/tests/test_facade.py b/graphrag_sdk/tests/test_facade.py index 78072b4f..b5b943ea 100644 --- a/graphrag_sdk/tests/test_facade.py +++ b/graphrag_sdk/tests/test_facade.py @@ -35,6 +35,15 @@ def mock_conn(): result_mock.result_set = [] conn.query = AsyncMock(return_value=result_mock) conn.config = ConnectionConfig() + # OntologyStore reaches into the connection's driver to open a handle on + # the paired ``__ontology`` graph. Stub the chain so the + # ontology graph appears as a writable async-query target whose queries + # return empty result sets. + ontology_graph = MagicMock() + ontology_graph.query = AsyncMock(return_value=result_mock) + conn._driver = MagicMock() + conn._driver.select_graph = MagicMock(return_value=ontology_graph) + conn._ensure_client = MagicMock() return conn diff --git a/graphrag_sdk/tests/test_graph_extraction.py b/graphrag_sdk/tests/test_graph_extraction.py index a7c77937..7041ecbc 100644 --- a/graphrag_sdk/tests/test_graph_extraction.py +++ b/graphrag_sdk/tests/test_graph_extraction.py @@ -712,3 +712,127 @@ def invoke(self, prompt, **kwargs): step2_prompt = captured_prompts[1] assert "## Allowed Relationships" in step2_prompt assert "WORKS_AT (Person \u2192 Company): Works at" in step2_prompt + + +class TestGraphExtractionSchemaAttributes: + """Coverage for schema-declared attributes flowing through extraction.""" + + def test_parse_step2_coerces_declared_attributes(self): + from graphrag_sdk.core.models import PropertyType + content = json.dumps({ + "entities": [ + { + "name": "Marie Curie", + "type": "Person", + "description": "Scientist", + "attributes": {"age": "56", "birth_date": "1867-11-07"}, + }, + ], + "relationships": [], + }) + schema = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[ + PropertyType(name="age", type="INTEGER"), + PropertyType(name="birth_date", type="DATE"), + ], + ) + ] + ) + ents, _ = GraphExtraction._parse_step2_response( + content, ["Person"], "c1", schema + ) + assert len(ents) == 1 + assert ents[0].attributes == {"age": 56, "birth_date": "1867-11-07"} + + def test_parse_step2_missing_attribute_becomes_none(self): + """Missing values are represented as ``None`` in attributes — the + entity is NEVER dropped. Storage strips ``None`` before writing so + the graph sees "key missing", which is the right null semantics.""" + from graphrag_sdk.core.models import PropertyType + content = json.dumps({ + "entities": [ + { + "name": "Marie Curie", + "type": "Person", + "description": "", + "attributes": {"age": 56}, + }, + ], + "relationships": [], + }) + schema = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[ + PropertyType(name="age", type="INTEGER"), + PropertyType(name="birth_date", type="DATE"), + ], + ) + ] + ) + ents, _ = GraphExtraction._parse_step2_response( + content, ["Person"], "c1", schema + ) + assert len(ents) == 1 + assert ents[0].attributes == {"age": 56, "birth_date": None} + + def test_aggregator_carries_attributes_with_last_write_wins(self): + from graphrag_sdk.core.models import ExtractedEntity + e1 = ExtractedEntity( + name="Marie Curie", + type="Person", + description="", + source_chunk_ids=["c1"], + attributes={"age": 56, "country": "France"}, + ) + e2 = ExtractedEntity( + name="Marie Curie", + type="Person", + description="", + source_chunk_ids=["c2"], + attributes={"age": 58, "birth_year": 1867}, + ) + merged = GraphExtraction._aggregate_entities([e1, e2]) + assert len(merged) == 1 + # `age` is overwritten last-write-wins; the other keys are unioned. + assert merged[0].attributes == { + "age": 58, + "country": "France", + "birth_year": 1867, + } + + def test_attributes_merged_into_node_properties(self): + from graphrag_sdk.core.models import ExtractedEntity + ent = ExtractedEntity( + name="Marie Curie", + type="Person", + description="d", + source_chunk_ids=["c1"], + attributes={"age": 56, "birth_date": "1867-11-07"}, + ) + nodes = GraphExtraction._entities_to_nodes([ent]) + assert nodes[0].properties["age"] == 56 + assert nodes[0].properties["birth_date"] == "1867-11-07" + # Reserved keys still present. + assert nodes[0].properties["name"] == "Marie Curie" + + def test_property_less_schema_keeps_attributes_empty(self): + """Deterministic-extractor regression: schema declares no + properties \u2192 records have empty attributes \u2192 storage gets only the + reserved keys.""" + content = json.dumps({ + "entities": [ + {"name": "Alice", "type": "Person", "description": "A person"}, + ], + "relationships": [], + }) + schema = GraphSchema(entities=[EntityType(label="Person")]) + ents, _ = GraphExtraction._parse_step2_response( + content, ["Person"], "c1", schema + ) + assert len(ents) == 1 + assert ents[0].attributes == {} diff --git a/graphrag_sdk/tests/test_models.py b/graphrag_sdk/tests/test_models.py index b4fdd022..7a177024 100644 --- a/graphrag_sdk/tests/test_models.py +++ b/graphrag_sdk/tests/test_models.py @@ -136,10 +136,118 @@ def test_document_output(self): class TestSchemaTypes: def test_property_type(self): - pt = PropertyType(name="age", type="INT", required=True) + pt = PropertyType(name="age", type="INTEGER") assert pt.name == "age" - assert pt.type == "INT" - assert pt.required is True + assert pt.type == "INTEGER" + + def test_property_type_normalizes_case(self): + pt = PropertyType(name="age", type="integer") + assert pt.type == "INTEGER" + + def test_property_type_rejects_unknown_type(self): + with pytest.raises(ValidationError): + PropertyType(name="age", type="OBJECT") + + def test_relation_type_properties_default_empty(self): + rt = RelationType(label="WORKS_AT") + assert rt.properties == [] + + def test_relation_type_with_properties(self): + rt = RelationType( + label="WORKS_AT", + properties=[PropertyType(name="since", type="DATE")], + ) + assert len(rt.properties) == 1 + assert rt.properties[0].type == "DATE" + + def test_graph_schema_rejects_reserved_entity_property_name(self): + with pytest.raises(ValidationError): + GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[PropertyType(name="name", type="STRING")], + ) + ] + ) + + def test_graph_schema_rejects_reserved_relation_property_name(self): + with pytest.raises(ValidationError): + GraphSchema( + relations=[ + RelationType( + label="WORKS_AT", + properties=[PropertyType(name="rel_type", type="STRING")], + ) + ] + ) + + def test_graph_schema_roundtrips_through_json_file(self, tmp_path): + schema = GraphSchema( + entities=[ + EntityType( + label="Person", + description="A human", + properties=[ + PropertyType(name="age", type="INTEGER"), + PropertyType(name="birth_date", type="DATE"), + ], + ) + ], + relations=[ + RelationType( + label="WORKS_AT", + patterns=[("Person", "Company")], + properties=[PropertyType(name="since", type="DATE")], + ), + ], + ) + path = tmp_path / "ontology.json" + schema.save_to_file(str(path)) + # File is real JSON and round-trippable. + loaded = GraphSchema.from_file(str(path)) + assert loaded == schema + # Indentation default is set so files are diff-friendly. + text = path.read_text() + assert "\n" in text + assert " " in text + + def test_graph_schema_merge_unions_entities_relations_and_properties(self): + a = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[PropertyType(name="age", type="INTEGER")], + ) + ], + relations=[ + RelationType(label="WORKS_AT", patterns=[("Person", "Company")]), + ], + ) + b = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[PropertyType(name="dob", type="DATE")], + ), + EntityType(label="Company"), + ], + relations=[ + RelationType( + label="WORKS_AT", + patterns=[("Person", "Organization")], + properties=[PropertyType(name="since", type="DATE")], + ), + ], + ) + merged = a.merge(b) + assert {e.label for e in merged.entities} == {"Person", "Company"} + person = next(e for e in merged.entities if e.label == "Person") + assert {p.name for p in person.properties} == {"age", "dob"} + works = next(r for r in merged.relations if r.label == "WORKS_AT") + assert ("Person", "Company") in works.patterns + assert ("Person", "Organization") in works.patterns + assert {p.name for p in works.properties} == {"since"} def test_entity_type(self): et = EntityType(label="Person", description="A human") diff --git a/graphrag_sdk/tests/test_ontology_store.py b/graphrag_sdk/tests/test_ontology_store.py new file mode 100644 index 00000000..560f4978 --- /dev/null +++ b/graphrag_sdk/tests/test_ontology_store.py @@ -0,0 +1,318 @@ +"""Tests for storage/ontology_store.py — persistent ontology graph. + +The store talks to FalkorDB directly; unit tests here mock the graph handle +through ``FalkorDBConnection``'s ``_driver.select_graph()`` seam. Real-FalkorDB +exercise is left for the integration suite. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from graphrag_sdk.core.models import ( + EntityType, + GraphSchema, + PropertyType, + RelationType, +) +from graphrag_sdk.storage.ontology_store import ( + OntologyContradictionError, + OntologyStore, + _decode_patterns, + _encode_patterns, + _props_from_rows, +) + + +class _FakeQueryResult: + """Stand-in for FalkorDB's QueryResult.""" + + def __init__(self, rows): + self.result_set = rows + + +class _FakeGraph: + """In-memory async ``query()`` substitute. Records calls and serves canned + responses for the load + read-existing-patterns queries.""" + + def __init__(self): + self.calls: list[tuple[str, dict | None]] = [] + self._ent_rows: list[list] = [] + self._rel_rows: list[list] = [] + self._patterns_for_label: dict[str, list[str]] = {} + + def set_load_response(self, entity_rows, relation_rows): + self._ent_rows = entity_rows + self._rel_rows = relation_rows + + def set_existing_patterns(self, label: str, patterns: list[str]): + self._patterns_for_label[label] = patterns + + async def query(self, cypher: str, params: dict | None = None): + self.calls.append((cypher, params)) + if "MATCH (e:OntologyEntityType)" in cypher and "collect" in cypher: + return _FakeQueryResult(self._ent_rows) + if "MATCH (r:OntologyRelationType)" in cypher and "collect" in cypher: + return _FakeQueryResult(self._rel_rows) + if "MATCH (r:OntologyRelationType {label: $label})" in cypher and "patterns" in cypher: + label = (params or {}).get("label", "") + return _FakeQueryResult([[self._patterns_for_label.get(label, [])]]) + return _FakeQueryResult([]) + + +@pytest.fixture +def fake_graph(): + return _FakeGraph() + + +@pytest.fixture +def store_factory(fake_graph): + """Returns a callable producing an ``OntologyStore`` wired to ``fake_graph``.""" + + def _make(data_graph_name: str = "kg") -> OntologyStore: + conn = MagicMock() + conn._ensure_client = MagicMock() + conn._driver = SimpleNamespace(select_graph=MagicMock(return_value=fake_graph)) + return OntologyStore(conn, data_graph_name) + + return _make + + +# ── small helpers ──────────────────────────────────────────────── + + +class TestEncoders: + def test_roundtrip(self): + patterns = [("Person", "Company"), ("Person", "Organization")] + encoded = _encode_patterns(patterns) + assert encoded == ["Person|Company", "Person|Organization"] + assert _decode_patterns(encoded) == patterns + + def test_decode_tolerates_garbage(self): + assert _decode_patterns(None) == [] + assert _decode_patterns(["no-pipe", "a|b"]) == [("a", "b")] + + +class TestPropsFromRows: + def test_filters_null_keyed_rows(self): + rows = [ + {"name": "age", "type": "INTEGER", "description": None}, + {"name": None, "type": None, "description": None}, # optional-match empty + None, + "not a dict", + ] + result = _props_from_rows(rows) + assert [(p.name, p.type) for p in result] == [("age", "INTEGER")] + + +# ── store identity ─────────────────────────────────────────────── + + +class TestOntologyStoreGraphName: + def test_suffix(self, store_factory): + store = store_factory("my_kg") + assert store.graph_name == "my_kg__ontology" + + +# ── register / load ────────────────────────────────────────────── + + +class TestRegister: + @pytest.mark.asyncio + async def test_empty_schema_short_circuits_to_load(self, store_factory, fake_graph): + store = store_factory() + result = await store.register(GraphSchema()) + # No upsert queries — only the two load queries. + upserts = [c for c in fake_graph.calls if "MERGE" in c[0]] + assert upserts == [] + assert isinstance(result, GraphSchema) + + @pytest.mark.asyncio + async def test_persists_entity_type_with_properties(self, store_factory, fake_graph): + store = store_factory() + schema = GraphSchema( + entities=[ + EntityType( + label="Person", + description="A human", + properties=[ + PropertyType(name="age", type="INTEGER"), + PropertyType(name="birth_date", type="DATE"), + ], + ), + ], + ) + await store.register(schema) + ent_merges = [c for c in fake_graph.calls if "MERGE (e:OntologyEntityType" in c[0]] + prop_merges = [c for c in fake_graph.calls if "MERGE (ent)-[:HAS_PROPERTY]->" in c[0]] + assert len(ent_merges) == 1 + assert len(prop_merges) == 2 + prop_names = {(c[1] or {}).get("name") for c in prop_merges} + assert prop_names == {"age", "birth_date"} + + @pytest.mark.asyncio + async def test_unions_relation_patterns(self, store_factory, fake_graph): + store = store_factory() + fake_graph.set_existing_patterns("WORKS_AT", ["Person|Company"]) + schema = GraphSchema( + entities=[EntityType(label="Person"), EntityType(label="Org")], + relations=[RelationType(label="WORKS_AT", patterns=[("Person", "Org")])], + ) + await store.register(schema) + rel_set_calls = [ + c + for c in fake_graph.calls + if "MERGE (r:OntologyRelationType {label: $label})" in c[0] and "SET r." in c[0] + ] + assert rel_set_calls, "expected SET on RelationType" + patterns = (rel_set_calls[-1][1] or {})["patterns"] + assert "Person|Company" in patterns + assert "Person|Org" in patterns + + +class TestLoad: + @pytest.mark.asyncio + async def test_empty_graph_yields_empty_schema(self, store_factory, fake_graph): + store = store_factory() + fake_graph.set_load_response([], []) + schema = await store.load() + assert schema.entities == [] + assert schema.relations == [] + + @pytest.mark.asyncio + async def test_introspection_failure_returns_empty(self, store_factory, fake_graph): + async def boom(cypher, params=None): + raise RuntimeError("connection blew up") + + fake_graph.query = boom + store = store_factory() + schema = await store.load() + assert schema.entities == [] + assert schema.relations == [] + + @pytest.mark.asyncio + async def test_reconstructs_schema(self, store_factory, fake_graph): + store = store_factory() + fake_graph.set_load_response( + entity_rows=[ + [ + "Person", + "A human", + [ + {"name": "age", "type": "INTEGER", "description": None}, + # OPTIONAL MATCH empty row — must be filtered. + {"name": None, "type": None, "description": None}, + ], + ], + ["Company", None, []], + ], + relation_rows=[ + [ + "WORKS_AT", + "Employment", + ["Person|Company"], + [{"name": "since", "type": "DATE", "description": None}], + ], + ], + ) + schema = await store.load() + assert {e.label for e in schema.entities} == {"Person", "Company"} + person = next(e for e in schema.entities if e.label == "Person") + assert [(p.name, p.type) for p in person.properties] == [("age", "INTEGER")] + works = next(r for r in schema.relations if r.label == "WORKS_AT") + assert works.patterns == [("Person", "Company")] + assert [(p.name, p.type) for p in works.properties] == [("since", "DATE")] + + +# ── contradiction validation ───────────────────────────────────── + + +class TestContradictionDetection: + """Additive-only schema: re-typing an existing property is rejected.""" + + @pytest.mark.asyncio + async def test_compatible_addition_is_accepted(self, store_factory, fake_graph): + store = store_factory() + fake_graph.set_load_response( + entity_rows=[ + [ + "Person", + None, + [{"name": "age", "type": "INTEGER", "description": None}], + ], + ], + relation_rows=[], + ) + incoming = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[ + PropertyType(name="age", type="INTEGER"), # same as before + PropertyType(name="birth_date", type="DATE"), # new property + ], + ), + ], + ) + await store.register(incoming) # must not raise + + @pytest.mark.asyncio + async def test_redefining_entity_property_type_is_rejected( + self, store_factory, fake_graph + ): + store = store_factory() + fake_graph.set_load_response( + entity_rows=[ + [ + "Person", + None, + [{"name": "age", "type": "INTEGER", "description": None}], + ], + ], + relation_rows=[], + ) + incoming = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[PropertyType(name="age", type="STRING")], + ), + ], + ) + with pytest.raises(OntologyContradictionError) as exc: + await store.register(incoming) + assert "Person.age" in str(exc.value) + # And no MERGE-on-EntityType happened (validation runs before persistence). + upserts = [c for c in fake_graph.calls if "MERGE (e:OntologyEntityType" in c[0]] + assert upserts == [] + + @pytest.mark.asyncio + async def test_redefining_relation_property_type_is_rejected( + self, store_factory, fake_graph + ): + store = store_factory() + fake_graph.set_load_response( + entity_rows=[], + relation_rows=[ + [ + "WORKS_AT", + None, + [], + [{"name": "since", "type": "DATE", "description": None}], + ], + ], + ) + incoming = GraphSchema( + relations=[ + RelationType( + label="WORKS_AT", + properties=[PropertyType(name="since", type="STRING")], + ), + ], + ) + with pytest.raises(OntologyContradictionError) as exc: + await store.register(incoming) + assert "WORKS_AT.since" in str(exc.value) diff --git a/graphrag_sdk/tests/test_pipeline.py b/graphrag_sdk/tests/test_pipeline.py index 3945fcf1..8e020c07 100644 --- a/graphrag_sdk/tests/test_pipeline.py +++ b/graphrag_sdk/tests/test_pipeline.py @@ -694,3 +694,142 @@ def test_prune_pattern_mismatch_sample_is_bounded(self, caplog): # Total count is reported, but the sampled list does not contain 50 entries. assert "Pruned 50" in msg assert msg.count("('Company', 'Person')") <= 3 + + +class TestValidateAttributes: + """Schema-attribute validation pass on top of ``_prune``.""" + + def _pipeline(self): + # We only exercise the validator; bypass construction by mocking the + # required strategy collaborators. + return IngestionPipeline( + loader=MagicMock(spec=LoaderStrategy), + chunker=MagicMock(spec=ChunkingStrategy), + extractor=MagicMock(spec=ExtractionStrategy), + resolver=MagicMock(spec=ResolutionStrategy), + graph_store=MagicMock(), + vector_store=MagicMock(), + schema=GraphSchema(), + ) + + def test_empty_schema_is_noop(self): + from graphrag_sdk.core.models import PropertyType + pipe = self._pipeline() + gd = GraphData( + nodes=[GraphNode(id="1", label="Foo", properties={"a": 1})], + ) + out = pipe._validate_attributes(gd, GraphSchema()) + assert len(out.nodes) == 1 + assert out.nodes[0].properties == {"a": 1} + + def test_unknown_attributes_dropped_on_declared_label(self): + from graphrag_sdk.core.models import PropertyType + pipe = self._pipeline() + schema = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[PropertyType(name="age", type="INTEGER")], + ) + ] + ) + node = GraphNode( + id="p1", + label="Person", + properties={ + "name": "Marie", + "age": 56, + "ssn": "123-45-6789", + }, + ) + out = pipe._validate_attributes(GraphData(nodes=[node]), schema) + assert len(out.nodes) == 1 + kept = out.nodes[0].properties + assert "ssn" not in kept + assert kept["age"] == 56 + # Reserved keys preserved. + assert kept["name"] == "Marie" + + def test_missing_attribute_keeps_node(self): + """Missing declared attributes don't drop the node — they stay + missing on the graph, which is the right null semantics for + retrieval (``WHERE p.age > N`` naturally excludes them).""" + from graphrag_sdk.core.models import PropertyType + pipe = self._pipeline() + schema = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[ + PropertyType(name="age", type="INTEGER"), + PropertyType(name="birth_date", type="DATE"), + ], + ) + ] + ) + nodes = [ + GraphNode( + id="full", + label="Person", + properties={"name": "Marie", "age": 56, "birth_date": "1867-11-07"}, + ), + GraphNode( + id="partial", + label="Person", + properties={"name": "Pierre", "age": 39}, + ), + ] + out = pipe._validate_attributes(GraphData(nodes=nodes), schema) + assert sorted(n.id for n in out.nodes) == ["full", "partial"] + + def test_unknown_node_label_passes_through(self): + """``"Unknown"`` typed nodes (preserved by ``_prune``) must not be + validated against any declared properties.""" + from graphrag_sdk.core.models import PropertyType + pipe = self._pipeline() + schema = GraphSchema( + entities=[ + EntityType( + label="Person", + properties=[PropertyType(name="age", type="INTEGER")], + ) + ] + ) + nodes = [GraphNode(id="u", label="Unknown", properties={"name": "x"})] + out = pipe._validate_attributes(GraphData(nodes=nodes), schema) + assert len(out.nodes) == 1 + assert out.nodes[0].id == "u" + + def test_undeclared_relationship_attribute_keys_are_stripped(self): + from graphrag_sdk.core.models import PropertyType + pipe = self._pipeline() + schema = GraphSchema( + entities=[EntityType(label="Person"), EntityType(label="Company")], + relations=[ + RelationType( + label="WORKS_AT", + properties=[PropertyType(name="since", type="DATE")], + ) + ], + ) + rels = [ + GraphRelationship( + start_node_id="a", + end_node_id="b", + type="RELATES", + properties={ + "rel_type": "WORKS_AT", + "fact": "f1", + "since": "2020-01-01", + "leaked_key": "noise", + }, + ), + ] + out = pipe._validate_attributes(GraphData(relationships=rels), schema) + assert len(out.relationships) == 1 + props = out.relationships[0].properties + assert "leaked_key" not in props + assert props["since"] == "2020-01-01" + # Reserved keys still preserved. + assert props["rel_type"] == "WORKS_AT" + assert props["fact"] == "f1"