From a6f7fe805cd232862bc369b695bd7032ee1b79e8 Mon Sep 17 00:00:00 2001 From: Naseem Ali <34807727+Naseem77@users.noreply.github.com> Date: Mon, 18 May 2026 16:54:56 +0300 Subject: [PATCH 1/8] Enforce provider async timeouts Add typed timeout errors for LLM and embedding calls and wrap async provider operations with asyncio.wait_for. Cover base, LiteLLM, and OpenRouter async paths with regression tests.\n\nCo-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../src/graphrag_sdk/core/exceptions.py | 12 ++ .../src/graphrag_sdk/core/providers/_retry.py | 6 + .../src/graphrag_sdk/core/providers/base.py | 92 +++++++++++-- .../graphrag_sdk/core/providers/litellm.py | 82 ++++++++++-- .../graphrag_sdk/core/providers/openrouter.py | 82 ++++++++++-- graphrag_sdk/tests/test_exceptions.py | 10 ++ graphrag_sdk/tests/test_providers.py | 126 +++++++++++++++++- 7 files changed, 377 insertions(+), 33 deletions(-) diff --git a/graphrag_sdk/src/graphrag_sdk/core/exceptions.py b/graphrag_sdk/src/graphrag_sdk/core/exceptions.py index d8942aeb..2e98a384 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/exceptions.py +++ b/graphrag_sdk/src/graphrag_sdk/core/exceptions.py @@ -19,12 +19,24 @@ class LLMError(GraphRAGError): pass +class LLMTimeoutError(LLMError): + """Raised when an LLM provider call exceeds its configured timeout.""" + + pass + + class EmbeddingError(GraphRAGError): """Raised when an embedding provider call fails.""" pass +class EmbeddingTimeoutError(EmbeddingError): + """Raised when an embedding provider call exceeds its configured timeout.""" + + pass + + # ── Ingestion Errors ───────────────────────────────────────────── diff --git a/graphrag_sdk/src/graphrag_sdk/core/providers/_retry.py b/graphrag_sdk/src/graphrag_sdk/core/providers/_retry.py index 13a4305b..4e2ad026 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/providers/_retry.py +++ b/graphrag_sdk/src/graphrag_sdk/core/providers/_retry.py @@ -7,6 +7,8 @@ from collections.abc import Awaitable, Callable from typing import Any +from graphrag_sdk.core.exceptions import EmbeddingTimeoutError + logger = logging.getLogger(__name__) _MAX_EXC_SUMMARY_LEN = 200 @@ -65,6 +67,8 @@ def binary_split_retry_sync( """ try: return embed_fn(texts, **kwargs) + except EmbeddingTimeoutError: + raise except Exception as exc: if not is_transient_embedding_error(exc): raise @@ -85,6 +89,8 @@ async def binary_split_retry_async( """Async variant of :func:`binary_split_retry_sync`.""" try: return await embed_fn(texts, **kwargs) + except EmbeddingTimeoutError: + raise except Exception as exc: if not is_transient_embedding_error(exc): raise diff --git a/graphrag_sdk/src/graphrag_sdk/core/providers/base.py b/graphrag_sdk/src/graphrag_sdk/core/providers/base.py index cd7a31da..53dacc60 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/providers/base.py +++ b/graphrag_sdk/src/graphrag_sdk/core/providers/base.py @@ -13,12 +13,35 @@ from pydantic import BaseModel +from graphrag_sdk.core.exceptions import EmbeddingTimeoutError, LLMTimeoutError from graphrag_sdk.core.models import ChatMessage, LLMResponse from graphrag_sdk.core.providers._retry import summarize_exception logger = logging.getLogger(__name__) +async def _wait_for_provider_call( + awaitable: Any, + *, + timeout: float | None, + timeout_error: type[LLMTimeoutError] | type[EmbeddingTimeoutError], + operation: str, +) -> Any: + if timeout is None: + return await awaitable + if timeout <= 0: + raise ValueError("timeout must be > 0") + try: + return await asyncio.wait_for(awaitable, timeout=timeout) + except (TimeoutError, asyncio.TimeoutError) as exc: + raise timeout_error(f"{operation} timed out after {timeout:.3g}s") from exc + + +def _validate_timeout(timeout: float | None) -> None: + if timeout is not None and timeout <= 0: + raise ValueError("timeout must be > 0") + + @dataclass class LLMBatchItem: """Result of one item in a batch invocation.""" @@ -60,17 +83,41 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]: """Embed a single text string into a float vector.""" ... - async def aembed_query(self, text: str, **kwargs: Any) -> list[float]: + async def aembed_query( + self, + text: str, + *, + timeout: float | None = None, + **kwargs: Any, + ) -> list[float]: """Async variant — defaults to sync-in-thread.""" - return await asyncio.to_thread(self.embed_query, text, **kwargs) + _validate_timeout(timeout) + return await _wait_for_provider_call( + asyncio.to_thread(self.embed_query, text, **kwargs), + timeout=timeout, + timeout_error=EmbeddingTimeoutError, + operation=f"embedding query with {self.model_name}", + ) def embed_documents(self, texts: list[str], **kwargs: Any) -> list[list[float]]: """Batch embed multiple texts. Default: sequential fallback.""" return [self.embed_query(t, **kwargs) for t in texts] - async def aembed_documents(self, texts: list[str], **kwargs: Any) -> list[list[float]]: + async def aembed_documents( + self, + texts: list[str], + *, + timeout: float | None = None, + **kwargs: Any, + ) -> list[list[float]]: """Async batch embed. Default: sync-in-thread.""" - return await asyncio.to_thread(self.embed_documents, texts, **kwargs) + _validate_timeout(timeout) + return await _wait_for_provider_call( + asyncio.to_thread(self.embed_documents, texts, **kwargs), + timeout=timeout, + timeout_error=EmbeddingTimeoutError, + operation=f"embedding documents with {self.model_name}", + ) class LLMInterface(ABC): @@ -106,6 +153,7 @@ async def ainvoke( prompt: str, *, max_retries: int = 3, + timeout: float | None = None, **kwargs: Any, ) -> LLMResponse: """Async variant with retry + jittered exponential backoff. @@ -113,10 +161,20 @@ async def ainvoke( Retries on any exception up to ``max_retries`` times with jittered delays between attempts. """ + if max_retries < 1: + raise ValueError("max_retries must be >= 1") + _validate_timeout(timeout) last_exc: Exception | None = None for attempt in range(max_retries): try: - return await asyncio.to_thread(self.invoke, prompt, **kwargs) + return await _wait_for_provider_call( + asyncio.to_thread(self.invoke, prompt, **kwargs), + timeout=timeout, + timeout_error=LLMTimeoutError, + operation=f"LLM call to {self.model_name}", + ) + except LLMTimeoutError: + raise except Exception as exc: last_exc = exc if attempt < max_retries - 1: @@ -137,6 +195,7 @@ async def ainvoke_messages( messages: list[ChatMessage], *, max_retries: int = 3, + timeout: float | None = None, **kwargs: Any, ) -> LLMResponse: """Invoke the LLM with a list of structured chat messages. @@ -160,7 +219,12 @@ async def ainvoke_messages( for msg in messages: parts.append(f"{msg.role.capitalize()}: {msg.content}") prompt = "\n\n".join(parts) - return await self.ainvoke(prompt, max_retries=max_retries, **kwargs) + return await self.ainvoke( + prompt, + max_retries=max_retries, + timeout=timeout, + **kwargs, + ) async def astream(self, prompt: str, **kwargs: Any) -> AsyncIterator[str]: """Async streaming — default yields the full response as one chunk.""" @@ -187,10 +251,16 @@ async def ainvoke_with_model( response_model: type[BaseModel], *, max_retries: int = 3, + timeout: float | None = None, **kwargs: Any, ) -> BaseModel: """Async structured output invocation with retry.""" - response = await self.ainvoke(prompt, max_retries=max_retries, **kwargs) + response = await self.ainvoke( + prompt, + max_retries=max_retries, + timeout=timeout, + **kwargs, + ) return response_model.model_validate_json(response.content) async def abatch_invoke( @@ -199,6 +269,7 @@ async def abatch_invoke( *, max_concurrency: int | None = None, max_retries: int = 3, + timeout: float | None = None, **kwargs: Any, ) -> list[LLMBatchItem]: """Invoke LLM on multiple prompts concurrently with per-item error handling. @@ -220,7 +291,12 @@ async def abatch_invoke( async def _call(i: int, prompt: str) -> LLMBatchItem: async with sem: try: - resp = await self.ainvoke(prompt, max_retries=max_retries, **kwargs) + resp = await self.ainvoke( + prompt, + max_retries=max_retries, + timeout=timeout, + **kwargs, + ) return LLMBatchItem(index=i, response=resp) except Exception as exc: return LLMBatchItem(index=i, error=exc) diff --git a/graphrag_sdk/src/graphrag_sdk/core/providers/litellm.py b/graphrag_sdk/src/graphrag_sdk/core/providers/litellm.py index 8d5e8331..6360ec3c 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/providers/litellm.py +++ b/graphrag_sdk/src/graphrag_sdk/core/providers/litellm.py @@ -8,13 +8,19 @@ import logging from typing import Any +from graphrag_sdk.core.exceptions import EmbeddingTimeoutError, LLMTimeoutError from graphrag_sdk.core.models import ChatMessage, LLMResponse from graphrag_sdk.core.providers._retry import ( binary_split_retry_async, binary_split_retry_sync, summarize_exception, ) -from graphrag_sdk.core.providers.base import Embedder, LLMInterface +from graphrag_sdk.core.providers.base import ( + Embedder, + LLMInterface, + _validate_timeout, + _wait_for_provider_call, +) logger = logging.getLogger(__name__) @@ -116,8 +122,12 @@ async def ainvoke( prompt: str, *, max_retries: int = 3, + timeout: float | None = None, **kwargs: Any, ) -> LLMResponse: + if max_retries < 1: + raise ValueError("max_retries must be >= 1") + _validate_timeout(timeout) try: import litellm except ImportError: @@ -128,9 +138,16 @@ async def ainvoke( last_exc: Exception | None = None for attempt in range(max_retries): try: - response = await litellm.acompletion(**self._completion_kwargs(prompt, **kwargs)) + response = await _wait_for_provider_call( + litellm.acompletion(**self._completion_kwargs(prompt, **kwargs)), + timeout=timeout, + timeout_error=LLMTimeoutError, + operation=f"LiteLLM call to {self.model_name}", + ) content = response.choices[0].message.content or "" return LLMResponse(content=content) + except LLMTimeoutError: + raise except Exception as exc: last_exc = exc if attempt < max_retries - 1: @@ -183,11 +200,13 @@ async def ainvoke_messages( messages: list[ChatMessage], *, max_retries: int = 3, + timeout: float | None = None, **kwargs: Any, ) -> LLMResponse: """Native multi-turn completion via LiteLLM.""" if max_retries < 1: raise ValueError("max_retries must be >= 1") + _validate_timeout(timeout) try: import litellm except ImportError: @@ -198,11 +217,16 @@ async def ainvoke_messages( last_exc: Exception | None = None for attempt in range(max_retries): try: - response = await litellm.acompletion( - **self._messages_completion_kwargs(messages, **kwargs) + response = await _wait_for_provider_call( + litellm.acompletion(**self._messages_completion_kwargs(messages, **kwargs)), + timeout=timeout, + timeout_error=LLMTimeoutError, + operation=f"LiteLLM messages call to {self.model_name}", ) content = response.choices[0].message.content or "" return LLMResponse(content=content) + except LLMTimeoutError: + raise except Exception as exc: last_exc = exc if attempt < max_retries - 1: @@ -291,10 +315,22 @@ def _raw_embed_sync(self, texts: list[str], **kwargs: Any) -> list[list[float]]: sorted_data = sorted(response.data, key=lambda x: x["index"]) return [d["embedding"] for d in sorted_data] - async def _raw_embed_async(self, texts: list[str], **kwargs: Any) -> list[list[float]]: + async def _raw_embed_async( + self, + texts: list[str], + *, + timeout: float | None = None, + **kwargs: Any, + ) -> list[list[float]]: """Raw async embed without retry — called by binary_split_retry_async.""" + _validate_timeout(timeout) litellm = self._import_litellm() - response = await litellm.aembedding(**self._embedding_kwargs(texts, **kwargs)) + response = await _wait_for_provider_call( + litellm.aembedding(**self._embedding_kwargs(texts, **kwargs)), + timeout=timeout, + timeout_error=EmbeddingTimeoutError, + operation=f"LiteLLM embedding batch with {self.model}", + ) sorted_data = sorted(response.data, key=lambda x: x["index"]) return [d["embedding"] for d in sorted_data] @@ -307,16 +343,42 @@ def embed_documents(self, texts: list[str], **kwargs: Any) -> list[list[float]]: results.extend(binary_split_retry_sync(self._raw_embed_sync, batch, **kwargs)) return results - async def aembed_query(self, text: str, **kwargs: Any) -> list[float]: + async def aembed_query( + self, + text: str, + *, + timeout: float | None = None, + **kwargs: Any, + ) -> list[float]: + _validate_timeout(timeout) litellm = self._import_litellm() - response = await litellm.aembedding(**self._embedding_kwargs(text, **kwargs)) + response = await _wait_for_provider_call( + litellm.aembedding(**self._embedding_kwargs(text, **kwargs)), + timeout=timeout, + timeout_error=EmbeddingTimeoutError, + operation=f"LiteLLM embedding query with {self.model}", + ) return response.data[0]["embedding"] - async def aembed_documents(self, texts: list[str], **kwargs: Any) -> list[list[float]]: + async def aembed_documents( + self, + texts: list[str], + *, + timeout: float | None = None, + **kwargs: Any, + ) -> list[list[float]]: + _validate_timeout(timeout) if not texts: return [] results: list[list[float]] = [] for start in range(0, len(texts), self.batch_size): batch = texts[start : start + self.batch_size] - results.extend(await binary_split_retry_async(self._raw_embed_async, batch, **kwargs)) + results.extend( + await binary_split_retry_async( + self._raw_embed_async, + batch, + timeout=timeout, + **kwargs, + ) + ) return results diff --git a/graphrag_sdk/src/graphrag_sdk/core/providers/openrouter.py b/graphrag_sdk/src/graphrag_sdk/core/providers/openrouter.py index bb075524..42b64b3b 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/providers/openrouter.py +++ b/graphrag_sdk/src/graphrag_sdk/core/providers/openrouter.py @@ -9,13 +9,19 @@ import os from typing import Any +from graphrag_sdk.core.exceptions import EmbeddingTimeoutError, LLMTimeoutError from graphrag_sdk.core.models import ChatMessage, LLMResponse from graphrag_sdk.core.providers._retry import ( binary_split_retry_async, binary_split_retry_sync, summarize_exception, ) -from graphrag_sdk.core.providers.base import Embedder, LLMInterface +from graphrag_sdk.core.providers.base import ( + Embedder, + LLMInterface, + _validate_timeout, + _wait_for_provider_call, +) logger = logging.getLogger(__name__) @@ -141,8 +147,12 @@ async def ainvoke( prompt: str, *, max_retries: int = 3, + timeout: float | None = None, **kwargs: Any, ) -> LLMResponse: + if max_retries < 1: + raise ValueError("max_retries must be >= 1") + _validate_timeout(timeout) client = self._get_async_client() create_kwargs = self._build_create_kwargs( [{"role": "user", "content": prompt}], @@ -151,9 +161,16 @@ async def ainvoke( last_exc: Exception | None = None for attempt in range(max_retries): try: - response = await client.chat.completions.create(**create_kwargs) + response = await _wait_for_provider_call( + client.chat.completions.create(**create_kwargs), + timeout=timeout, + timeout_error=LLMTimeoutError, + operation=f"OpenRouter call to {self.model_name}", + ) content = response.choices[0].message.content or "" return LLMResponse(content=content) + except LLMTimeoutError: + raise except Exception as exc: last_exc = exc if attempt < max_retries - 1: @@ -174,11 +191,13 @@ async def ainvoke_messages( messages: list[ChatMessage], *, max_retries: int = 3, + timeout: float | None = None, **kwargs: Any, ) -> LLMResponse: """Native multi-turn completion via OpenRouter.""" if max_retries < 1: raise ValueError("max_retries must be >= 1") + _validate_timeout(timeout) client = self._get_async_client() create_kwargs = self._build_create_kwargs( [m.to_dict() for m in messages], @@ -187,9 +206,16 @@ async def ainvoke_messages( last_exc: Exception | None = None for attempt in range(max_retries): try: - response = await client.chat.completions.create(**create_kwargs) + response = await _wait_for_provider_call( + client.chat.completions.create(**create_kwargs), + timeout=timeout, + timeout_error=LLMTimeoutError, + operation=f"OpenRouter messages call to {self.model_name}", + ) content = response.choices[0].message.content or "" return LLMResponse(content=content) + except LLMTimeoutError: + raise except Exception as exc: last_exc = exc if attempt < max_retries - 1: @@ -280,10 +306,22 @@ def _raw_embed_sync(self, texts: list[str], **kwargs: Any) -> list[list[float]]: sorted_data = sorted(response.data, key=lambda x: x.index) return [d.embedding for d in sorted_data] - async def _raw_embed_async(self, texts: list[str], **kwargs: Any) -> list[list[float]]: + async def _raw_embed_async( + self, + texts: list[str], + *, + timeout: float | None = None, + **kwargs: Any, + ) -> list[list[float]]: """Raw async embed without retry — called by binary_split_retry_async.""" + _validate_timeout(timeout) client = self._get_async_client() - response = await client.embeddings.create(model=self.model, input=texts, **kwargs) + response = await _wait_for_provider_call( + client.embeddings.create(model=self.model, input=texts, **kwargs), + timeout=timeout, + timeout_error=EmbeddingTimeoutError, + operation=f"OpenRouter embedding batch with {self.model}", + ) sorted_data = sorted(response.data, key=lambda x: x.index) return [d.embedding for d in sorted_data] @@ -296,16 +334,42 @@ def embed_documents(self, texts: list[str], **kwargs: Any) -> list[list[float]]: results.extend(binary_split_retry_sync(self._raw_embed_sync, batch, **kwargs)) return results - async def aembed_query(self, text: str, **kwargs: Any) -> list[float]: + async def aembed_query( + self, + text: str, + *, + timeout: float | None = None, + **kwargs: Any, + ) -> list[float]: + _validate_timeout(timeout) client = self._get_async_client() - response = await client.embeddings.create(model=self.model, input=text, **kwargs) + response = await _wait_for_provider_call( + client.embeddings.create(model=self.model, input=text, **kwargs), + timeout=timeout, + timeout_error=EmbeddingTimeoutError, + operation=f"OpenRouter embedding query with {self.model}", + ) return response.data[0].embedding - async def aembed_documents(self, texts: list[str], **kwargs: Any) -> list[list[float]]: + async def aembed_documents( + self, + texts: list[str], + *, + timeout: float | None = None, + **kwargs: Any, + ) -> list[list[float]]: + _validate_timeout(timeout) if not texts: return [] results: list[list[float]] = [] for start in range(0, len(texts), self.batch_size): batch = texts[start : start + self.batch_size] - results.extend(await binary_split_retry_async(self._raw_embed_async, batch, **kwargs)) + results.extend( + await binary_split_retry_async( + self._raw_embed_async, + batch, + timeout=timeout, + **kwargs, + ) + ) return results diff --git a/graphrag_sdk/tests/test_exceptions.py b/graphrag_sdk/tests/test_exceptions.py index 12f84eea..56358156 100644 --- a/graphrag_sdk/tests/test_exceptions.py +++ b/graphrag_sdk/tests/test_exceptions.py @@ -8,11 +8,13 @@ ConfigError, DatabaseError, EmbeddingError, + EmbeddingTimeoutError, ExtractionError, GraphRAGError, IndexError_, IngestionError, LLMError, + LLMTimeoutError, LoaderError, ResolutionError, RetrieverError, @@ -29,10 +31,18 @@ def test_llm_error_is_graphrag_error(self): with pytest.raises(GraphRAGError): raise LLMError("llm failed") + def test_llm_timeout_error_is_llm_error(self): + with pytest.raises(LLMError): + raise LLMTimeoutError("llm timed out") + def test_embedding_error_is_graphrag_error(self): with pytest.raises(GraphRAGError): raise EmbeddingError("embedding failed") + def test_embedding_timeout_error_is_embedding_error(self): + with pytest.raises(EmbeddingError): + raise EmbeddingTimeoutError("embedding timed out") + def test_ingestion_error_hierarchy(self): assert issubclass(IngestionError, GraphRAGError) assert issubclass(LoaderError, IngestionError) diff --git a/graphrag_sdk/tests/test_providers.py b/graphrag_sdk/tests/test_providers.py index 4dfe0caa..a2476ef5 100644 --- a/graphrag_sdk/tests/test_providers.py +++ b/graphrag_sdk/tests/test_providers.py @@ -2,25 +2,25 @@ from __future__ import annotations import asyncio -import json +import time from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest from pydantic import BaseModel +from graphrag_sdk.core.exceptions import EmbeddingTimeoutError, LLMTimeoutError from graphrag_sdk.core.models import LLMResponse from graphrag_sdk.core.providers import ( Embedder, - LLMBatchItem, - LLMInterface, LiteLLM, LiteLLMEmbedder, + LLMBatchItem, + LLMInterface, OpenRouterEmbedder, OpenRouterLLM, ) - # ── Concrete test implementations ────────────────────────────── @@ -56,6 +56,26 @@ async def test_aembed_query_default(self): assert isinstance(result, list) assert result[0] == 5.0 # len("world") + async def test_aembed_query_timeout_raises_typed_error(self): + class SlowEmbedder(SimpleEmbedder): + def embed_query(self, text: str, **kwargs: Any) -> list[float]: + time.sleep(0.05) + return super().embed_query(text, **kwargs) + + emb = SlowEmbedder() + with pytest.raises(EmbeddingTimeoutError, match="timed out"): + await emb.aembed_query("world", timeout=0.001) + + async def test_aembed_documents_timeout_raises_typed_error(self): + class SlowEmbedder(SimpleEmbedder): + def embed_documents(self, texts: list[str], **kwargs: Any) -> list[list[float]]: + time.sleep(0.05) + return super().embed_documents(texts, **kwargs) + + emb = SlowEmbedder() + with pytest.raises(EmbeddingTimeoutError, match="timed out"): + await emb.aembed_documents(["world"], timeout=0.001) + def test_cannot_instantiate_abc(self): with pytest.raises(TypeError): Embedder() # type: ignore[abstract] @@ -81,6 +101,21 @@ async def test_ainvoke_default(self): response = await llm.ainvoke("Async test") assert response.content.startswith("Echo:") + async def test_ainvoke_timeout_raises_typed_error(self): + class SlowLLM(SimpleLLM): + def invoke(self, prompt: str, **kwargs: Any) -> LLMResponse: + time.sleep(0.05) + return super().invoke(prompt, **kwargs) + + llm = SlowLLM() + with pytest.raises(LLMTimeoutError, match="timed out"): + await llm.ainvoke("Async test", timeout=0.001, max_retries=1) + + async def test_ainvoke_rejects_non_positive_timeout(self): + llm = SimpleLLM() + with pytest.raises(ValueError, match="timeout must be > 0"): + await llm.ainvoke("Async test", timeout=0) + async def test_invoke_with_model(self): class Result(BaseModel): answer: int @@ -387,6 +422,19 @@ async def test_ainvoke(self): assert result.content == "async hi" mock_litellm.acompletion.assert_awaited_once() + @pytest.mark.asyncio + async def test_ainvoke_timeout_raises_typed_error(self): + async def slow_completion(**kwargs): + await asyncio.sleep(0.05) + return _mock_litellm_completion_response("too late") + + mock_litellm = MagicMock() + mock_litellm.acompletion = slow_completion + with patch.dict("sys.modules", {"litellm": mock_litellm}): + llm = LiteLLM(model="azure/gpt-4") + with pytest.raises(LLMTimeoutError, match="timed out"): + await llm.ainvoke("Hello", timeout=0.001, max_retries=1) + @pytest.mark.asyncio async def test_ainvoke_retries_on_failure(self): mock_litellm = MagicMock() @@ -603,6 +651,19 @@ async def test_aembed_query(self): assert result == vec + @pytest.mark.asyncio + async def test_aembed_query_timeout_raises_typed_error(self): + async def slow_embedding(**kwargs): + await asyncio.sleep(0.05) + return _mock_litellm_embedding_response([[0.7, 0.8, 0.9]]) + + mock_litellm = MagicMock() + mock_litellm.aembedding = slow_embedding + with patch.dict("sys.modules", {"litellm": mock_litellm}): + embedder = LiteLLMEmbedder(model="text-embedding-ada-002") + with pytest.raises(EmbeddingTimeoutError, match="timed out"): + await embedder.aembed_query("hello", timeout=0.001) + @pytest.mark.asyncio async def test_aembed_documents(self): mock_litellm = MagicMock() @@ -616,6 +677,21 @@ async def test_aembed_documents(self): assert result == vecs + @pytest.mark.asyncio + async def test_aembed_documents_timeout_is_not_binary_split(self): + async def slow_embedding(**kwargs): + await asyncio.sleep(0.05) + return _mock_litellm_embedding_response([[0.1, 0.2]]) + + mock_litellm = MagicMock() + mock_litellm.aembedding = AsyncMock(side_effect=slow_embedding) + with patch.dict("sys.modules", {"litellm": mock_litellm}): + embedder = LiteLLMEmbedder(model="text-embedding-ada-002") + with pytest.raises(EmbeddingTimeoutError, match="timed out"): + await embedder.aembed_documents(["a", "b"], timeout=0.001) + + assert mock_litellm.aembedding.await_count == 1 + def test_import_error(self): with patch.dict("sys.modules", {"litellm": None}): embedder = LiteLLMEmbedder(model="text-embedding-ada-002") @@ -632,7 +708,9 @@ class TestOpenRouterLLM: def test_invoke(self): mock_openai = MagicMock() mock_client = MagicMock() - mock_client.chat.completions.create.return_value = _mock_openai_completion_response("router hi") + mock_client.chat.completions.create.return_value = _mock_openai_completion_response( + "router hi" + ) mock_openai.OpenAI.return_value = mock_client with patch.dict("sys.modules", {"openai": mock_openai}): llm = OpenRouterLLM(model="anthropic/claude-sonnet-4-20250514", api_key="or-key") @@ -659,6 +737,22 @@ async def test_ainvoke(self): assert result.content == "async router" mock_openai.AsyncOpenAI.assert_called_once() + @pytest.mark.asyncio + async def test_ainvoke_timeout_raises_typed_error(self): + mock_openai = MagicMock() + mock_async_client = MagicMock() + + async def slow_create(**kwargs): + await asyncio.sleep(0.05) + return _mock_openai_completion_response("too late") + + mock_async_client.chat.completions.create = slow_create + mock_openai.AsyncOpenAI.return_value = mock_async_client + with patch.dict("sys.modules", {"openai": mock_openai}): + llm = OpenRouterLLM(model="openai/gpt-4", api_key="or-key") + with pytest.raises(LLMTimeoutError, match="timed out"): + await llm.ainvoke("Hello", timeout=0.001, max_retries=1) + def test_env_var_fallback(self): mock_openai = MagicMock() mock_client = MagicMock() @@ -705,7 +799,11 @@ def test_embed_documents_batching(self): ] mock_openai.OpenAI.return_value = mock_client with patch.dict("sys.modules", {"openai": mock_openai}): - embedder = OpenRouterEmbedder(model="openai/text-embedding-ada-002", api_key="k", batch_size=2) + embedder = OpenRouterEmbedder( + model="openai/text-embedding-ada-002", + api_key="k", + batch_size=2, + ) result = embedder.embed_documents(["a", "b", "c"]) assert len(result) == 3 @@ -727,6 +825,22 @@ async def test_aembed_query(self): assert result == vec + @pytest.mark.asyncio + async def test_aembed_query_timeout_raises_typed_error(self): + mock_openai = MagicMock() + mock_async_client = MagicMock() + + async def slow_create(**kwargs): + await asyncio.sleep(0.05) + return _mock_openai_embedding_response([[0.4, 0.5, 0.6]]) + + mock_async_client.embeddings.create = slow_create + mock_openai.AsyncOpenAI.return_value = mock_async_client + with patch.dict("sys.modules", {"openai": mock_openai}): + embedder = OpenRouterEmbedder(model="openai/text-embedding-ada-002", api_key="k") + with pytest.raises(EmbeddingTimeoutError, match="timed out"): + await embedder.aembed_query("hello", timeout=0.001) + def test_import_error(self): with patch.dict("sys.modules", {"openai": None}): embedder = OpenRouterEmbedder(model="text-embedding-ada-002", api_key="k") From 7d65089cc3a7ab384c7978956ec97144807dba2f Mon Sep 17 00:00:00 2001 From: Naseem Ali <34807727+Naseem77@users.noreply.github.com> Date: Mon, 18 May 2026 17:17:49 +0300 Subject: [PATCH 2/8] Improve broad exception visibility Add typed wrapping and error-level logging around high-risk broad exception paths while preserving debug tracebacks. Cover connection, provider, loader, pipeline, retrieval, and history validation error behavior.\n\nCo-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- graphrag_sdk/src/graphrag_sdk/api/main.py | 11 +++--- .../src/graphrag_sdk/core/connection.py | 23 ++++++++++-- .../src/graphrag_sdk/core/providers/_retry.py | 4 +++ .../src/graphrag_sdk/core/providers/base.py | 16 +++++++++ .../graphrag_sdk/core/providers/litellm.py | 20 +++++++++++ .../graphrag_sdk/core/providers/openrouter.py | 20 +++++++++++ .../ingestion/loaders/pdf_loader.py | 10 ++++-- .../ingestion/loaders/text_loader.py | 2 ++ .../src/graphrag_sdk/ingestion/pipeline.py | 2 ++ .../graphrag_sdk/retrieval/strategies/base.py | 2 ++ graphrag_sdk/tests/test_connection.py | 23 ++++++++++-- graphrag_sdk/tests/test_facade.py | 13 +++++++ graphrag_sdk/tests/test_loaders.py | 36 +++++++++++++++++++ graphrag_sdk/tests/test_pipeline.py | 8 +++-- graphrag_sdk/tests/test_providers.py | 24 +++++++++++++ graphrag_sdk/tests/test_retrieval.py | 8 +++-- 16 files changed, 205 insertions(+), 17 deletions(-) diff --git a/graphrag_sdk/src/graphrag_sdk/api/main.py b/graphrag_sdk/src/graphrag_sdk/api/main.py index 684c7a2d..df3d44e9 100644 --- a/graphrag_sdk/src/graphrag_sdk/api/main.py +++ b/graphrag_sdk/src/graphrag_sdk/api/main.py @@ -1480,13 +1480,16 @@ def _validate_history( f"history[{i}]: each message must have 'role' and " f"'content' keys, got {sorted(msg.keys())}" ) - try: - validated.append(ChatMessage(role=msg["role"], content=msg["content"])) - except Exception: + role = msg["role"] + content = msg["content"] + if role not in {"system", "user", "assistant"}: raise ValueError( - f"history[{i}]: invalid role '{msg['role']}'. " + f"history[{i}]: invalid role '{role}'. " f"Must be one of: 'system', 'user', 'assistant'" ) + if not isinstance(content, str): + raise ValueError(f"history[{i}]: content must be a string") + validated.append(ChatMessage(role=role, content=content)) else: raise TypeError( f"history[{i}]: expected ChatMessage or dict, got {type(msg).__name__}" diff --git a/graphrag_sdk/src/graphrag_sdk/core/connection.py b/graphrag_sdk/src/graphrag_sdk/core/connection.py index 74bc48dd..489af699 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/connection.py +++ b/graphrag_sdk/src/graphrag_sdk/core/connection.py @@ -11,6 +11,8 @@ from typing import Any from urllib.parse import urlparse +from graphrag_sdk.core.exceptions import DatabaseError + logger = logging.getLogger(__name__) @@ -180,7 +182,13 @@ async def query( last_exc = exc # Don't retry non-transient errors (e.g. schema/index conflicts) if self._is_non_transient(exc): - raise + logger.error( + "Non-transient FalkorDB query failure: %s: %s", + type(exc).__name__, + exc, + ) + logger.debug("Non-transient FalkorDB query failure details", exc_info=True) + raise DatabaseError(f"FalkorDB query failed: {exc}") from exc await self._breaker.record_failure() logger.warning( "Query attempt %d/%d failed: %s", @@ -193,7 +201,18 @@ async def query( break base_delay = self.config.retry_delay * (2**attempt) await asyncio.sleep(base_delay * (0.5 + random.random())) - raise last_exc # type: ignore[misc] + logger.error( + "FalkorDB query failed after %d attempts: %s: %s", + self.config.retry_count, + type(last_exc).__name__ if last_exc is not None else "UnknownError", + last_exc, + ) + if last_exc is not None: + logger.debug( + "FalkorDB query failure details", + exc_info=(type(last_exc), last_exc, last_exc.__traceback__), + ) + raise DatabaseError(f"FalkorDB query failed: {last_exc}") from last_exc # Substrings that indicate a non-transient (permanent) error — # retrying will never succeed. diff --git a/graphrag_sdk/src/graphrag_sdk/core/providers/_retry.py b/graphrag_sdk/src/graphrag_sdk/core/providers/_retry.py index 4e2ad026..a911a14d 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/providers/_retry.py +++ b/graphrag_sdk/src/graphrag_sdk/core/providers/_retry.py @@ -71,6 +71,8 @@ def binary_split_retry_sync( raise except Exception as exc: if not is_transient_embedding_error(exc): + logger.error("Non-transient embedding failure: %s", summarize_exception(exc)) + logger.debug("Non-transient embedding failure details", exc_info=True) raise if len(texts) == 1: logger.warning("Embedding failed for text (len=%d): skipped", len(texts[0])) @@ -93,6 +95,8 @@ async def binary_split_retry_async( raise except Exception as exc: if not is_transient_embedding_error(exc): + logger.error("Non-transient embedding failure: %s", summarize_exception(exc)) + logger.debug("Non-transient embedding failure details", exc_info=True) raise if len(texts) == 1: logger.warning("Embedding failed for text (len=%d): skipped", len(texts[0])) diff --git a/graphrag_sdk/src/graphrag_sdk/core/providers/base.py b/graphrag_sdk/src/graphrag_sdk/core/providers/base.py index 53dacc60..82c0b680 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/providers/base.py +++ b/graphrag_sdk/src/graphrag_sdk/core/providers/base.py @@ -188,6 +188,16 @@ async def ainvoke( ) logger.debug("LLM call failure details", exc_info=True) await asyncio.sleep(delay) + logger.error( + "LLM call failed after %d attempts: %s", + max_retries, + summarize_exception(last_exc) if last_exc is not None else "UnknownError", + ) + if last_exc is not None: + logger.debug( + "LLM call final failure details", + exc_info=(type(last_exc), last_exc, last_exc.__traceback__), + ) raise last_exc # type: ignore[misc] async def ainvoke_messages( @@ -299,6 +309,12 @@ async def _call(i: int, prompt: str) -> LLMBatchItem: ) return LLMBatchItem(index=i, response=resp) except Exception as exc: + logger.error( + "Batch LLM item %d failed: %s", + i, + summarize_exception(exc), + ) + logger.debug("Batch LLM item failure details", exc_info=True) return LLMBatchItem(index=i, error=exc) return list(await asyncio.gather(*[_call(i, p) for i, p in enumerate(prompts)])) diff --git a/graphrag_sdk/src/graphrag_sdk/core/providers/litellm.py b/graphrag_sdk/src/graphrag_sdk/core/providers/litellm.py index 6360ec3c..35637290 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/providers/litellm.py +++ b/graphrag_sdk/src/graphrag_sdk/core/providers/litellm.py @@ -161,6 +161,16 @@ async def ainvoke( ) logger.debug("LiteLLM call failure details", exc_info=True) await asyncio.sleep(delay) + logger.error( + "LiteLLM call failed after %d attempts: %s", + max_retries, + summarize_exception(last_exc) if last_exc is not None else "UnknownError", + ) + if last_exc is not None: + logger.debug( + "LiteLLM call final failure details", + exc_info=(type(last_exc), last_exc, last_exc.__traceback__), + ) raise last_exc # type: ignore[misc] def _messages_completion_kwargs( @@ -240,6 +250,16 @@ async def ainvoke_messages( ) logger.debug("LiteLLM call failure details", exc_info=True) await asyncio.sleep(delay) + logger.error( + "LiteLLM messages call failed after %d attempts: %s", + max_retries, + summarize_exception(last_exc) if last_exc is not None else "UnknownError", + ) + if last_exc is not None: + logger.debug( + "LiteLLM messages call final failure details", + exc_info=(type(last_exc), last_exc, last_exc.__traceback__), + ) raise last_exc # type: ignore[misc] diff --git a/graphrag_sdk/src/graphrag_sdk/core/providers/openrouter.py b/graphrag_sdk/src/graphrag_sdk/core/providers/openrouter.py index 42b64b3b..f4f2a1e1 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/providers/openrouter.py +++ b/graphrag_sdk/src/graphrag_sdk/core/providers/openrouter.py @@ -184,6 +184,16 @@ async def ainvoke( ) logger.debug("OpenRouter call failure details", exc_info=True) await asyncio.sleep(delay) + logger.error( + "OpenRouter call failed after %d attempts: %s", + max_retries, + summarize_exception(last_exc) if last_exc is not None else "UnknownError", + ) + if last_exc is not None: + logger.debug( + "OpenRouter call final failure details", + exc_info=(type(last_exc), last_exc, last_exc.__traceback__), + ) raise last_exc # type: ignore[misc] async def ainvoke_messages( @@ -229,6 +239,16 @@ async def ainvoke_messages( ) logger.debug("OpenRouter call failure details", exc_info=True) await asyncio.sleep(delay) + logger.error( + "OpenRouter messages call failed after %d attempts: %s", + max_retries, + summarize_exception(last_exc) if last_exc is not None else "UnknownError", + ) + if last_exc is not None: + logger.debug( + "OpenRouter messages call final failure details", + exc_info=(type(last_exc), last_exc, last_exc.__traceback__), + ) raise last_exc # type: ignore[misc] diff --git a/graphrag_sdk/src/graphrag_sdk/ingestion/loaders/pdf_loader.py b/graphrag_sdk/src/graphrag_sdk/ingestion/loaders/pdf_loader.py index 72989058..e1a75c5a 100644 --- a/graphrag_sdk/src/graphrag_sdk/ingestion/loaders/pdf_loader.py +++ b/graphrag_sdk/src/graphrag_sdk/ingestion/loaders/pdf_loader.py @@ -56,8 +56,9 @@ def _load_with_pymupdf(self, path: Path) -> DocumentOutput: """Extract text using PyMuPDF with sort=True for table-aware layout.""" import fitz # PyMuPDF - doc = fitz.open(str(path)) + doc = None try: + doc = fitz.open(str(path)) pages: list[str] = [] for page in doc: text = page.get_text(sort=True) @@ -65,9 +66,12 @@ def _load_with_pymupdf(self, path: Path) -> DocumentOutput: pages.append(text) page_count = len(doc) except Exception as exc: + logger.error("Failed to read PDF %s with PyMuPDF: %s", path, exc) + logger.debug("PyMuPDF read failure details", exc_info=True) raise LoaderError(f"Failed to read PDF {path}: {exc}") from exc finally: - doc.close() + if doc is not None: + doc.close() full_text = "\n\n".join(pages) logger.info( @@ -117,4 +121,6 @@ def _load_with_pypdf(self, path: Path) -> DocumentOutput: ), ) except Exception as exc: + logger.error("Failed to read PDF %s with pypdf: %s", path, exc) + logger.debug("pypdf read failure details", exc_info=True) raise LoaderError(f"Failed to read PDF {path}: {exc}") from exc diff --git a/graphrag_sdk/src/graphrag_sdk/ingestion/loaders/text_loader.py b/graphrag_sdk/src/graphrag_sdk/ingestion/loaders/text_loader.py index 56c7e3d6..842c24c8 100644 --- a/graphrag_sdk/src/graphrag_sdk/ingestion/loaders/text_loader.py +++ b/graphrag_sdk/src/graphrag_sdk/ingestion/loaders/text_loader.py @@ -46,4 +46,6 @@ def _load_sync(self, source: str) -> DocumentOutput: ), ) except Exception as exc: + logger.error("Failed to read text file %s: %s", source, exc) + logger.debug("Text file read failure details", exc_info=True) raise LoaderError(f"Failed to read {source}: {exc}") from exc diff --git a/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py b/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py index 165cbeeb..1a31ed2b 100644 --- a/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py +++ b/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py @@ -275,6 +275,8 @@ async def _step_index_chunks() -> None: except IngestionError: raise except Exception as exc: + logger.error("Pipeline failed with unexpected error: %s", exc) + logger.debug("Pipeline failure details", exc_info=True) raise IngestionError(f"Pipeline failed: {exc}") from exc async def _build_lexical_graph( diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/base.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/base.py index b3669c3a..7cce3850 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/base.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/base.py @@ -94,6 +94,8 @@ async def search( except RetrieverError: raise except Exception as exc: + logger.error("Retrieval [%s] failed: %s", self.__class__.__name__, exc) + logger.debug("Retrieval failure details", exc_info=True) raise RetrieverError(f"Retrieval [{self.__class__.__name__}] failed: {exc}") from exc @abstractmethod diff --git a/graphrag_sdk/tests/test_connection.py b/graphrag_sdk/tests/test_connection.py index 922cff59..cfdb1abe 100644 --- a/graphrag_sdk/tests/test_connection.py +++ b/graphrag_sdk/tests/test_connection.py @@ -6,6 +6,7 @@ import pytest from graphrag_sdk.core.connection import ConnectionConfig, FalkorDBConnection +from graphrag_sdk.core.exceptions import DatabaseError class TestConnectionConfig: @@ -105,7 +106,7 @@ async def test_query_retries(self): assert result == "success" assert mock_graph.query.call_count == 3 - async def test_query_exhausts_retries(self): + async def test_query_exhausts_retries(self, caplog): conn = FalkorDBConnection(ConnectionConfig(retry_count=2, retry_delay=0.0)) mock_graph = MagicMock() @@ -113,10 +114,26 @@ async def test_query_exhausts_retries(self): conn._graph = mock_graph conn._driver = MagicMock() - with pytest.raises(Exception, match="always fails"): - await conn.query("MATCH (n) RETURN n") + with caplog.at_level("ERROR", logger="graphrag_sdk.core.connection"): + with pytest.raises(DatabaseError, match="always fails"): + await conn.query("MATCH (n) RETURN n") + assert "FalkorDB query failed after 2 attempts" in caplog.text assert mock_graph.query.call_count == 2 + async def test_query_non_transient_failure_is_typed_and_logged(self, caplog): + conn = FalkorDBConnection(ConnectionConfig(retry_count=3, retry_delay=0.0)) + + mock_graph = MagicMock() + mock_graph.query = AsyncMock(side_effect=Exception("already indexed")) + conn._graph = mock_graph + conn._driver = MagicMock() + + with caplog.at_level("ERROR", logger="graphrag_sdk.core.connection"): + with pytest.raises(DatabaseError, match="already indexed"): + await conn.query("CREATE INDEX idx") + assert "Non-transient FalkorDB query failure" in caplog.text + assert mock_graph.query.call_count == 1 + async def test_query_with_params(self): conn = FalkorDBConnection(ConnectionConfig(retry_count=1)) diff --git a/graphrag_sdk/tests/test_facade.py b/graphrag_sdk/tests/test_facade.py index 78072b4f..b4a178aa 100644 --- a/graphrag_sdk/tests/test_facade.py +++ b/graphrag_sdk/tests/test_facade.py @@ -534,6 +534,19 @@ async def test_completion_history_wrong_type_raises(self, mock_conn, embedder): with pytest.raises(TypeError, match="expected ChatMessage or dict"): await g.completion("test?", history=["not a dict"]) + async def test_completion_history_non_string_content_raises(self, mock_conn, embedder): + """History dict content must be a string.""" + llm = MockLLM(responses=["unused"]) + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + mock_strategy = MagicMock(spec=RetrievalStrategy) + mock_strategy.search = AsyncMock( + return_value=RetrieverResult(items=[RetrieverResultItem(content="c")]) + ) + g._retrieval_strategy = mock_strategy + + with pytest.raises(ValueError, match="content must be a string"): + await g.completion("test?", history=[{"role": "user", "content": 123}]) + async def test_completion_no_history_uses_messages_api(self, mock_conn, embedder): """Without history, completion still uses ainvoke_messages (unified path).""" llm = MockLLM(responses=["Single-turn answer."]) diff --git a/graphrag_sdk/tests/test_loaders.py b/graphrag_sdk/tests/test_loaders.py index 46620b08..4af116be 100644 --- a/graphrag_sdk/tests/test_loaders.py +++ b/graphrag_sdk/tests/test_loaders.py @@ -1,6 +1,8 @@ """Tests for ingestion/loaders/ — TextLoader and PdfLoader.""" from __future__ import annotations +import logging + import pytest from graphrag_sdk.core.exceptions import LoaderError @@ -59,6 +61,20 @@ async def test_document_info_has_uid(self, ctx, tmp_path): result = await loader.load(str(file), ctx) assert result.document_info.uid # non-empty UUID + async def test_read_failure_is_logged_and_typed(self, ctx, tmp_path, monkeypatch, caplog): + file = tmp_path / "unreadable.txt" + file.write_text("content") + loader = TextLoader() + + def fail_read_text(*args, **kwargs): + raise PermissionError("denied") + + monkeypatch.setattr("pathlib.Path.read_text", fail_read_text) + with caplog.at_level(logging.ERROR, logger="graphrag_sdk.ingestion.loaders.text_loader"): + with pytest.raises(LoaderError, match="Failed to read"): + await loader.load(str(file), ctx) + assert "Failed to read text file" in caplog.text + class TestPdfLoader: """PDF loader tests — skipped if no PDF backend is installed.""" @@ -136,6 +152,26 @@ def fake_pypdf(self, path): result = await pdf_mod.PdfLoader().load(str(file), ctx) assert result.document_info.metadata["pdf_backend"] == "pypdf" + async def test_pymupdf_open_failure_is_logged_and_typed( + self, ctx, tmp_path, monkeypatch, caplog + ): + from graphrag_sdk.ingestion.loaders import pdf_loader as pdf_mod + + file = tmp_path / "corrupt.pdf" + file.write_bytes(b"%PDF-corrupt") + + class FakeFitz: + @staticmethod + def open(path): + raise RuntimeError("corrupt pdf") + + monkeypatch.setitem(__import__("sys").modules, "fitz", FakeFitz) + with caplog.at_level(logging.ERROR, logger="graphrag_sdk.ingestion.loaders.pdf_loader"): + with pytest.raises(LoaderError, match="Failed to read PDF"): + await pdf_mod.PdfLoader().load(str(file), ctx) + assert "Failed to read PDF" in caplog.text + assert "PyMuPDF" in caplog.text + async def test_raises_when_no_backend_installed(self, ctx, tmp_path, monkeypatch): """With neither backend available, a clear install message is raised.""" from graphrag_sdk.ingestion.loaders import pdf_loader as pdf_mod diff --git a/graphrag_sdk/tests/test_pipeline.py b/graphrag_sdk/tests/test_pipeline.py index 3945fcf1..321367ae 100644 --- a/graphrag_sdk/tests/test_pipeline.py +++ b/graphrag_sdk/tests/test_pipeline.py @@ -194,7 +194,7 @@ async def extract(self, chunks, schema, ctx): # KNOWS survives; WRONG pruned (+ Alien endpoint gone) assert result.relationships_created == 1 - async def test_pipeline_wraps_exception(self, ctx, mock_graph_store, mock_vector_store): + async def test_pipeline_wraps_exception(self, ctx, mock_graph_store, mock_vector_store, caplog): """Non-IngestionError exceptions get wrapped.""" class FailingLoader(LoaderStrategy): async def load(self, source, ctx): @@ -209,8 +209,10 @@ async def load(self, source, ctx): vector_store=mock_vector_store, schema=GraphSchema(), ) - with pytest.raises(IngestionError, match="Pipeline failed"): - await pipeline.run("test.txt", ctx) + with caplog.at_level("ERROR", logger="graphrag_sdk.ingestion.pipeline"): + with pytest.raises(IngestionError, match="Pipeline failed"): + await pipeline.run("test.txt", ctx) + assert "Pipeline failed with unexpected error" in caplog.text async def test_pipeline_writes_content_hash(self, ctx, mock_graph_store, mock_vector_store): """v1.1.0: Document node carries SHA-256 of the loaded text so diff --git a/graphrag_sdk/tests/test_providers.py b/graphrag_sdk/tests/test_providers.py index a2476ef5..cea4fa2e 100644 --- a/graphrag_sdk/tests/test_providers.py +++ b/graphrag_sdk/tests/test_providers.py @@ -200,6 +200,30 @@ def invoke(self, prompt: str, **kwargs: Any) -> LLMResponse: # Single-line invariant assert "\n" not in msg + async def test_ainvoke_exhaustion_logs_error_without_exception_body(self, caplog): + secret_body = "Authorization: Bearer SECRET_KEY_xyz\nproxy=https://internal" + + class FailingLLM(LLMInterface): + def __init__(self) -> None: + super().__init__(model_name="fail") + + def invoke(self, prompt: str, **kwargs: Any) -> LLMResponse: + raise RuntimeError(f"upstream 500\n{secret_body}") + + import logging + + llm = FailingLLM() + with caplog.at_level(logging.ERROR, logger="graphrag_sdk.core.providers.base"): + with pytest.raises(RuntimeError): + await llm.ainvoke("hi", max_retries=1) + errors = [r for r in caplog.records if r.levelno == logging.ERROR] + assert errors + msg = errors[0].getMessage() + assert "LLM call failed after 1 attempts" in msg + assert "upstream 500" in msg + assert "SECRET_KEY_xyz" not in msg + assert "proxy=" not in msg + class TestSummarizeException: """S6: WARNING-level exception logs must be single-line and length-bounded.""" diff --git a/graphrag_sdk/tests/test_retrieval.py b/graphrag_sdk/tests/test_retrieval.py index d48bc7f1..45389fa6 100644 --- a/graphrag_sdk/tests/test_retrieval.py +++ b/graphrag_sdk/tests/test_retrieval.py @@ -53,10 +53,12 @@ async def test_search_validates_whitespace_query(self, ctx): with pytest.raises(RetrieverError, match="Empty query"): await strategy.search(" ", ctx) - async def test_search_wraps_exception(self, ctx): + async def test_search_wraps_exception(self, ctx, caplog): strategy = StubRetrieval(should_fail=True) - with pytest.raises(RetrieverError, match="failed"): - await strategy.search("valid query", ctx) + with caplog.at_level("ERROR", logger="graphrag_sdk.retrieval.strategies.base"): + with pytest.raises(RetrieverError, match="failed"): + await strategy.search("valid query", ctx) + assert "Retrieval [StubRetrieval] failed" in caplog.text async def test_search_default_context(self): """search() creates default Context if none provided.""" From 7e7ec1db20c25338b0301a7211f8251847a6a1d4 Mon Sep 17 00:00:00 2001 From: Naseem Ali <34807727+Naseem77@users.noreply.github.com> Date: Mon, 18 May 2026 17:19:59 +0300 Subject: [PATCH 3/8] Cover GraphRAG async context cleanup Add regression tests for async context manager cleanup, close failure propagation, and inner-exception preservation.\n\nCo-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- graphrag_sdk/tests/test_facade.py | 34 +++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/graphrag_sdk/tests/test_facade.py b/graphrag_sdk/tests/test_facade.py index b4a178aa..38fb4376 100644 --- a/graphrag_sdk/tests/test_facade.py +++ b/graphrag_sdk/tests/test_facade.py @@ -102,6 +102,40 @@ async def _execute(self, query, ctx, **kwargs): g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, retrieval_strategy=strategy, embedding_dimension=8) assert g._retrieval_strategy is strategy + async def test_async_context_manager_returns_self_and_closes(self, mock_conn, embedder, llm): + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + + async with g as rag: + assert rag is g + + mock_conn.close.assert_awaited_once() + + async def test_async_context_manager_preserves_inner_exception_on_close_failure( + self, mock_conn, embedder, llm, caplog + ): + mock_conn.close = AsyncMock(side_effect=RuntimeError("close failed")) + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + + with caplog.at_level("WARNING", logger="graphrag_sdk.api.main"): + with pytest.raises(ValueError, match="inner failure"): + async with g: + raise ValueError("inner failure") + + mock_conn.close.assert_awaited_once() + assert "Error closing connection during __aexit__" in caplog.text + + async def test_async_context_manager_raises_close_failure_without_inner_exception( + self, mock_conn, embedder, llm + ): + mock_conn.close = AsyncMock(side_effect=RuntimeError("close failed")) + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + + with pytest.raises(RuntimeError, match="close failed"): + async with g: + pass + + mock_conn.close.assert_awaited_once() + class TestGraphRAGGraphAdmin: """A1: graph admin operations exposed as facade methods.""" From 8013163702fa0cc2cfd94564141bb3e912c2589b Mon Sep 17 00:00:00 2001 From: Naseem Ali <34807727+Naseem77@users.noreply.github.com> Date: Mon, 18 May 2026 17:52:59 +0300 Subject: [PATCH 4/8] Enforce retrieval latency budgets Add a typed latency budget error and enforce Context budgets before retrieval phases, helper I/O, graph config probes, Cypher calls, and completion LLM calls. Cover propagation and phase gating with regression tests.\n\nCo-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- graphrag_sdk/src/graphrag_sdk/__init__.py | 7 +- graphrag_sdk/src/graphrag_sdk/api/main.py | 28 +++- graphrag_sdk/src/graphrag_sdk/core/context.py | 12 ++ .../src/graphrag_sdk/core/exceptions.py | 6 + .../graphrag_sdk/retrieval/strategies/base.py | 4 +- .../retrieval/strategies/chunk_retrieval.py | 29 ++++ .../retrieval/strategies/cypher_generation.py | 15 +- .../retrieval/strategies/entity_discovery.py | 30 ++++ .../retrieval/strategies/local.py | 3 + .../retrieval/strategies/multi_path.py | 52 +++++-- .../strategies/relationship_expansion.py | 12 ++ .../retrieval/strategies/result_assembly.py | 7 + graphrag_sdk/tests/test_chunk_retrieval.py | 63 +++++++- graphrag_sdk/tests/test_context.py | 10 ++ graphrag_sdk/tests/test_cypher_generation.py | 49 +++++++ graphrag_sdk/tests/test_exceptions.py | 5 + graphrag_sdk/tests/test_facade.py | 134 +++++++++++++++++- .../tests/test_multi_path_retrieval.py | 82 +++++++++++ graphrag_sdk/tests/test_retrieval.py | 44 +++++- 19 files changed, 574 insertions(+), 18 deletions(-) diff --git a/graphrag_sdk/src/graphrag_sdk/__init__.py b/graphrag_sdk/src/graphrag_sdk/__init__.py index f66ed70a..2b90d246 100644 --- a/graphrag_sdk/src/graphrag_sdk/__init__.py +++ b/graphrag_sdk/src/graphrag_sdk/__init__.py @@ -19,7 +19,11 @@ # ── Core Contracts ─────────────────────────────────────────────── from graphrag_sdk.core.connection import ConnectionConfig, FalkorDBConnection from graphrag_sdk.core.context import Context -from graphrag_sdk.core.exceptions import DocumentNotFoundError, GraphRAGError +from graphrag_sdk.core.exceptions import ( + DocumentNotFoundError, + GraphRAGError, + LatencyBudgetExceededError, +) from graphrag_sdk.core.models import ( ApplyChangesResult, BatchEntry, @@ -134,6 +138,7 @@ "GraphRelationship", "GraphSchema", "IngestionResult", + "LatencyBudgetExceededError", "LLMBatchItem", "LLMInterface", "LiteLLM", diff --git a/graphrag_sdk/src/graphrag_sdk/api/main.py b/graphrag_sdk/src/graphrag_sdk/api/main.py index df3d44e9..5c132ec5 100644 --- a/graphrag_sdk/src/graphrag_sdk/api/main.py +++ b/graphrag_sdk/src/graphrag_sdk/api/main.py @@ -15,7 +15,12 @@ from graphrag_sdk import __version__ from graphrag_sdk.core.connection import ConnectionConfig, FalkorDBConnection from graphrag_sdk.core.context import Context -from graphrag_sdk.core.exceptions import ConfigError, DatabaseError, DocumentNotFoundError +from graphrag_sdk.core.exceptions import ( + ConfigError, + DatabaseError, + DocumentNotFoundError, + LatencyBudgetExceededError, +) from graphrag_sdk.core.models import ( ApplyChangesResult, BatchEntry, @@ -1446,13 +1451,16 @@ async def retrieve( ctx = Context() ctx.log(f"Retrieve: {question[:80]}...") + ctx.ensure_budget("graph config validation") - await self._validate_graph_config() + await self._validate_graph_config(ctx=ctx) retrieval = strategy or self._retrieval_strategy + ctx.ensure_budget("retrieval strategy search") retriever_result = await retrieval.search(question, ctx) if reranker is not None: + ctx.ensure_budget("retrieval reranking") retriever_result = await reranker.rerank(question, retriever_result, ctx) ctx.log(f"Retrieved {len(retriever_result.items)} context items") @@ -1518,8 +1526,11 @@ async def _rewrite_question_with_history( question=question, ) try: + ctx.ensure_budget("question rewrite LLM call") resp = await self.llm.ainvoke(prompt) rewritten = (resp.content or "").strip().splitlines()[0].strip() if resp.content else "" + except LatencyBudgetExceededError: + raise except Exception as e: # Broad catch is intentional (see docstring) — but log at WARNING # with full traceback so programming bugs surface in operator @@ -1588,6 +1599,7 @@ async def completion( # Step 1: Optionally rewrite the question for retrieval. retrieval_query = question if validated_history and rewrite_question_with_history: + ctx.ensure_budget("question rewrite") retrieval_query = await self._rewrite_question_with_history( question, validated_history, @@ -1597,6 +1609,7 @@ async def completion( ctx.log(f"Rewrote for retrieval: {retrieval_query[:80]}") # Step 2: Retrieve + rerank (using possibly-rewritten query). + ctx.ensure_budget("completion retrieval") retriever_result = await self.retrieve( retrieval_query, strategy=strategy, @@ -1638,6 +1651,7 @@ async def completion( ChatMessage(role="user", content=final_user_content), ] + ctx.ensure_budget("completion LLM call") llm_response = await self.llm.ainvoke_messages(messages) result = RagResult( @@ -1678,7 +1692,7 @@ async def _write_graph_config(self) -> None: except Exception: logger.debug("Failed to write graph config node", exc_info=True) - async def _validate_graph_config(self) -> None: + async def _validate_graph_config(self, ctx: Context | None = None) -> None: """Check that the current embedder matches the graph's stored config. Two checks, both cached after first run: @@ -1700,6 +1714,8 @@ async def _validate_graph_config(self) -> None: return try: + if ctx is not None: + ctx.ensure_budget("graph config query") result = await self._graph_store.query_raw( "MATCH (c:__GraphRAGConfig__ {id: 'default'}) " "RETURN c.embedding_model, c.embedding_dimension" @@ -1724,6 +1740,8 @@ async def _validate_graph_config(self) -> None: ) except ConfigError: raise + except LatencyBudgetExceededError: + raise except Exception: # Don't mark as validated on transient failures — retry next call. logger.debug("Failed to validate graph config", exc_info=True) @@ -1732,8 +1750,12 @@ async def _validate_graph_config(self) -> None: # Probe the embedder once: confirm it produces vectors of the # configured dimension. Catches user error like # ``embedding_dimension=256`` paired with a 1536-dim model. + if ctx is not None: + ctx.ensure_budget("graph config embedder probe") try: probe = await self.embedder.aembed_query("dim_check") + except LatencyBudgetExceededError: + raise except Exception: # Probe failure is non-fatal — but don't cache a "validated" # state, otherwise a transient outage permanently disables diff --git a/graphrag_sdk/src/graphrag_sdk/core/context.py b/graphrag_sdk/src/graphrag_sdk/core/context.py index 7f93ab0e..97bdd920 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/context.py +++ b/graphrag_sdk/src/graphrag_sdk/core/context.py @@ -10,6 +10,8 @@ from typing import Any from uuid import uuid4 +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError + logger = logging.getLogger(__name__) @@ -52,6 +54,16 @@ def budget_exceeded(self) -> bool: remaining = self.remaining_budget_ms return remaining is not None and remaining <= 0 + def ensure_budget(self, operation: str) -> None: + """Raise if the latency budget is already exhausted before *operation* starts.""" + if not self.budget_exceeded: + return + budget = self.latency_budget_ms if self.latency_budget_ms is not None else 0.0 + raise LatencyBudgetExceededError( + f"Latency budget exceeded before {operation} " + f"(elapsed={self.elapsed_ms:.1f}ms, budget={budget:.1f}ms)" + ) + def child(self, **overrides: Any) -> Context: """Create a child context inheriting tenant/trace but with optional overrides. diff --git a/graphrag_sdk/src/graphrag_sdk/core/exceptions.py b/graphrag_sdk/src/graphrag_sdk/core/exceptions.py index 2e98a384..99c3c379 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/exceptions.py +++ b/graphrag_sdk/src/graphrag_sdk/core/exceptions.py @@ -10,6 +10,12 @@ class GraphRAGError(Exception): pass +class LatencyBudgetExceededError(GraphRAGError): + """Raised when an operation cannot start within the remaining latency budget.""" + + pass + + # ── Provider Errors ────────────────────────────────────────────── diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/base.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/base.py index 7cce3850..1dea837a 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/base.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/base.py @@ -12,7 +12,7 @@ from typing import Any from graphrag_sdk.core.context import Context -from graphrag_sdk.core.exceptions import RetrieverError +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError, RetrieverError from graphrag_sdk.core.models import RawSearchResult, RetrieverResult, RetrieverResultItem logger = logging.getLogger(__name__) @@ -93,6 +93,8 @@ async def search( except RetrieverError: raise + except LatencyBudgetExceededError: + raise except Exception as exc: logger.error("Retrieval [%s] failed: %s", self.__class__.__name__, exc) logger.debug("Retrieval failure details", exc_info=True) diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/chunk_retrieval.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/chunk_retrieval.py index e71e4455..2d4f5c5f 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/chunk_retrieval.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/chunk_retrieval.py @@ -6,6 +6,9 @@ import logging from typing import Any +from graphrag_sdk.core.context import Context +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError + logger = logging.getLogger(__name__) @@ -17,6 +20,7 @@ async def retrieve_chunks( llm_kw: list[str], simple_kw: list[str], entity_list: list[tuple[str, dict]], + ctx: Context | None = None, ) -> tuple[dict[str, str], dict[str, str], dict[str, list[float]]]: """4-path chunk retrieval: fulltext + vector + MENTIONED_IN + 2-hop. @@ -37,17 +41,25 @@ def _add(cid: str, text: str, source: str) -> None: fulltext_queries = [query] + llm_kw[:6] + simple_kw[:4] for ft_q in fulltext_queries: try: + if ctx is not None: + ctx.ensure_budget("chunk fulltext search") results = await vector_store.fulltext_search_chunks(ft_q, top_k=5) for c in results: _add(c.get("id", ""), c.get("text", ""), "fulltext") + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Chunk fulltext search failed for query: %s", exc) # Path B: Vector search try: + if ctx is not None: + ctx.ensure_budget("chunk vector search") results = await vector_store.search_chunks(query_vector, top_k=15) for c in results: _add(c.get("id", ""), c.get("text", ""), "vector") + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Chunk vector search failed: %s", exc) @@ -60,6 +72,8 @@ def _add(cid: str, text: str, source: str) -> None: eids_mention = [eid for eid, _ in entity_list[:15]] if eids_mention: try: + if ctx is not None: + ctx.ensure_budget("MENTIONED_IN chunk retrieval") result = await graph_store.query_raw( "UNWIND $eids AS eid " "MATCH (e:__Entity__ {id: eid})-[:MENTIONED_IN]->(c:Chunk) " @@ -75,6 +89,8 @@ def _add(cid: str, text: str, source: str) -> None: cid = row[1] text = row[2] if len(row) > 2 else "" _add(cid, text, "mentioned_in") + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("MENTIONED_IN chunk retrieval failed: %s", exc) @@ -82,6 +98,8 @@ def _add(cid: str, text: str, source: str) -> None: eids_2hop_chunk = [eid for eid, _ in entity_list[:10]] if eids_2hop_chunk: try: + if ctx is not None: + ctx.ensure_budget("2-hop chunk retrieval") result = await graph_store.query_raw( "UNWIND $eids AS eid " "MATCH (e:__Entity__ {id: eid})-[:RELATES]-(neighbor:__Entity__)" @@ -94,6 +112,8 @@ def _add(cid: str, text: str, source: str) -> None: cid = row[0] text = row[1] if len(row) > 1 else "" _add(cid, text, "2hop_mentioned") + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("2-hop chunk retrieval failed: %s", exc) @@ -102,6 +122,8 @@ def _add(cid: str, text: str, source: str) -> None: missing_ids = list(chunks.keys()) if missing_ids: try: + if ctx is not None: + ctx.ensure_budget("stored chunk embedding fetch") result = await graph_store.query_raw( "UNWIND $ids AS cid " "MATCH (c:Chunk {id: cid}) " @@ -112,6 +134,8 @@ def _add(cid: str, text: str, source: str) -> None: for row in result.result_set: if row[0] and row[1] is not None: embeddings[row[0]] = list(row[1]) + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Stored embedding fetch failed: %s", exc) @@ -121,6 +145,7 @@ def _add(cid: str, text: str, source: str) -> None: async def fetch_chunk_documents( graph_store: Any, chunk_ids: list[str], + ctx: Context | None = None, ) -> dict[str, str]: """Batch-fetch the source document path for each chunk via PART_OF. @@ -137,6 +162,8 @@ async def fetch_chunk_documents( if not chunk_ids: return {} try: + if ctx is not None: + ctx.ensure_budget("chunk source document fetch") result = await graph_store.query_raw( "UNWIND $ids AS cid " "MATCH (d:Document)-[:PART_OF]->(c:Chunk {id: cid}) " @@ -150,6 +177,8 @@ async def fetch_chunk_documents( if cid and path: mapping[cid] = path return mapping + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Document name fetch failed: %s", exc) return {} diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py index 0dd3d289..e337741f 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py @@ -12,6 +12,9 @@ import re from typing import Any +from graphrag_sdk.core.context import Context +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError + logger = logging.getLogger(__name__) # ── Valid labels for our graph schema ──────────────────────────── @@ -245,6 +248,7 @@ async def generate_cypher( question: str, *, max_retries: int = 3, + ctx: Context | None = None, ) -> str | None: """Generate a Cypher query from a natural language question. @@ -255,6 +259,8 @@ async def generate_cypher( for attempt in range(max_retries): try: + if ctx is not None: + ctx.ensure_budget("Cypher generation LLM call") if attempt > 0 and last_error: prompt_with_feedback = ( prompt + f"\n\nPrevious attempt failed with error: {last_error}\n" @@ -279,6 +285,8 @@ async def generate_cypher( cypher = _sanitize_cypher(cypher) return cypher + except LatencyBudgetExceededError: + raise except Exception as exc: last_error = str(exc) logger.debug("Cypher generation attempt %d failed: %s", attempt + 1, exc) @@ -293,6 +301,7 @@ async def execute_cypher_retrieval( question: str, *, max_retries: int = 3, + ctx: Context | None = None, ) -> tuple[list[str], dict[str, dict]]: """Full text-to-cypher retrieval: generate -> validate -> execute -> parse. @@ -306,12 +315,16 @@ async def execute_cypher_retrieval( On any failure, returns empty results (silent degradation). """ - cypher = await generate_cypher(llm, question, max_retries=max_retries) + cypher = await generate_cypher(llm, question, max_retries=max_retries, ctx=ctx) if not cypher: return [], {} try: + if ctx is not None: + ctx.ensure_budget("Cypher execution") result = await graph_store.query_raw(cypher) + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Cypher execution failed: %s — query: %s", exc, cypher) return [], {} diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/entity_discovery.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/entity_discovery.py index f1d54786..7882903c 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/entity_discovery.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/entity_discovery.py @@ -8,6 +8,9 @@ import re from typing import Any +from graphrag_sdk.core.context import Context +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError + logger = logging.getLogger(__name__) _ENUMERATION_RE = re.compile( @@ -26,6 +29,7 @@ async def search_relates_edges( vector_store: Any, query_vector: list[float], rel_top_k: int = 15, + ctx: Context | None = None, ) -> tuple[list[tuple[str, float]], dict[str, dict]]: """Search RELATES edges by vector similarity. @@ -37,6 +41,8 @@ async def search_relates_edges( fact_strings: list[tuple[str, float]] = [] entities: dict[str, dict] = {} try: + if ctx is not None: + ctx.ensure_budget("RELATES vector search") results = await vector_store.search_relationships(query_vector, top_k=rel_top_k) for rel in results: src = rel.get("src_name", "") @@ -58,6 +64,8 @@ async def search_relates_edges( tgt_id = tgt.strip().lower().replace(" ", "_") if tgt_id not in entities: entities[tgt_id] = {"name": tgt, "description": ""} + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("RELATES edge vector search failed: %s", exc) return fact_strings, entities @@ -68,6 +76,7 @@ async def discover_entities( vector_store: Any, llm_kw: list[str], all_keywords: list[str], + ctx: Context | None = None, ) -> tuple[dict[str, dict], dict[str, str]]: """2-path entity discovery. @@ -107,6 +116,8 @@ def _add(eid: str, info: dict, source: str) -> None: # so exact matches land at the head of `found` and survive the # downstream max_entities / result_assembly caps. try: + if ctx is not None: + ctx.ensure_budget("entity exact-name search") result = await graph_store.query_raw( "UNWIND $keywords AS kw " "CALL { " @@ -127,6 +138,8 @@ def _add(eid: str, info: dict, source: str) -> None: }, "cypher_exact", ) + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Entity exact-name search failed: %s", exc) @@ -134,6 +147,8 @@ def _add(eid: str, info: dict, source: str) -> None: # Excludes exact matches (already added in pass a1) so the quota # isn't spent re-fetching them. try: + if ctx is not None: + ctx.ensure_budget("entity contains search") result = await graph_store.query_raw( "UNWIND $keywords AS kw " "CALL { " @@ -157,17 +172,23 @@ def _add(eid: str, info: dict, source: str) -> None: }, "cypher_contains", ) + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Entity CONTAINS search failed: %s", exc) # Path b: Fulltext search on entity index for kw in all_keywords[:6]: try: + if ctx is not None: + ctx.ensure_budget("entity fulltext search") ft_ents = await vector_store.fulltext_search_entities(kw, top_k=3) for ent in ft_ents: eid = ent.get("id", "") if eid: try: + if ctx is not None: + ctx.ensure_budget("entity detail fetch") detail = await graph_store.query_raw( "MATCH (e:__Entity__ {id: $eid}) " "RETURN e.name AS name, e.description AS desc", @@ -183,9 +204,13 @@ def _add(eid: str, info: dict, source: str) -> None: }, "fulltext", ) + except LatencyBudgetExceededError: + raise except Exception: logger.debug("Entity detail fetch failed for %s", eid, exc_info=True) _add(eid, {"name": "", "description": ""}, "fulltext") + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Entity fulltext search failed for '%s': %s", kw, exc) @@ -197,6 +222,7 @@ async def expand_sibling_entities( found_entities: dict[str, dict], found_sources: dict[str, str], max_siblings: int = 20, + ctx: Context | None = None, ) -> int: """Expand discovered entities by finding graph siblings. @@ -216,6 +242,8 @@ async def expand_sibling_entities( added = 0 try: + if ctx is not None: + ctx.ensure_budget("sibling entity expansion") result = await graph_store.query_raw( "MATCH (e:__Entity__) WHERE e.id IN $found_ids " "MATCH (e)-[]-(hub:__Entity__) " @@ -238,6 +266,8 @@ async def expand_sibling_entities( } found_sources[eid] = "sibling_expansion" added += 1 + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Sibling entity expansion failed: %s", exc) diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/local.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/local.py index 08944260..e433f46f 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/local.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/local.py @@ -52,9 +52,11 @@ async def _execute( top_k = kwargs.get("top_k", self._top_k) # Step 1: Embed the query + ctx.ensure_budget("LocalRetrieval query embedding") query_vector = await self._embedder.aembed_query(query) # Step 2: Vector search for matching chunks + ctx.ensure_budget("LocalRetrieval chunk vector search") chunk_results = await self._vector.search_chunks( query_vector=query_vector, top_k=top_k, @@ -74,6 +76,7 @@ async def _execute( if self._include_entities and self._graph: # 1-hop: find entities extracted from this chunk + ctx.ensure_budget("LocalRetrieval entity expansion") entities = await self._graph.get_connected_entities( chunk_id=chunk.get("id", ""), ) diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py index d7bcbd9b..a63c45aa 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py @@ -11,6 +11,7 @@ from typing import Any from graphrag_sdk.core.context import Context +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError from graphrag_sdk.core.models import ( RawSearchResult, RetrieverResult, @@ -186,21 +187,26 @@ async def _execute( **kwargs: Any, ) -> RawSearchResult: # 1. Extract keywords - simple_kw, llm_kw = await self._extract_keywords(query) + ctx.ensure_budget("MultiPath keyword extraction") + simple_kw, llm_kw = await self._extract_keywords(query, ctx) all_keywords = llm_kw[:8] + simple_kw ctx.log(f"MultiPath [1/9]: {len(all_keywords)} keywords extracted") # 2. Embed question only + ctx.ensure_budget("MultiPath question embedding") query_vector = await self._embedder.aembed_query(query) # 3. RELATES vector search + Text-to-Cypher (parallel when enabled) + ctx.ensure_budget("MultiPath relationship vector search") if self._enable_cypher: results = await asyncio.gather( - search_relates_edges(self._vector, query_vector, self._rel_top_k), - execute_cypher_retrieval(self._graph, self._llm, query), + search_relates_edges(self._vector, query_vector, self._rel_top_k, ctx=ctx), + execute_cypher_retrieval(self._graph, self._llm, query, ctx=ctx), return_exceptions=True, ) # Unpack RELATES results + if isinstance(results[0], LatencyBudgetExceededError): + raise results[0] if isinstance(results[0], BaseException): fact_strings_scored, rel_entities = [], {} else: @@ -208,11 +214,13 @@ async def _execute( # Unpack Cypher results cypher_facts: list[str] = [] cypher_entities: dict[str, dict] = {} + if isinstance(results[1], LatencyBudgetExceededError): + raise results[1] if not isinstance(results[1], BaseException): cypher_facts, cypher_entities = results[1] else: fact_strings_scored, rel_entities = await search_relates_edges( - self._vector, query_vector, self._rel_top_k + self._vector, query_vector, self._rel_top_k, ctx=ctx ) cypher_facts, cypher_entities = [], {} @@ -227,8 +235,9 @@ async def _execute( ) # 4. Entity discovery (2 paths) + merge rel_entities + cypher_entities + ctx.ensure_budget("MultiPath entity discovery") found_entities, entity_sources = await discover_entities( - self._graph, self._vector, llm_kw, all_keywords + self._graph, self._vector, llm_kw, all_keywords, ctx=ctx ) for eid, einfo in rel_entities.items(): if eid not in found_entities: @@ -242,7 +251,13 @@ async def _execute( # 4b. Sibling expansion for enumeration queries if is_enumeration_query(query): - n_siblings = await expand_sibling_entities(self._graph, found_entities, entity_sources) + ctx.ensure_budget("MultiPath sibling entity expansion") + n_siblings = await expand_sibling_entities( + self._graph, + found_entities, + entity_sources, + ctx=ctx, + ) if n_siblings: ctx.log( f"MultiPath [4b/9]: +{n_siblings} sibling entities " @@ -250,13 +265,15 @@ async def _execute( ) # 5. Relationship expansion + ctx.ensure_budget("MultiPath relationship expansion") entity_list = list(found_entities.items())[: self._max_entities] relationship_strings = await expand_relationships( - self._graph, entity_list, self._max_relationships + self._graph, entity_list, self._max_relationships, ctx=ctx ) ctx.log(f"MultiPath [5/9]: {len(relationship_strings)} relationships") # 6. Chunk retrieval (4 paths) + ctx.ensure_budget("MultiPath chunk retrieval") candidate_chunks, chunk_sources, chunk_embeddings = await retrieve_chunks( self._vector, self._graph, @@ -265,6 +282,7 @@ async def _execute( llm_kw, simple_kw, entity_list, + ctx=ctx, ) ctx.log( f"MultiPath [6/9]: {len(candidate_chunks)} candidate chunks " @@ -272,15 +290,22 @@ async def _execute( ) # 7. Source document names - chunk_doc_map = await fetch_chunk_documents(self._graph, list(candidate_chunks.keys())) + ctx.ensure_budget("MultiPath source document fetch") + chunk_doc_map = await fetch_chunk_documents( + self._graph, + list(candidate_chunks.keys()), + ctx=ctx, + ) # 8. Cosine rerank (uses stored embeddings when available) + ctx.ensure_budget("MultiPath chunk reranking") source_passages = await rerank_chunks( self._embedder, query_vector, candidate_chunks, self._chunk_top_k, stored_embeddings=chunk_embeddings, + ctx=ctx, ) # Tag with source docs @@ -295,6 +320,7 @@ async def _execute( ctx.log(f"MultiPath [8/9]: {len(source_passages)} passages after rerank") # 9. Detect question type + assemble + ctx.ensure_budget("MultiPath result assembly") q_type_hint = detect_question_type(query) return assemble_raw_result( entity_list, @@ -322,13 +348,19 @@ def _format(self, raw: RawSearchResult) -> RetrieverResult: # -- Internal: keyword extraction (stays in orchestrator) -- - async def _extract_keywords(self, query: str) -> tuple[list[str], list[str]]: + async def _extract_keywords( + self, + query: str, + ctx: Context | None = None, + ) -> tuple[list[str], list[str]]: """Extract simple + LLM-based keywords from the query.""" words = re.sub(r"[?.!,;:'\"\-()\[\]]", " ", query.lower()).split() simple = [w for w in words if w not in self._STOP_WORDS and len(w) > 2][:12] llm_kw: list[str] = [] try: + if ctx is not None: + ctx.ensure_budget("MultiPath keyword extraction LLM call") response = await self._llm.ainvoke( "Extract ALL proper nouns, character names, person names, place names, " "book titles, and specific terms from this question. " @@ -340,6 +372,8 @@ async def _extract_keywords(self, query: str) -> tuple[list[str], list[str]]: for k in response.content.split(",") if k.strip() and len(k.strip()) > 1 ] + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug(f"LLM keyword extraction failed: {exc}") diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/relationship_expansion.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/relationship_expansion.py index 57264be8..ff83e4d6 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/relationship_expansion.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/relationship_expansion.py @@ -6,6 +6,9 @@ import logging from typing import Any +from graphrag_sdk.core.context import Context +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError + logger = logging.getLogger(__name__) @@ -13,6 +16,7 @@ async def expand_relationships( graph_store: Any, entity_list: list[tuple[str, dict]], max_relationships: int = 20, + ctx: Context | None = None, ) -> list[str]: """1-hop + 2-hop relationship expansion from top entities. @@ -30,6 +34,8 @@ async def expand_relationships( eids_1hop = [eid for eid, _ in entity_list[:15]] if eids_1hop: try: + if ctx is not None: + ctx.ensure_budget("1-hop relationship expansion") result = await graph_store.query_raw( "UNWIND $eids AS eid " "MATCH (a:__Entity__ {id: eid})-[r:RELATES]->(b:__Entity__) " @@ -50,6 +56,8 @@ async def expand_relationships( if fact: line += f": {fact}" relationship_strings.append(line) + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Relationship expansion failed: %s", exc) @@ -57,6 +65,8 @@ async def expand_relationships( eids_2hop = [eid for eid, _ in entity_list[:5]] if eids_2hop: try: + if ctx is not None: + ctx.ensure_budget("2-hop relationship expansion") result = await graph_store.query_raw( "UNWIND $eids AS eid " "MATCH (a:__Entity__ {id: eid})-[r1:RELATES]->(b:__Entity__)" @@ -77,6 +87,8 @@ async def expand_relationships( seen.add(key) line = f"{a_name} —[{r1_type}]→ {b_name} —[{r2_type}]→ {c_name}" relationship_strings.append(line) + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Relationship expansion failed: %s", exc) diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/result_assembly.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/result_assembly.py index 6b2ad9a4..793c2af4 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/result_assembly.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/result_assembly.py @@ -8,6 +8,8 @@ import numpy as np +from graphrag_sdk.core.context import Context +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError from graphrag_sdk.core.models import RawSearchResult from graphrag_sdk.core.providers import Embedder @@ -29,6 +31,7 @@ async def rerank_chunks( candidate_chunks: dict[str, str], chunk_top_k: int = 15, stored_embeddings: dict[str, list[float]] | None = None, + ctx: Context | None = None, ) -> list[str]: """Rank candidates by cosine similarity, take top_k. @@ -60,10 +63,14 @@ async def rerank_chunks( # Fallback: re-embed all candidates (coverage too low for fast path) try: + if ctx is not None: + ctx.ensure_budget("chunk reranking embedding") chunk_vectors = await embedder.aembed_documents(chunk_texts) scored = [(i, cosine_sim(query_vector, cvec)) for i, cvec in enumerate(chunk_vectors)] scored.sort(key=lambda x: x[1], reverse=True) return [chunk_texts[i] for i, _ in scored[:chunk_top_k]] + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Chunk reranking failed, returning unranked: %s", exc) return chunk_texts[:chunk_top_k] diff --git a/graphrag_sdk/tests/test_chunk_retrieval.py b/graphrag_sdk/tests/test_chunk_retrieval.py index 2b59fb49..a731cf30 100644 --- a/graphrag_sdk/tests/test_chunk_retrieval.py +++ b/graphrag_sdk/tests/test_chunk_retrieval.py @@ -5,7 +5,12 @@ import pytest -from graphrag_sdk.retrieval.strategies.chunk_retrieval import fetch_chunk_documents +from graphrag_sdk.core.context import Context +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError +from graphrag_sdk.retrieval.strategies.chunk_retrieval import ( + fetch_chunk_documents, + retrieve_chunks, +) def _graph_with_rows(rows: list[list]) -> MagicMock: @@ -90,3 +95,59 @@ async def test_query_failure_returns_empty_mapping(self): graph.query_raw = AsyncMock(side_effect=RuntimeError("graph down")) mapping = await fetch_chunk_documents(graph, ["chunk-1"]) assert mapping == {} + + async def test_budget_error_propagates(self): + graph = MagicMock() + graph.query_raw = AsyncMock( + side_effect=LatencyBudgetExceededError("budget exhausted") + ) + + with pytest.raises(LatencyBudgetExceededError, match="budget exhausted"): + await fetch_chunk_documents(graph, ["chunk-1"]) + + +class TestRetrieveChunks: + async def test_budget_error_propagates_from_fulltext_path(self): + vector = MagicMock() + vector.fulltext_search_chunks = AsyncMock( + side_effect=LatencyBudgetExceededError("budget exhausted") + ) + graph = MagicMock() + + with pytest.raises(LatencyBudgetExceededError, match="budget exhausted"): + await retrieve_chunks( + vector, + graph, + "query", + [0.1], + [], + [], + [], + ) + + async def test_budget_checked_between_fulltext_queries(self): + ctx = Context(latency_budget_ms=1000.0) + + async def first_fulltext_exhausts_budget(*args, **kwargs): + ctx.latency_budget_ms = 0.0 + return [] + + vector = MagicMock() + vector.fulltext_search_chunks = AsyncMock(side_effect=first_fulltext_exhausts_budget) + vector.search_chunks = AsyncMock(return_value=[]) + graph = MagicMock() + + with pytest.raises(LatencyBudgetExceededError, match="chunk fulltext search"): + await retrieve_chunks( + vector, + graph, + "query", + [0.1], + ["second-query"], + [], + [], + ctx=ctx, + ) + + assert vector.fulltext_search_chunks.await_count == 1 + vector.search_chunks.assert_not_awaited() diff --git a/graphrag_sdk/tests/test_context.py b/graphrag_sdk/tests/test_context.py index 921edd9b..70edbecf 100644 --- a/graphrag_sdk/tests/test_context.py +++ b/graphrag_sdk/tests/test_context.py @@ -6,6 +6,7 @@ import pytest from graphrag_sdk.core.context import Context +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError class TestContextCreation: @@ -66,6 +67,15 @@ def test_budget_exceeded(self): assert ctx.budget_exceeded is True assert ctx.remaining_budget_ms == 0.0 + def test_ensure_budget_raises_when_exhausted(self): + ctx = Context(latency_budget_ms=0.0) + with pytest.raises(LatencyBudgetExceededError, match="before test operation"): + ctx.ensure_budget("test operation") + + def test_ensure_budget_allows_operation_without_budget(self): + ctx = Context() + ctx.ensure_budget("test operation") + class TestChildContext: def test_inherits_parent(self): diff --git a/graphrag_sdk/tests/test_cypher_generation.py b/graphrag_sdk/tests/test_cypher_generation.py index 58694158..9c145e80 100644 --- a/graphrag_sdk/tests/test_cypher_generation.py +++ b/graphrag_sdk/tests/test_cypher_generation.py @@ -136,6 +136,55 @@ def test_removes_path_assignment(self): class TestExecuteCypherRetrieval: + async def test_budget_exhaustion_propagates_before_generation(self): + """Latency budget exhaustion should not be swallowed as generation failure.""" + from unittest.mock import AsyncMock, MagicMock + + from graphrag_sdk.core.context import Context + from graphrag_sdk.core.exceptions import LatencyBudgetExceededError + from graphrag_sdk.retrieval.strategies.cypher_generation import ( + execute_cypher_retrieval, + ) + + mock_llm = MagicMock() + mock_llm.ainvoke = AsyncMock() + mock_graph = MagicMock() + + with pytest.raises(LatencyBudgetExceededError, match="Cypher generation LLM call"): + await execute_cypher_retrieval( + mock_graph, + mock_llm, + "test?", + ctx=Context(latency_budget_ms=0.0), + ) + mock_llm.ainvoke.assert_not_awaited() + + async def test_budget_exhaustion_propagates_before_execution(self): + """Budget expiry during generation should block the Cypher DB query.""" + from unittest.mock import AsyncMock, MagicMock + + from graphrag_sdk.core.context import Context + from graphrag_sdk.core.exceptions import LatencyBudgetExceededError + from graphrag_sdk.core.models import LLMResponse + from graphrag_sdk.retrieval.strategies.cypher_generation import ( + execute_cypher_retrieval, + ) + + ctx = Context(latency_budget_ms=1000.0) + + async def generate_and_exhaust(*args, **kwargs): + ctx.latency_budget_ms = 0.0 + return LLMResponse(content="```cypher\nMATCH (n:Person) RETURN n.name LIMIT 10\n```") + + mock_llm = MagicMock() + mock_llm.ainvoke = AsyncMock(side_effect=generate_and_exhaust) + mock_graph = MagicMock() + mock_graph.query_raw = AsyncMock() + + with pytest.raises(LatencyBudgetExceededError, match="Cypher execution"): + await execute_cypher_retrieval(mock_graph, mock_llm, "test?", ctx=ctx) + mock_graph.query_raw.assert_not_awaited() + async def test_returns_empty_on_generation_failure(self): """When LLM returns garbage, should return empty results.""" from unittest.mock import AsyncMock, MagicMock diff --git a/graphrag_sdk/tests/test_exceptions.py b/graphrag_sdk/tests/test_exceptions.py index 56358156..64e98683 100644 --- a/graphrag_sdk/tests/test_exceptions.py +++ b/graphrag_sdk/tests/test_exceptions.py @@ -13,6 +13,7 @@ GraphRAGError, IndexError_, IngestionError, + LatencyBudgetExceededError, LLMError, LLMTimeoutError, LoaderError, @@ -27,6 +28,10 @@ def test_base_exception(self): with pytest.raises(GraphRAGError): raise GraphRAGError("base error") + def test_latency_budget_error_is_graphrag_error(self): + with pytest.raises(GraphRAGError): + raise LatencyBudgetExceededError("budget exceeded") + def test_llm_error_is_graphrag_error(self): with pytest.raises(GraphRAGError): raise LLMError("llm failed") diff --git a/graphrag_sdk/tests/test_facade.py b/graphrag_sdk/tests/test_facade.py index 38fb4376..b73d6b91 100644 --- a/graphrag_sdk/tests/test_facade.py +++ b/graphrag_sdk/tests/test_facade.py @@ -9,7 +9,11 @@ from graphrag_sdk.api.main import GraphRAG from graphrag_sdk.core.connection import ConnectionConfig, FalkorDBConnection from graphrag_sdk.core.context import Context -from graphrag_sdk.core.exceptions import ConfigError, DocumentNotFoundError +from graphrag_sdk.core.exceptions import ( + ConfigError, + DocumentNotFoundError, + LatencyBudgetExceededError, +) from graphrag_sdk.core.models import ( ApplyChangesResult, ChatMessage, @@ -444,6 +448,87 @@ async def rerank(self, query, result, ctx): assert result.items[0].content == "B" assert llm._call_index == 0 + async def test_retrieve_checks_budget_before_config_embedder_probe( + self, mock_conn, embedder + ): + llm = MockLLM(responses=["should not be called"]) + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + ctx = Context(latency_budget_ms=1000.0) + + async def config_query_exhausts_budget(*args, **kwargs): + ctx.latency_budget_ms = 0.0 + result = MagicMock() + result.result_set = [] + return result + + g._graph_store.query_raw = AsyncMock(side_effect=config_query_exhausts_budget) + mock_strategy = MagicMock(spec=RetrievalStrategy) + mock_strategy.search = AsyncMock(return_value=RetrieverResult()) + g._retrieval_strategy = mock_strategy + + with pytest.raises(LatencyBudgetExceededError, match="graph config embedder probe"): + await g.retrieve("test", ctx=ctx) + + assert embedder.call_count == 0 + mock_strategy.search.assert_not_awaited() + + async def test_retrieve_checks_budget_before_config_query(self, mock_conn, embedder): + llm = MockLLM(responses=["should not be called"]) + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + g._graph_store.query_raw = AsyncMock() + mock_strategy = MagicMock(spec=RetrievalStrategy) + mock_strategy.search = AsyncMock(return_value=RetrieverResult()) + g._retrieval_strategy = mock_strategy + ctx = Context(latency_budget_ms=1000.0) + + def exhaust_budget(operation: str) -> None: + if operation == "graph config query": + raise LatencyBudgetExceededError("budget exhausted before config query") + + ctx.ensure_budget = exhaust_budget # type: ignore[method-assign] + with pytest.raises(LatencyBudgetExceededError, match="config query"): + await g.retrieve("test", ctx=ctx) + + g._graph_store.query_raw.assert_not_awaited() + mock_strategy.search.assert_not_awaited() + + async def test_retrieve_propagates_budget_error_from_config_query( + self, mock_conn, embedder + ): + llm = MockLLM(responses=["should not be called"]) + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + g._graph_store.query_raw = AsyncMock( + side_effect=LatencyBudgetExceededError("budget exhausted in config query") + ) + mock_strategy = MagicMock(spec=RetrievalStrategy) + mock_strategy.search = AsyncMock(return_value=RetrieverResult()) + g._retrieval_strategy = mock_strategy + + with pytest.raises(LatencyBudgetExceededError, match="config query"): + await g.retrieve("test", ctx=Context(latency_budget_ms=1000.0)) + + mock_strategy.search.assert_not_awaited() + + async def test_retrieve_propagates_budget_error_from_config_probe( + self, mock_conn, embedder + ): + llm = MockLLM(responses=["should not be called"]) + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + result = MagicMock() + result.result_set = [] + g._graph_store.query_raw = AsyncMock(return_value=result) + embedder.aembed_query = AsyncMock( + side_effect=LatencyBudgetExceededError("budget exhausted in config probe") + ) + mock_strategy = MagicMock(spec=RetrievalStrategy) + mock_strategy.search = AsyncMock(return_value=RetrieverResult()) + g._retrieval_strategy = mock_strategy + + with pytest.raises(LatencyBudgetExceededError, match="config probe"): + await g.retrieve("test", ctx=Context(latency_budget_ms=1000.0)) + + mock_strategy.search.assert_not_awaited() + class TestGraphRAGCompletion: async def test_completion_basic(self, mock_conn, embedder): @@ -727,6 +812,19 @@ async def test_completion_rewrite_question_enabled(self, mock_conn, embedder): assert result.metadata["retrieval_query"] == "Where did Jane Doe go to college?" assert result.answer == "She attended Stanford University." + async def test_question_rewrite_budget_exhaustion_propagates(self, mock_conn, embedder): + llm = MockLLM(responses=["unused"]) + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + + with pytest.raises(LatencyBudgetExceededError, match="question rewrite LLM call"): + await g._rewrite_question_with_history( + "Where did she go?", + [ChatMessage(role="user", content="Who is Jane?")], + ctx=Context(latency_budget_ms=0.0), + ) + + assert llm._call_index == 0 + async def test_completion_rewrite_fallback_on_empty(self, mock_conn, embedder): """If the rewrite LLM returns empty, fall back to the original question.""" llm = MockLLM(responses=["", "Some answer."]) @@ -776,6 +874,40 @@ async def test_completion_custom_prompt_template_with_history(self, mock_conn, e assert "SRC" in final.content assert "What is it?" in final.content + async def test_completion_checks_budget_before_retrieval(self, mock_conn, embedder): + llm = MockLLM(responses=["unused"]) + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + mock_strategy = MagicMock(spec=RetrievalStrategy) + mock_strategy.search = AsyncMock( + return_value=RetrieverResult(items=[RetrieverResultItem(content="c")]) + ) + g._retrieval_strategy = mock_strategy + ctx = Context(latency_budget_ms=0.0) + + with pytest.raises(LatencyBudgetExceededError, match="completion retrieval"): + await g.completion("q?", ctx=ctx) + + mock_strategy.search.assert_not_awaited() + assert llm._call_index == 0 + + async def test_completion_checks_budget_before_final_llm(self, mock_conn, embedder): + llm = MockLLM(responses=["unused"]) + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + mock_strategy = MagicMock(spec=RetrievalStrategy) + + async def search_and_exhaust(question, ctx): + ctx.latency_budget_ms = 0.0 + return RetrieverResult(items=[RetrieverResultItem(content="c")]) + + mock_strategy.search = AsyncMock(side_effect=search_and_exhaust) + g._retrieval_strategy = mock_strategy + + with pytest.raises(LatencyBudgetExceededError, match="completion LLM call"): + await g.completion("q?", ctx=Context(latency_budget_ms=1000.0)) + + mock_strategy.search.assert_awaited_once() + assert llm._call_index == 0 + class TestGraphRAGCompletionInjectionDefenses: """Verifies the S4 mitigations: context delimiters + close-tag neutralization.""" diff --git a/graphrag_sdk/tests/test_multi_path_retrieval.py b/graphrag_sdk/tests/test_multi_path_retrieval.py index b535730e..8c162cd5 100644 --- a/graphrag_sdk/tests/test_multi_path_retrieval.py +++ b/graphrag_sdk/tests/test_multi_path_retrieval.py @@ -6,6 +6,7 @@ import pytest from graphrag_sdk.core.context import Context +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError from graphrag_sdk.core.models import ( LLMResponse, RawSearchResult, @@ -14,8 +15,10 @@ from graphrag_sdk.retrieval.strategies.entity_discovery import ( expand_sibling_entities, is_enumeration_query, + search_relates_edges, ) from graphrag_sdk.retrieval.strategies.multi_path import MultiPathRetrieval +from graphrag_sdk.retrieval.strategies.result_assembly import rerank_chunks from .conftest import MockEmbedder, MockLLM @@ -80,6 +83,85 @@ async def test_question_embedding(self, strategy, mp_embedder): await strategy.search("Who is Alice?") assert mp_embedder.call_count > 0 + async def test_latency_budget_checked_before_first_phase( + self, strategy, mp_embedder, mp_llm + ): + ctx = Context(latency_budget_ms=0.0) + + with pytest.raises(LatencyBudgetExceededError, match="keyword extraction"): + await strategy.search("Who is Alice?", ctx) + + assert mp_llm._call_index == 0 + assert mp_embedder.call_count == 0 + + async def test_relates_search_budget_error_propagates(self): + vector = MagicMock() + vector.search_relationships = AsyncMock( + side_effect=LatencyBudgetExceededError("budget exhausted") + ) + + with pytest.raises(LatencyBudgetExceededError, match="budget exhausted"): + await search_relates_edges(vector, [0.1], rel_top_k=3) + + async def test_relates_search_checks_budget_before_vector_call(self): + vector = MagicMock() + vector.search_relationships = AsyncMock(return_value=[]) + + with pytest.raises(LatencyBudgetExceededError, match="RELATES vector search"): + await search_relates_edges( + vector, + [0.1], + rel_top_k=3, + ctx=Context(latency_budget_ms=0.0), + ) + + vector.search_relationships.assert_not_awaited() + + async def test_rerank_budget_error_propagates(self, mp_embedder): + mp_embedder.aembed_documents = AsyncMock( + side_effect=LatencyBudgetExceededError("budget exhausted") + ) + + with pytest.raises(LatencyBudgetExceededError, match="budget exhausted"): + await rerank_chunks( + mp_embedder, + [0.1], + {"c1": "text"}, + stored_embeddings={}, + ) + + async def test_latency_budget_checked_before_question_embedding( + self, strategy, mp_embedder, mp_llm + ): + ctx = Context(latency_budget_ms=1000.0) + + def exhaust_budget(operation: str) -> None: + if operation == "MultiPath question embedding": + raise LatencyBudgetExceededError("budget exhausted before embedding") + + ctx.ensure_budget = exhaust_budget # type: ignore[method-assign] + with pytest.raises(LatencyBudgetExceededError, match="before embedding"): + await strategy.search("Who is Alice?", ctx) + + assert mp_llm._call_index == 1 + assert mp_embedder.call_count == 0 + + async def test_latency_budget_propagates_from_keyword_llm( + self, strategy, mp_embedder, mp_llm + ): + ctx = Context(latency_budget_ms=1000.0) + + def exhaust_budget(operation: str) -> None: + if operation == "MultiPath keyword extraction LLM call": + raise LatencyBudgetExceededError("budget exhausted before keyword LLM") + + ctx.ensure_budget = exhaust_budget # type: ignore[method-assign] + with pytest.raises(LatencyBudgetExceededError, match="keyword LLM"): + await strategy.search("Who is Alice?", ctx) + + assert mp_llm._call_index == 0 + assert mp_embedder.call_count == 0 + async def test_relates_edge_vector_search(self, mp_graph_store, mp_vector_store, mp_embedder, mp_llm): """RELATES edge vector search should be called and return facts + entities.""" mp_vector_store.search_relationships = AsyncMock(return_value=[ diff --git a/graphrag_sdk/tests/test_retrieval.py b/graphrag_sdk/tests/test_retrieval.py index 45389fa6..74dfa836 100644 --- a/graphrag_sdk/tests/test_retrieval.py +++ b/graphrag_sdk/tests/test_retrieval.py @@ -7,7 +7,7 @@ import pytest from graphrag_sdk.core.context import Context -from graphrag_sdk.core.exceptions import RetrieverError +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError, RetrieverError from graphrag_sdk.core.models import RawSearchResult, RetrieverResult, RetrieverResultItem from graphrag_sdk.retrieval.strategies.base import RetrievalStrategy from graphrag_sdk.retrieval.strategies.local import LocalRetrieval @@ -157,3 +157,45 @@ async def test_local_scores_in_items(self, ctx, local_strategy): async def test_local_chunk_id_in_metadata(self, ctx, local_strategy): result = await local_strategy.search("test", ctx) assert result.items[0].metadata["chunk_id"] == "chunk-1" + + async def test_local_budget_checked_before_embedding( + self, local_strategy, embedder, mock_vector_store + ): + with pytest.raises(LatencyBudgetExceededError, match="query embedding"): + await local_strategy.search("test", Context(latency_budget_ms=0.0)) + + assert embedder.call_count == 0 + mock_vector_store.search_chunks.assert_not_awaited() + + async def test_local_budget_checked_before_vector_search( + self, local_strategy, embedder, mock_vector_store + ): + ctx = Context(latency_budget_ms=1000.0) + + def exhaust_budget(operation: str) -> None: + if operation == "LocalRetrieval chunk vector search": + raise LatencyBudgetExceededError("budget exhausted before vector search") + + ctx.ensure_budget = exhaust_budget # type: ignore[method-assign] + with pytest.raises(LatencyBudgetExceededError, match="vector search"): + await local_strategy.search("test", ctx) + + assert embedder.call_count == 1 + mock_vector_store.search_chunks.assert_not_awaited() + + async def test_local_budget_checked_before_entity_expansion( + self, local_strategy, embedder, mock_vector_store, mock_graph_store + ): + ctx = Context(latency_budget_ms=1000.0) + + def exhaust_budget(operation: str) -> None: + if operation == "LocalRetrieval entity expansion": + raise LatencyBudgetExceededError("budget exhausted before entity expansion") + + ctx.ensure_budget = exhaust_budget # type: ignore[method-assign] + with pytest.raises(LatencyBudgetExceededError, match="entity expansion"): + await local_strategy.search("test", ctx) + + assert embedder.call_count == 1 + mock_vector_store.search_chunks.assert_awaited_once() + mock_graph_store.get_connected_entities.assert_not_awaited() From cd760dc3862f7a722a62e44ffd847c9444b461db Mon Sep 17 00:00:00 2001 From: Naseem Ali <34807727+Naseem77@users.noreply.github.com> Date: Mon, 18 May 2026 17:55:31 +0300 Subject: [PATCH 5/8] Mark real FalkorDB integration tests Add an explicit integration marker, run marked real-FalkorDB tests in CI, document docker-compose usage, and expose the FalkorDB browser port in local compose.\n\nCo-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .github/workflows/ci.yml | 8 +++----- CONTRIBUTING.md | 12 ++++++++++-- docker-compose.yml | 2 ++ graphrag_sdk/pyproject.toml | 3 +++ graphrag_sdk/tests/test_integration.py | 1 + 5 files changed, 19 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1c357263..1fc96e92 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -85,10 +85,8 @@ jobs: done echo "FalkorDB did not become reachable within 20 attempts" >&2 exit 1 - - name: Run incremental-update integration tests - # Only the v1.1.0 invariant suite needs the live database; the - # rest of test_integration.py is import-smoke that already runs - # in the unit job above. + - name: Run real-FalkorDB integration tests run: > python -m pytest -v - tests/test_integration.py::TestIncrementalUpdateInvariants + -m integration + tests/test_integration.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3f670cd3..e01da8a4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -24,10 +24,10 @@ pip install -e "graphrag_sdk[dev]" You will need a running FalkorDB instance for integration work. The easiest way is via Docker: ```bash -docker run -p 6379:6379 falkordb/falkordb +docker compose up -d falkordb ``` -This exposes FalkorDB on the default Redis port (6379). No additional configuration is required for local development. +This exposes FalkorDB on the default Redis port (6379) and the browser UI on port 3000. --- @@ -41,6 +41,14 @@ python -m pytest graphrag_sdk/tests/ -q There are 558 tests covering the ingestion pipeline, the GraphRAG facade, extraction strategies, resolution strategies, retrieval strategies, storage layers, and utilities. All tests use mock providers, so no live LLM or database connection is needed to run them. +Run real-FalkorDB integration tests with: + +```bash +RUN_INTEGRATION=1 python -m pytest graphrag_sdk/tests/test_integration.py -m integration -q +``` + +These tests use scripted local providers, not live LLM APIs. + --- ## 3. Code Style diff --git a/docker-compose.yml b/docker-compose.yml index 867bddf0..258720f1 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,6 +3,7 @@ services: image: falkordb/falkordb:v4.18.0 ports: - "6379:6379" + - "3000:3000" volumes: - falkordb_data:/data healthcheck: @@ -10,6 +11,7 @@ services: interval: 5s timeout: 3s retries: 5 + start_period: 5s volumes: falkordb_data: diff --git a/graphrag_sdk/pyproject.toml b/graphrag_sdk/pyproject.toml index be8597fd..1d45c7dd 100644 --- a/graphrag_sdk/pyproject.toml +++ b/graphrag_sdk/pyproject.toml @@ -97,3 +97,6 @@ plugins = ["pydantic.mypy"] [tool.pytest.ini_options] asyncio_mode = "auto" testpaths = ["tests"] +markers = [ + "integration: tests that require a live FalkorDB instance", +] diff --git a/graphrag_sdk/tests/test_integration.py b/graphrag_sdk/tests/test_integration.py index 45b0102f..222437e7 100644 --- a/graphrag_sdk/tests/test_integration.py +++ b/graphrag_sdk/tests/test_integration.py @@ -237,6 +237,7 @@ async def _entity_count(rag, name: str) -> int: @pytest.mark.asyncio +@pytest.mark.integration class TestIncrementalUpdateInvariants: """v1.1.0 load-bearing correctness: scoped orphan cleanup preserves entities shared across documents. Run against real FalkorDB. From 8002adaa07cb0b6a2fd42f73395984e0692b1cc2 Mon Sep 17 00:00:00 2001 From: Naseem Ali <34807727+Naseem77@users.noreply.github.com> Date: Mon, 18 May 2026 17:57:52 +0300 Subject: [PATCH 6/8] Tighten release automation Validate Python distributions before trusted PyPI publishing, upload release artifacts, enable manual docs deploys, and add Dependabot coverage for actions and Python dependencies.\n\nCo-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .github/dependabot.yml | 19 +++++++++++++++++++ .github/workflows/docs.yml | 2 ++ .github/workflows/pypi-publish.yaml | 7 +++++++ 3 files changed, 28 insertions(+) create mode 100644 .github/dependabot.yml diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..7e6c18eb --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,19 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + groups: + github-actions: + patterns: + - "*" + + - package-ecosystem: "pip" + directory: "/graphrag_sdk" + schedule: + interval: "weekly" + groups: + python-dependencies: + patterns: + - "*" diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index a75d2ea0..ffcac4d7 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -1,11 +1,13 @@ name: Deploy Docs on: + workflow_dispatch: push: branches: [main, staging] paths: - "docs/**" - "mkdocs.yml" + - "README.md" permissions: contents: read diff --git a/.github/workflows/pypi-publish.yaml b/.github/workflows/pypi-publish.yaml index fa83c6ef..9b0edad6 100644 --- a/.github/workflows/pypi-publish.yaml +++ b/.github/workflows/pypi-publish.yaml @@ -32,8 +32,15 @@ jobs: fi - run: pip install build + - run: pip install twine - run: python -m build + - run: twine check dist/* + + - uses: actions/upload-artifact@v4 + with: + name: python-package-distributions + path: graphrag_sdk/dist/ - uses: pypa/gh-action-pypi-publish@release/v1 with: From 9dfa72bc0916bbf1303b25945a599249b05c2510 Mon Sep 17 00:00:00 2001 From: Naseem Ali <34807727+Naseem77@users.noreply.github.com> Date: Mon, 18 May 2026 19:57:30 +0300 Subject: [PATCH 7/8] Address facade context manager review warning Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- graphrag_sdk/tests/test_facade.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/graphrag_sdk/tests/test_facade.py b/graphrag_sdk/tests/test_facade.py index b73d6b91..35d766d0 100644 --- a/graphrag_sdk/tests/test_facade.py +++ b/graphrag_sdk/tests/test_facade.py @@ -120,12 +120,13 @@ async def test_async_context_manager_preserves_inner_exception_on_close_failure( mock_conn.close = AsyncMock(side_effect=RuntimeError("close failed")) g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) - with caplog.at_level("WARNING", logger="graphrag_sdk.api.main"): - with pytest.raises(ValueError, match="inner failure"): - async with g: - raise ValueError("inner failure") - - mock_conn.close.assert_awaited_once() + try: + with caplog.at_level("WARNING", logger="graphrag_sdk.api.main"): + with pytest.raises(ValueError, match="inner failure"): + async with g: + raise ValueError("inner failure") + finally: + mock_conn.close.assert_awaited_once() assert "Error closing connection during __aexit__" in caplog.text async def test_async_context_manager_raises_close_failure_without_inner_exception( From 3eca0b45c126089e0efd47bd57feadd6fd58bbdf Mon Sep 17 00:00:00 2001 From: Naseem Ali <34807727+Naseem77@users.noreply.github.com> Date: Mon, 18 May 2026 20:14:00 +0300 Subject: [PATCH 8/8] Address remaining facade test reachability warning Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- graphrag_sdk/tests/test_facade.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphrag_sdk/tests/test_facade.py b/graphrag_sdk/tests/test_facade.py index 35d766d0..0333ad2d 100644 --- a/graphrag_sdk/tests/test_facade.py +++ b/graphrag_sdk/tests/test_facade.py @@ -127,7 +127,7 @@ async def test_async_context_manager_preserves_inner_exception_on_close_failure( raise ValueError("inner failure") finally: mock_conn.close.assert_awaited_once() - assert "Error closing connection during __aexit__" in caplog.text + assert "Error closing connection during __aexit__" in caplog.text async def test_async_context_manager_raises_close_failure_without_inner_exception( self, mock_conn, embedder, llm