diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..7e6c18eb --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,19 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + groups: + github-actions: + patterns: + - "*" + + - package-ecosystem: "pip" + directory: "/graphrag_sdk" + schedule: + interval: "weekly" + groups: + python-dependencies: + patterns: + - "*" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1c357263..1fc96e92 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -85,10 +85,8 @@ jobs: done echo "FalkorDB did not become reachable within 20 attempts" >&2 exit 1 - - name: Run incremental-update integration tests - # Only the v1.1.0 invariant suite needs the live database; the - # rest of test_integration.py is import-smoke that already runs - # in the unit job above. + - name: Run real-FalkorDB integration tests run: > python -m pytest -v - tests/test_integration.py::TestIncrementalUpdateInvariants + -m integration + tests/test_integration.py diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index a75d2ea0..ffcac4d7 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -1,11 +1,13 @@ name: Deploy Docs on: + workflow_dispatch: push: branches: [main, staging] paths: - "docs/**" - "mkdocs.yml" + - "README.md" permissions: contents: read diff --git a/.github/workflows/pypi-publish.yaml b/.github/workflows/pypi-publish.yaml index fa83c6ef..9b0edad6 100644 --- a/.github/workflows/pypi-publish.yaml +++ b/.github/workflows/pypi-publish.yaml @@ -32,8 +32,15 @@ jobs: fi - run: pip install build + - run: pip install twine - run: python -m build + - run: twine check dist/* + + - uses: actions/upload-artifact@v4 + with: + name: python-package-distributions + path: graphrag_sdk/dist/ - uses: pypa/gh-action-pypi-publish@release/v1 with: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3f670cd3..e01da8a4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -24,10 +24,10 @@ pip install -e "graphrag_sdk[dev]" You will need a running FalkorDB instance for integration work. The easiest way is via Docker: ```bash -docker run -p 6379:6379 falkordb/falkordb +docker compose up -d falkordb ``` -This exposes FalkorDB on the default Redis port (6379). No additional configuration is required for local development. +This exposes FalkorDB on the default Redis port (6379) and the browser UI on port 3000. --- @@ -41,6 +41,14 @@ python -m pytest graphrag_sdk/tests/ -q There are 558 tests covering the ingestion pipeline, the GraphRAG facade, extraction strategies, resolution strategies, retrieval strategies, storage layers, and utilities. All tests use mock providers, so no live LLM or database connection is needed to run them. +Run real-FalkorDB integration tests with: + +```bash +RUN_INTEGRATION=1 python -m pytest graphrag_sdk/tests/test_integration.py -m integration -q +``` + +These tests use scripted local providers, not live LLM APIs. + --- ## 3. Code Style diff --git a/docker-compose.yml b/docker-compose.yml index 867bddf0..63c0baf5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,6 +3,7 @@ services: image: falkordb/falkordb:v4.18.0 ports: - "6379:6379" + - "127.0.0.1:3000:3000" volumes: - falkordb_data:/data healthcheck: @@ -10,6 +11,7 @@ services: interval: 5s timeout: 3s retries: 5 + start_period: 5s volumes: falkordb_data: diff --git a/graphrag_sdk/pyproject.toml b/graphrag_sdk/pyproject.toml index be8597fd..1d45c7dd 100644 --- a/graphrag_sdk/pyproject.toml +++ b/graphrag_sdk/pyproject.toml @@ -97,3 +97,6 @@ plugins = ["pydantic.mypy"] [tool.pytest.ini_options] asyncio_mode = "auto" testpaths = ["tests"] +markers = [ + "integration: tests that require a live FalkorDB instance", +] diff --git a/graphrag_sdk/src/graphrag_sdk/__init__.py b/graphrag_sdk/src/graphrag_sdk/__init__.py index f66ed70a..2b90d246 100644 --- a/graphrag_sdk/src/graphrag_sdk/__init__.py +++ b/graphrag_sdk/src/graphrag_sdk/__init__.py @@ -19,7 +19,11 @@ # ── Core Contracts ─────────────────────────────────────────────── from graphrag_sdk.core.connection import ConnectionConfig, FalkorDBConnection from graphrag_sdk.core.context import Context -from graphrag_sdk.core.exceptions import DocumentNotFoundError, GraphRAGError +from graphrag_sdk.core.exceptions import ( + DocumentNotFoundError, + GraphRAGError, + LatencyBudgetExceededError, +) from graphrag_sdk.core.models import ( ApplyChangesResult, BatchEntry, @@ -134,6 +138,7 @@ "GraphRelationship", "GraphSchema", "IngestionResult", + "LatencyBudgetExceededError", "LLMBatchItem", "LLMInterface", "LiteLLM", diff --git a/graphrag_sdk/src/graphrag_sdk/api/main.py b/graphrag_sdk/src/graphrag_sdk/api/main.py index 684c7a2d..f44a31dd 100644 --- a/graphrag_sdk/src/graphrag_sdk/api/main.py +++ b/graphrag_sdk/src/graphrag_sdk/api/main.py @@ -15,7 +15,12 @@ from graphrag_sdk import __version__ from graphrag_sdk.core.connection import ConnectionConfig, FalkorDBConnection from graphrag_sdk.core.context import Context -from graphrag_sdk.core.exceptions import ConfigError, DatabaseError, DocumentNotFoundError +from graphrag_sdk.core.exceptions import ( + ConfigError, + DatabaseError, + DocumentNotFoundError, + LatencyBudgetExceededError, +) from graphrag_sdk.core.models import ( ApplyChangesResult, BatchEntry, @@ -1446,13 +1451,16 @@ async def retrieve( ctx = Context() ctx.log(f"Retrieve: {question[:80]}...") + ctx.ensure_budget("graph config validation") - await self._validate_graph_config() + await self._validate_graph_config(ctx=ctx) retrieval = strategy or self._retrieval_strategy + ctx.ensure_budget("retrieval strategy search") retriever_result = await retrieval.search(question, ctx) if reranker is not None: + ctx.ensure_budget("retrieval reranking") retriever_result = await reranker.rerank(question, retriever_result, ctx) ctx.log(f"Retrieved {len(retriever_result.items)} context items") @@ -1480,13 +1488,16 @@ def _validate_history( f"history[{i}]: each message must have 'role' and " f"'content' keys, got {sorted(msg.keys())}" ) - try: - validated.append(ChatMessage(role=msg["role"], content=msg["content"])) - except Exception: + role = msg["role"] + content = msg["content"] + if role not in {"system", "user", "assistant"}: raise ValueError( - f"history[{i}]: invalid role '{msg['role']}'. " + f"history[{i}]: invalid role '{role}'. " f"Must be one of: 'system', 'user', 'assistant'" ) + if not isinstance(content, str): + raise ValueError(f"history[{i}]: content must be a string") + validated.append(ChatMessage(role=role, content=content)) else: raise TypeError( f"history[{i}]: expected ChatMessage or dict, got {type(msg).__name__}" @@ -1515,8 +1526,11 @@ async def _rewrite_question_with_history( question=question, ) try: - resp = await self.llm.ainvoke(prompt) + ctx.ensure_budget("question rewrite LLM call") + resp = await self.llm.ainvoke(prompt, timeout=ctx.remaining_budget_seconds) rewritten = (resp.content or "").strip().splitlines()[0].strip() if resp.content else "" + except LatencyBudgetExceededError: + raise except Exception as e: # Broad catch is intentional (see docstring) — but log at WARNING # with full traceback so programming bugs surface in operator @@ -1585,6 +1599,7 @@ async def completion( # Step 1: Optionally rewrite the question for retrieval. retrieval_query = question if validated_history and rewrite_question_with_history: + ctx.ensure_budget("question rewrite") retrieval_query = await self._rewrite_question_with_history( question, validated_history, @@ -1594,6 +1609,7 @@ async def completion( ctx.log(f"Rewrote for retrieval: {retrieval_query[:80]}") # Step 2: Retrieve + rerank (using possibly-rewritten query). + ctx.ensure_budget("completion retrieval") retriever_result = await self.retrieve( retrieval_query, strategy=strategy, @@ -1635,7 +1651,11 @@ async def completion( ChatMessage(role="user", content=final_user_content), ] - llm_response = await self.llm.ainvoke_messages(messages) + ctx.ensure_budget("completion LLM call") + llm_response = await self.llm.ainvoke_messages( + messages, + timeout=ctx.remaining_budget_seconds, + ) result = RagResult( answer=self._clean_answer(llm_response.content), @@ -1675,7 +1695,7 @@ async def _write_graph_config(self) -> None: except Exception: logger.debug("Failed to write graph config node", exc_info=True) - async def _validate_graph_config(self) -> None: + async def _validate_graph_config(self, *, ctx: Context | None = None) -> None: """Check that the current embedder matches the graph's stored config. Two checks, both cached after first run: @@ -1697,6 +1717,8 @@ async def _validate_graph_config(self) -> None: return try: + if ctx is not None: + ctx.ensure_budget("graph config query") result = await self._graph_store.query_raw( "MATCH (c:__GraphRAGConfig__ {id: 'default'}) " "RETURN c.embedding_model, c.embedding_dimension" @@ -1721,6 +1743,8 @@ async def _validate_graph_config(self) -> None: ) except ConfigError: raise + except LatencyBudgetExceededError: + raise except Exception: # Don't mark as validated on transient failures — retry next call. logger.debug("Failed to validate graph config", exc_info=True) @@ -1729,8 +1753,15 @@ async def _validate_graph_config(self) -> None: # Probe the embedder once: confirm it produces vectors of the # configured dimension. Catches user error like # ``embedding_dimension=256`` paired with a 1536-dim model. + if ctx is not None: + ctx.ensure_budget("graph config embedder probe") try: - probe = await self.embedder.aembed_query("dim_check") + probe = await self.embedder.aembed_query( + "dim_check", + timeout=ctx.remaining_budget_seconds if ctx is not None else None, + ) + except LatencyBudgetExceededError: + raise except Exception: # Probe failure is non-fatal — but don't cache a "validated" # state, otherwise a transient outage permanently disables diff --git a/graphrag_sdk/src/graphrag_sdk/core/connection.py b/graphrag_sdk/src/graphrag_sdk/core/connection.py index 74bc48dd..f46efdb0 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/connection.py +++ b/graphrag_sdk/src/graphrag_sdk/core/connection.py @@ -180,6 +180,12 @@ async def query( last_exc = exc # Don't retry non-transient errors (e.g. schema/index conflicts) if self._is_non_transient(exc): + logger.error( + "Non-transient FalkorDB query failure: %s: %s", + type(exc).__name__, + exc, + ) + logger.debug("Non-transient FalkorDB query failure details", exc_info=True) raise await self._breaker.record_failure() logger.warning( @@ -193,7 +199,20 @@ async def query( break base_delay = self.config.retry_delay * (2**attempt) await asyncio.sleep(base_delay * (0.5 + random.random())) - raise last_exc # type: ignore[misc] + logger.error( + "FalkorDB query failed after %d attempts: %s: %s", + self.config.retry_count, + type(last_exc).__name__ if last_exc is not None else "UnknownError", + last_exc, + ) + if last_exc is not None: + logger.debug( + "FalkorDB query failure details", + exc_info=(type(last_exc), last_exc, last_exc.__traceback__), + ) + if last_exc is not None: + raise last_exc + raise RuntimeError("FalkorDB query failed without an exception") # Substrings that indicate a non-transient (permanent) error — # retrying will never succeed. diff --git a/graphrag_sdk/src/graphrag_sdk/core/context.py b/graphrag_sdk/src/graphrag_sdk/core/context.py index 7f93ab0e..8acedbb8 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/context.py +++ b/graphrag_sdk/src/graphrag_sdk/core/context.py @@ -10,6 +10,8 @@ from typing import Any from uuid import uuid4 +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError + logger = logging.getLogger(__name__) @@ -46,12 +48,30 @@ def remaining_budget_ms(self) -> float | None: return None return max(0.0, self.latency_budget_ms - self.elapsed_ms) + @property + def remaining_budget_seconds(self) -> float | None: + """Remaining latency budget in seconds, or None if no budget set.""" + remaining = self.remaining_budget_ms + if remaining is None: + return None + return max(remaining / 1000.0, 1e-9) + @property def budget_exceeded(self) -> bool: """True if the latency budget has been exceeded.""" remaining = self.remaining_budget_ms return remaining is not None and remaining <= 0 + def ensure_budget(self, operation: str) -> None: + """Raise if the latency budget is already exhausted before *operation* starts.""" + if not self.budget_exceeded: + return + budget = self.latency_budget_ms if self.latency_budget_ms is not None else 0.0 + raise LatencyBudgetExceededError( + f"Latency budget exceeded before {operation} " + f"(elapsed={self.elapsed_ms:.1f}ms, budget={budget:.1f}ms)" + ) + def child(self, **overrides: Any) -> Context: """Create a child context inheriting tenant/trace but with optional overrides. diff --git a/graphrag_sdk/src/graphrag_sdk/core/exceptions.py b/graphrag_sdk/src/graphrag_sdk/core/exceptions.py index d8942aeb..99c3c379 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/exceptions.py +++ b/graphrag_sdk/src/graphrag_sdk/core/exceptions.py @@ -10,6 +10,12 @@ class GraphRAGError(Exception): pass +class LatencyBudgetExceededError(GraphRAGError): + """Raised when an operation cannot start within the remaining latency budget.""" + + pass + + # ── Provider Errors ────────────────────────────────────────────── @@ -19,12 +25,24 @@ class LLMError(GraphRAGError): pass +class LLMTimeoutError(LLMError): + """Raised when an LLM provider call exceeds its configured timeout.""" + + pass + + class EmbeddingError(GraphRAGError): """Raised when an embedding provider call fails.""" pass +class EmbeddingTimeoutError(EmbeddingError): + """Raised when an embedding provider call exceeds its configured timeout.""" + + pass + + # ── Ingestion Errors ───────────────────────────────────────────── diff --git a/graphrag_sdk/src/graphrag_sdk/core/providers/_retry.py b/graphrag_sdk/src/graphrag_sdk/core/providers/_retry.py index 13a4305b..a911a14d 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/providers/_retry.py +++ b/graphrag_sdk/src/graphrag_sdk/core/providers/_retry.py @@ -7,6 +7,8 @@ from collections.abc import Awaitable, Callable from typing import Any +from graphrag_sdk.core.exceptions import EmbeddingTimeoutError + logger = logging.getLogger(__name__) _MAX_EXC_SUMMARY_LEN = 200 @@ -65,8 +67,12 @@ def binary_split_retry_sync( """ try: return embed_fn(texts, **kwargs) + except EmbeddingTimeoutError: + raise except Exception as exc: if not is_transient_embedding_error(exc): + logger.error("Non-transient embedding failure: %s", summarize_exception(exc)) + logger.debug("Non-transient embedding failure details", exc_info=True) raise if len(texts) == 1: logger.warning("Embedding failed for text (len=%d): skipped", len(texts[0])) @@ -85,8 +91,12 @@ async def binary_split_retry_async( """Async variant of :func:`binary_split_retry_sync`.""" try: return await embed_fn(texts, **kwargs) + except EmbeddingTimeoutError: + raise except Exception as exc: if not is_transient_embedding_error(exc): + logger.error("Non-transient embedding failure: %s", summarize_exception(exc)) + logger.debug("Non-transient embedding failure details", exc_info=True) raise if len(texts) == 1: logger.warning("Embedding failed for text (len=%d): skipped", len(texts[0])) diff --git a/graphrag_sdk/src/graphrag_sdk/core/providers/_timeout.py b/graphrag_sdk/src/graphrag_sdk/core/providers/_timeout.py new file mode 100644 index 00000000..fd901596 --- /dev/null +++ b/graphrag_sdk/src/graphrag_sdk/core/providers/_timeout.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import asyncio +from typing import Any + +from graphrag_sdk.core.exceptions import EmbeddingTimeoutError, LLMTimeoutError + + +async def wait_for_provider_call( + awaitable: Any, + *, + timeout: float | None, + timeout_error: type[LLMTimeoutError] | type[EmbeddingTimeoutError], + operation: str, +) -> Any: + if timeout is None: + return await awaitable + if timeout <= 0: + close = getattr(awaitable, "close", None) + if close is not None: + close() + raise timeout_error(f"{operation} timed out after {timeout:.3g}s") + try: + return await asyncio.wait_for(awaitable, timeout=timeout) + except (TimeoutError, asyncio.TimeoutError) as exc: + raise timeout_error(f"{operation} timed out after {timeout:.3g}s") from exc + + +def validate_timeout(timeout: float | None) -> None: + if timeout is not None and timeout <= 0: + raise ValueError("timeout must be > 0") diff --git a/graphrag_sdk/src/graphrag_sdk/core/providers/base.py b/graphrag_sdk/src/graphrag_sdk/core/providers/base.py index cd7a31da..a46ace67 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/providers/base.py +++ b/graphrag_sdk/src/graphrag_sdk/core/providers/base.py @@ -13,8 +13,10 @@ from pydantic import BaseModel +from graphrag_sdk.core.exceptions import EmbeddingTimeoutError, LLMTimeoutError from graphrag_sdk.core.models import ChatMessage, LLMResponse from graphrag_sdk.core.providers._retry import summarize_exception +from graphrag_sdk.core.providers._timeout import validate_timeout, wait_for_provider_call logger = logging.getLogger(__name__) @@ -60,17 +62,41 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]: """Embed a single text string into a float vector.""" ... - async def aembed_query(self, text: str, **kwargs: Any) -> list[float]: + async def aembed_query( + self, + text: str, + *, + timeout: float | None = None, + **kwargs: Any, + ) -> list[float]: """Async variant — defaults to sync-in-thread.""" - return await asyncio.to_thread(self.embed_query, text, **kwargs) + validate_timeout(timeout) + return await wait_for_provider_call( + asyncio.to_thread(self.embed_query, text, **kwargs), + timeout=timeout, + timeout_error=EmbeddingTimeoutError, + operation=f"embedding query with {self.model_name}", + ) def embed_documents(self, texts: list[str], **kwargs: Any) -> list[list[float]]: """Batch embed multiple texts. Default: sequential fallback.""" return [self.embed_query(t, **kwargs) for t in texts] - async def aembed_documents(self, texts: list[str], **kwargs: Any) -> list[list[float]]: + async def aembed_documents( + self, + texts: list[str], + *, + timeout: float | None = None, + **kwargs: Any, + ) -> list[list[float]]: """Async batch embed. Default: sync-in-thread.""" - return await asyncio.to_thread(self.embed_documents, texts, **kwargs) + validate_timeout(timeout) + return await wait_for_provider_call( + asyncio.to_thread(self.embed_documents, texts, **kwargs), + timeout=timeout, + timeout_error=EmbeddingTimeoutError, + operation=f"embedding documents with {self.model_name}", + ) class LLMInterface(ABC): @@ -106,6 +132,7 @@ async def ainvoke( prompt: str, *, max_retries: int = 3, + timeout: float | None = None, **kwargs: Any, ) -> LLMResponse: """Async variant with retry + jittered exponential backoff. @@ -113,10 +140,20 @@ async def ainvoke( Retries on any exception up to ``max_retries`` times with jittered delays between attempts. """ + if max_retries < 1: + raise ValueError("max_retries must be >= 1") + validate_timeout(timeout) last_exc: Exception | None = None for attempt in range(max_retries): try: - return await asyncio.to_thread(self.invoke, prompt, **kwargs) + return await wait_for_provider_call( + asyncio.to_thread(self.invoke, prompt, **kwargs), + timeout=timeout, + timeout_error=LLMTimeoutError, + operation=f"LLM call to {self.model_name}", + ) + except LLMTimeoutError: + raise except Exception as exc: last_exc = exc if attempt < max_retries - 1: @@ -130,6 +167,16 @@ async def ainvoke( ) logger.debug("LLM call failure details", exc_info=True) await asyncio.sleep(delay) + logger.error( + "LLM call failed after %d attempts: %s", + max_retries, + summarize_exception(last_exc) if last_exc is not None else "UnknownError", + ) + if last_exc is not None: + logger.debug( + "LLM call final failure details", + exc_info=(type(last_exc), last_exc, last_exc.__traceback__), + ) raise last_exc # type: ignore[misc] async def ainvoke_messages( @@ -137,6 +184,7 @@ async def ainvoke_messages( messages: list[ChatMessage], *, max_retries: int = 3, + timeout: float | None = None, **kwargs: Any, ) -> LLMResponse: """Invoke the LLM with a list of structured chat messages. @@ -155,16 +203,28 @@ async def ainvoke_messages( Returns: LLMResponse with the model's reply. """ + validate_timeout(timeout) # Default fallback: flatten messages into a single prompt string. parts: list[str] = [] for msg in messages: parts.append(f"{msg.role.capitalize()}: {msg.content}") prompt = "\n\n".join(parts) - return await self.ainvoke(prompt, max_retries=max_retries, **kwargs) - - async def astream(self, prompt: str, **kwargs: Any) -> AsyncIterator[str]: + return await self.ainvoke( + prompt, + max_retries=max_retries, + timeout=timeout, + **kwargs, + ) + + async def astream( + self, + prompt: str, + *, + timeout: float | None = None, + **kwargs: Any, + ) -> AsyncIterator[str]: """Async streaming — default yields the full response as one chunk.""" - resp = await self.ainvoke(prompt, **kwargs) + resp = await self.ainvoke(prompt, timeout=timeout, **kwargs) yield resp.content def invoke_with_model( @@ -187,10 +247,16 @@ async def ainvoke_with_model( response_model: type[BaseModel], *, max_retries: int = 3, + timeout: float | None = None, **kwargs: Any, ) -> BaseModel: """Async structured output invocation with retry.""" - response = await self.ainvoke(prompt, max_retries=max_retries, **kwargs) + response = await self.ainvoke( + prompt, + max_retries=max_retries, + timeout=timeout, + **kwargs, + ) return response_model.model_validate_json(response.content) async def abatch_invoke( @@ -199,6 +265,7 @@ async def abatch_invoke( *, max_concurrency: int | None = None, max_retries: int = 3, + timeout: float | None = None, **kwargs: Any, ) -> list[LLMBatchItem]: """Invoke LLM on multiple prompts concurrently with per-item error handling. @@ -220,9 +287,20 @@ async def abatch_invoke( async def _call(i: int, prompt: str) -> LLMBatchItem: async with sem: try: - resp = await self.ainvoke(prompt, max_retries=max_retries, **kwargs) + resp = await self.ainvoke( + prompt, + max_retries=max_retries, + timeout=timeout, + **kwargs, + ) return LLMBatchItem(index=i, response=resp) except Exception as exc: + logger.error( + "Batch LLM item %d failed: %s", + i, + summarize_exception(exc), + ) + logger.debug("Batch LLM item failure details", exc_info=True) return LLMBatchItem(index=i, error=exc) return list(await asyncio.gather(*[_call(i, p) for i, p in enumerate(prompts)])) diff --git a/graphrag_sdk/src/graphrag_sdk/core/providers/litellm.py b/graphrag_sdk/src/graphrag_sdk/core/providers/litellm.py index 8d5e8331..9faeb51d 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/providers/litellm.py +++ b/graphrag_sdk/src/graphrag_sdk/core/providers/litellm.py @@ -6,14 +6,17 @@ import asyncio import logging +import time from typing import Any +from graphrag_sdk.core.exceptions import EmbeddingTimeoutError, LLMTimeoutError from graphrag_sdk.core.models import ChatMessage, LLMResponse from graphrag_sdk.core.providers._retry import ( binary_split_retry_async, binary_split_retry_sync, summarize_exception, ) +from graphrag_sdk.core.providers._timeout import validate_timeout, wait_for_provider_call from graphrag_sdk.core.providers.base import Embedder, LLMInterface logger = logging.getLogger(__name__) @@ -116,8 +119,12 @@ async def ainvoke( prompt: str, *, max_retries: int = 3, + timeout: float | None = None, **kwargs: Any, ) -> LLMResponse: + if max_retries < 1: + raise ValueError("max_retries must be >= 1") + validate_timeout(timeout) try: import litellm except ImportError: @@ -128,9 +135,16 @@ async def ainvoke( last_exc: Exception | None = None for attempt in range(max_retries): try: - response = await litellm.acompletion(**self._completion_kwargs(prompt, **kwargs)) + response = await wait_for_provider_call( + litellm.acompletion(**self._completion_kwargs(prompt, **kwargs)), + timeout=timeout, + timeout_error=LLMTimeoutError, + operation=f"LiteLLM call to {self.model_name}", + ) content = response.choices[0].message.content or "" return LLMResponse(content=content) + except LLMTimeoutError: + raise except Exception as exc: last_exc = exc if attempt < max_retries - 1: @@ -144,6 +158,16 @@ async def ainvoke( ) logger.debug("LiteLLM call failure details", exc_info=True) await asyncio.sleep(delay) + logger.error( + "LiteLLM call failed after %d attempts: %s", + max_retries, + summarize_exception(last_exc) if last_exc is not None else "UnknownError", + ) + if last_exc is not None: + logger.debug( + "LiteLLM call final failure details", + exc_info=(type(last_exc), last_exc, last_exc.__traceback__), + ) raise last_exc # type: ignore[misc] def _messages_completion_kwargs( @@ -183,11 +207,13 @@ async def ainvoke_messages( messages: list[ChatMessage], *, max_retries: int = 3, + timeout: float | None = None, **kwargs: Any, ) -> LLMResponse: """Native multi-turn completion via LiteLLM.""" if max_retries < 1: raise ValueError("max_retries must be >= 1") + validate_timeout(timeout) try: import litellm except ImportError: @@ -198,11 +224,16 @@ async def ainvoke_messages( last_exc: Exception | None = None for attempt in range(max_retries): try: - response = await litellm.acompletion( - **self._messages_completion_kwargs(messages, **kwargs) + response = await wait_for_provider_call( + litellm.acompletion(**self._messages_completion_kwargs(messages, **kwargs)), + timeout=timeout, + timeout_error=LLMTimeoutError, + operation=f"LiteLLM messages call to {self.model_name}", ) content = response.choices[0].message.content or "" return LLMResponse(content=content) + except LLMTimeoutError: + raise except Exception as exc: last_exc = exc if attempt < max_retries - 1: @@ -216,6 +247,16 @@ async def ainvoke_messages( ) logger.debug("LiteLLM call failure details", exc_info=True) await asyncio.sleep(delay) + logger.error( + "LiteLLM messages call failed after %d attempts: %s", + max_retries, + summarize_exception(last_exc) if last_exc is not None else "UnknownError", + ) + if last_exc is not None: + logger.debug( + "LiteLLM messages call final failure details", + exc_info=(type(last_exc), last_exc, last_exc.__traceback__), + ) raise last_exc # type: ignore[misc] @@ -291,10 +332,22 @@ def _raw_embed_sync(self, texts: list[str], **kwargs: Any) -> list[list[float]]: sorted_data = sorted(response.data, key=lambda x: x["index"]) return [d["embedding"] for d in sorted_data] - async def _raw_embed_async(self, texts: list[str], **kwargs: Any) -> list[list[float]]: + async def _raw_embed_async( + self, + texts: list[str], + *, + timeout: float | None = None, + **kwargs: Any, + ) -> list[list[float]]: """Raw async embed without retry — called by binary_split_retry_async.""" + validate_timeout(timeout) litellm = self._import_litellm() - response = await litellm.aembedding(**self._embedding_kwargs(texts, **kwargs)) + response = await wait_for_provider_call( + litellm.aembedding(**self._embedding_kwargs(texts, **kwargs)), + timeout=timeout, + timeout_error=EmbeddingTimeoutError, + operation=f"LiteLLM embedding batch with {self.model}", + ) sorted_data = sorted(response.data, key=lambda x: x["index"]) return [d["embedding"] for d in sorted_data] @@ -307,16 +360,62 @@ def embed_documents(self, texts: list[str], **kwargs: Any) -> list[list[float]]: results.extend(binary_split_retry_sync(self._raw_embed_sync, batch, **kwargs)) return results - async def aembed_query(self, text: str, **kwargs: Any) -> list[float]: + async def aembed_query( + self, + text: str, + *, + timeout: float | None = None, + **kwargs: Any, + ) -> list[float]: + validate_timeout(timeout) litellm = self._import_litellm() - response = await litellm.aembedding(**self._embedding_kwargs(text, **kwargs)) + response = await wait_for_provider_call( + litellm.aembedding(**self._embedding_kwargs(text, **kwargs)), + timeout=timeout, + timeout_error=EmbeddingTimeoutError, + operation=f"LiteLLM embedding query with {self.model}", + ) return response.data[0]["embedding"] - async def aembed_documents(self, texts: list[str], **kwargs: Any) -> list[list[float]]: + async def aembed_documents( + self, + texts: list[str], + *, + timeout: float | None = None, + **kwargs: Any, + ) -> list[list[float]]: + validate_timeout(timeout) if not texts: return [] + deadline = time.monotonic() + timeout if timeout is not None else None + + def remaining_timeout() -> float | None: + if deadline is None: + return None + remaining = deadline - time.monotonic() + if remaining <= 0: + raise EmbeddingTimeoutError( + f"LiteLLM embedding documents with {self.model} timed out after {timeout:.3g}s" + ) + return remaining + + async def embed_batch_with_deadline( + batch: list[str], **inner_kwargs: Any + ) -> list[list[float]]: + return await self._raw_embed_async( + batch, + timeout=remaining_timeout(), + **inner_kwargs, + ) + results: list[list[float]] = [] for start in range(0, len(texts), self.batch_size): batch = texts[start : start + self.batch_size] - results.extend(await binary_split_retry_async(self._raw_embed_async, batch, **kwargs)) + results.extend( + await binary_split_retry_async( + embed_batch_with_deadline, + batch, + **kwargs, + ) + ) return results diff --git a/graphrag_sdk/src/graphrag_sdk/core/providers/openrouter.py b/graphrag_sdk/src/graphrag_sdk/core/providers/openrouter.py index bb075524..16d5da44 100644 --- a/graphrag_sdk/src/graphrag_sdk/core/providers/openrouter.py +++ b/graphrag_sdk/src/graphrag_sdk/core/providers/openrouter.py @@ -7,14 +7,17 @@ import asyncio import logging import os +import time from typing import Any +from graphrag_sdk.core.exceptions import EmbeddingTimeoutError, LLMTimeoutError from graphrag_sdk.core.models import ChatMessage, LLMResponse from graphrag_sdk.core.providers._retry import ( binary_split_retry_async, binary_split_retry_sync, summarize_exception, ) +from graphrag_sdk.core.providers._timeout import validate_timeout, wait_for_provider_call from graphrag_sdk.core.providers.base import Embedder, LLMInterface logger = logging.getLogger(__name__) @@ -141,8 +144,12 @@ async def ainvoke( prompt: str, *, max_retries: int = 3, + timeout: float | None = None, **kwargs: Any, ) -> LLMResponse: + if max_retries < 1: + raise ValueError("max_retries must be >= 1") + validate_timeout(timeout) client = self._get_async_client() create_kwargs = self._build_create_kwargs( [{"role": "user", "content": prompt}], @@ -151,9 +158,16 @@ async def ainvoke( last_exc: Exception | None = None for attempt in range(max_retries): try: - response = await client.chat.completions.create(**create_kwargs) + response = await wait_for_provider_call( + client.chat.completions.create(**create_kwargs), + timeout=timeout, + timeout_error=LLMTimeoutError, + operation=f"OpenRouter call to {self.model_name}", + ) content = response.choices[0].message.content or "" return LLMResponse(content=content) + except LLMTimeoutError: + raise except Exception as exc: last_exc = exc if attempt < max_retries - 1: @@ -167,6 +181,16 @@ async def ainvoke( ) logger.debug("OpenRouter call failure details", exc_info=True) await asyncio.sleep(delay) + logger.error( + "OpenRouter call failed after %d attempts: %s", + max_retries, + summarize_exception(last_exc) if last_exc is not None else "UnknownError", + ) + if last_exc is not None: + logger.debug( + "OpenRouter call final failure details", + exc_info=(type(last_exc), last_exc, last_exc.__traceback__), + ) raise last_exc # type: ignore[misc] async def ainvoke_messages( @@ -174,11 +198,13 @@ async def ainvoke_messages( messages: list[ChatMessage], *, max_retries: int = 3, + timeout: float | None = None, **kwargs: Any, ) -> LLMResponse: """Native multi-turn completion via OpenRouter.""" if max_retries < 1: raise ValueError("max_retries must be >= 1") + validate_timeout(timeout) client = self._get_async_client() create_kwargs = self._build_create_kwargs( [m.to_dict() for m in messages], @@ -187,9 +213,16 @@ async def ainvoke_messages( last_exc: Exception | None = None for attempt in range(max_retries): try: - response = await client.chat.completions.create(**create_kwargs) + response = await wait_for_provider_call( + client.chat.completions.create(**create_kwargs), + timeout=timeout, + timeout_error=LLMTimeoutError, + operation=f"OpenRouter messages call to {self.model_name}", + ) content = response.choices[0].message.content or "" return LLMResponse(content=content) + except LLMTimeoutError: + raise except Exception as exc: last_exc = exc if attempt < max_retries - 1: @@ -203,6 +236,16 @@ async def ainvoke_messages( ) logger.debug("OpenRouter call failure details", exc_info=True) await asyncio.sleep(delay) + logger.error( + "OpenRouter messages call failed after %d attempts: %s", + max_retries, + summarize_exception(last_exc) if last_exc is not None else "UnknownError", + ) + if last_exc is not None: + logger.debug( + "OpenRouter messages call final failure details", + exc_info=(type(last_exc), last_exc, last_exc.__traceback__), + ) raise last_exc # type: ignore[misc] @@ -280,10 +323,22 @@ def _raw_embed_sync(self, texts: list[str], **kwargs: Any) -> list[list[float]]: sorted_data = sorted(response.data, key=lambda x: x.index) return [d.embedding for d in sorted_data] - async def _raw_embed_async(self, texts: list[str], **kwargs: Any) -> list[list[float]]: + async def _raw_embed_async( + self, + texts: list[str], + *, + timeout: float | None = None, + **kwargs: Any, + ) -> list[list[float]]: """Raw async embed without retry — called by binary_split_retry_async.""" + validate_timeout(timeout) client = self._get_async_client() - response = await client.embeddings.create(model=self.model, input=texts, **kwargs) + response = await wait_for_provider_call( + client.embeddings.create(model=self.model, input=texts, **kwargs), + timeout=timeout, + timeout_error=EmbeddingTimeoutError, + operation=f"OpenRouter embedding batch with {self.model}", + ) sorted_data = sorted(response.data, key=lambda x: x.index) return [d.embedding for d in sorted_data] @@ -296,16 +351,63 @@ def embed_documents(self, texts: list[str], **kwargs: Any) -> list[list[float]]: results.extend(binary_split_retry_sync(self._raw_embed_sync, batch, **kwargs)) return results - async def aembed_query(self, text: str, **kwargs: Any) -> list[float]: + async def aembed_query( + self, + text: str, + *, + timeout: float | None = None, + **kwargs: Any, + ) -> list[float]: + validate_timeout(timeout) client = self._get_async_client() - response = await client.embeddings.create(model=self.model, input=text, **kwargs) + response = await wait_for_provider_call( + client.embeddings.create(model=self.model, input=text, **kwargs), + timeout=timeout, + timeout_error=EmbeddingTimeoutError, + operation=f"OpenRouter embedding query with {self.model}", + ) return response.data[0].embedding - async def aembed_documents(self, texts: list[str], **kwargs: Any) -> list[list[float]]: + async def aembed_documents( + self, + texts: list[str], + *, + timeout: float | None = None, + **kwargs: Any, + ) -> list[list[float]]: + validate_timeout(timeout) if not texts: return [] + deadline = time.monotonic() + timeout if timeout is not None else None + + def remaining_timeout() -> float | None: + if deadline is None: + return None + remaining = deadline - time.monotonic() + if remaining <= 0: + raise EmbeddingTimeoutError( + f"OpenRouter embedding documents with {self.model} timed out after " + f"{timeout:.3g}s" + ) + return remaining + + async def embed_batch_with_deadline( + batch: list[str], **inner_kwargs: Any + ) -> list[list[float]]: + return await self._raw_embed_async( + batch, + timeout=remaining_timeout(), + **inner_kwargs, + ) + results: list[list[float]] = [] for start in range(0, len(texts), self.batch_size): batch = texts[start : start + self.batch_size] - results.extend(await binary_split_retry_async(self._raw_embed_async, batch, **kwargs)) + results.extend( + await binary_split_retry_async( + embed_batch_with_deadline, + batch, + **kwargs, + ) + ) return results diff --git a/graphrag_sdk/src/graphrag_sdk/ingestion/loaders/pdf_loader.py b/graphrag_sdk/src/graphrag_sdk/ingestion/loaders/pdf_loader.py index 72989058..e1a75c5a 100644 --- a/graphrag_sdk/src/graphrag_sdk/ingestion/loaders/pdf_loader.py +++ b/graphrag_sdk/src/graphrag_sdk/ingestion/loaders/pdf_loader.py @@ -56,8 +56,9 @@ def _load_with_pymupdf(self, path: Path) -> DocumentOutput: """Extract text using PyMuPDF with sort=True for table-aware layout.""" import fitz # PyMuPDF - doc = fitz.open(str(path)) + doc = None try: + doc = fitz.open(str(path)) pages: list[str] = [] for page in doc: text = page.get_text(sort=True) @@ -65,9 +66,12 @@ def _load_with_pymupdf(self, path: Path) -> DocumentOutput: pages.append(text) page_count = len(doc) except Exception as exc: + logger.error("Failed to read PDF %s with PyMuPDF: %s", path, exc) + logger.debug("PyMuPDF read failure details", exc_info=True) raise LoaderError(f"Failed to read PDF {path}: {exc}") from exc finally: - doc.close() + if doc is not None: + doc.close() full_text = "\n\n".join(pages) logger.info( @@ -117,4 +121,6 @@ def _load_with_pypdf(self, path: Path) -> DocumentOutput: ), ) except Exception as exc: + logger.error("Failed to read PDF %s with pypdf: %s", path, exc) + logger.debug("pypdf read failure details", exc_info=True) raise LoaderError(f"Failed to read PDF {path}: {exc}") from exc diff --git a/graphrag_sdk/src/graphrag_sdk/ingestion/loaders/text_loader.py b/graphrag_sdk/src/graphrag_sdk/ingestion/loaders/text_loader.py index 56c7e3d6..842c24c8 100644 --- a/graphrag_sdk/src/graphrag_sdk/ingestion/loaders/text_loader.py +++ b/graphrag_sdk/src/graphrag_sdk/ingestion/loaders/text_loader.py @@ -46,4 +46,6 @@ def _load_sync(self, source: str) -> DocumentOutput: ), ) except Exception as exc: + logger.error("Failed to read text file %s: %s", source, exc) + logger.debug("Text file read failure details", exc_info=True) raise LoaderError(f"Failed to read {source}: {exc}") from exc diff --git a/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py b/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py index 165cbeeb..1a31ed2b 100644 --- a/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py +++ b/graphrag_sdk/src/graphrag_sdk/ingestion/pipeline.py @@ -275,6 +275,8 @@ async def _step_index_chunks() -> None: except IngestionError: raise except Exception as exc: + logger.error("Pipeline failed with unexpected error: %s", exc) + logger.debug("Pipeline failure details", exc_info=True) raise IngestionError(f"Pipeline failed: {exc}") from exc async def _build_lexical_graph( diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/reranking_strategies/cosine.py b/graphrag_sdk/src/graphrag_sdk/retrieval/reranking_strategies/cosine.py index 33e7cd9a..2e5eb539 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/reranking_strategies/cosine.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/reranking_strategies/cosine.py @@ -40,7 +40,11 @@ async def rerank( return result texts = [query] + [item.content for item in result.items] - vectors = await self._embedder.aembed_documents(texts) + ctx.ensure_budget("cosine reranker embedding") + vectors = await self._embedder.aembed_documents( + texts, + timeout=ctx.remaining_budget_seconds, + ) query_vec = np.array(vectors[0]) q_norm = np.linalg.norm(query_vec) diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/base.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/base.py index b3669c3a..1dea837a 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/base.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/base.py @@ -12,7 +12,7 @@ from typing import Any from graphrag_sdk.core.context import Context -from graphrag_sdk.core.exceptions import RetrieverError +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError, RetrieverError from graphrag_sdk.core.models import RawSearchResult, RetrieverResult, RetrieverResultItem logger = logging.getLogger(__name__) @@ -93,7 +93,11 @@ async def search( except RetrieverError: raise + except LatencyBudgetExceededError: + raise except Exception as exc: + logger.error("Retrieval [%s] failed: %s", self.__class__.__name__, exc) + logger.debug("Retrieval failure details", exc_info=True) raise RetrieverError(f"Retrieval [{self.__class__.__name__}] failed: {exc}") from exc @abstractmethod diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/chunk_retrieval.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/chunk_retrieval.py index e71e4455..8134aafa 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/chunk_retrieval.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/chunk_retrieval.py @@ -6,7 +6,11 @@ import logging from typing import Any +from graphrag_sdk.core.context import Context +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError + logger = logging.getLogger(__name__) +_UNBOUNDED = Context() async def retrieve_chunks( @@ -17,6 +21,7 @@ async def retrieve_chunks( llm_kw: list[str], simple_kw: list[str], entity_list: list[tuple[str, dict]], + ctx: Context = _UNBOUNDED, ) -> tuple[dict[str, str], dict[str, str], dict[str, list[float]]]: """4-path chunk retrieval: fulltext + vector + MENTIONED_IN + 2-hop. @@ -37,17 +42,23 @@ def _add(cid: str, text: str, source: str) -> None: fulltext_queries = [query] + llm_kw[:6] + simple_kw[:4] for ft_q in fulltext_queries: try: + ctx.ensure_budget("chunk fulltext search") results = await vector_store.fulltext_search_chunks(ft_q, top_k=5) for c in results: _add(c.get("id", ""), c.get("text", ""), "fulltext") + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Chunk fulltext search failed for query: %s", exc) # Path B: Vector search try: + ctx.ensure_budget("chunk vector search") results = await vector_store.search_chunks(query_vector, top_k=15) for c in results: _add(c.get("id", ""), c.get("text", ""), "vector") + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Chunk vector search failed: %s", exc) @@ -60,6 +71,7 @@ def _add(cid: str, text: str, source: str) -> None: eids_mention = [eid for eid, _ in entity_list[:15]] if eids_mention: try: + ctx.ensure_budget("MENTIONED_IN chunk retrieval") result = await graph_store.query_raw( "UNWIND $eids AS eid " "MATCH (e:__Entity__ {id: eid})-[:MENTIONED_IN]->(c:Chunk) " @@ -75,6 +87,8 @@ def _add(cid: str, text: str, source: str) -> None: cid = row[1] text = row[2] if len(row) > 2 else "" _add(cid, text, "mentioned_in") + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("MENTIONED_IN chunk retrieval failed: %s", exc) @@ -82,6 +96,7 @@ def _add(cid: str, text: str, source: str) -> None: eids_2hop_chunk = [eid for eid, _ in entity_list[:10]] if eids_2hop_chunk: try: + ctx.ensure_budget("2-hop chunk retrieval") result = await graph_store.query_raw( "UNWIND $eids AS eid " "MATCH (e:__Entity__ {id: eid})-[:RELATES]-(neighbor:__Entity__)" @@ -94,6 +109,8 @@ def _add(cid: str, text: str, source: str) -> None: cid = row[0] text = row[1] if len(row) > 1 else "" _add(cid, text, "2hop_mentioned") + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("2-hop chunk retrieval failed: %s", exc) @@ -102,6 +119,7 @@ def _add(cid: str, text: str, source: str) -> None: missing_ids = list(chunks.keys()) if missing_ids: try: + ctx.ensure_budget("stored chunk embedding fetch") result = await graph_store.query_raw( "UNWIND $ids AS cid " "MATCH (c:Chunk {id: cid}) " @@ -112,6 +130,8 @@ def _add(cid: str, text: str, source: str) -> None: for row in result.result_set: if row[0] and row[1] is not None: embeddings[row[0]] = list(row[1]) + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Stored embedding fetch failed: %s", exc) @@ -121,6 +141,7 @@ def _add(cid: str, text: str, source: str) -> None: async def fetch_chunk_documents( graph_store: Any, chunk_ids: list[str], + ctx: Context = _UNBOUNDED, ) -> dict[str, str]: """Batch-fetch the source document path for each chunk via PART_OF. @@ -137,6 +158,7 @@ async def fetch_chunk_documents( if not chunk_ids: return {} try: + ctx.ensure_budget("chunk source document fetch") result = await graph_store.query_raw( "UNWIND $ids AS cid " "MATCH (d:Document)-[:PART_OF]->(c:Chunk {id: cid}) " @@ -150,6 +172,8 @@ async def fetch_chunk_documents( if cid and path: mapping[cid] = path return mapping + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Document name fetch failed: %s", exc) return {} diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py index 0dd3d289..fbb04ab7 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/cypher_generation.py @@ -12,7 +12,11 @@ import re from typing import Any +from graphrag_sdk.core.context import Context +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError + logger = logging.getLogger(__name__) +_UNBOUNDED = Context() # ── Valid labels for our graph schema ──────────────────────────── @@ -245,6 +249,7 @@ async def generate_cypher( question: str, *, max_retries: int = 3, + ctx: Context = _UNBOUNDED, ) -> str | None: """Generate a Cypher query from a natural language question. @@ -255,15 +260,19 @@ async def generate_cypher( for attempt in range(max_retries): try: + ctx.ensure_budget("Cypher generation LLM call") if attempt > 0 and last_error: prompt_with_feedback = ( prompt + f"\n\nPrevious attempt failed with error: {last_error}\n" "Remember: no shortestPath, every RETURN column must have a " "unique alias, add LIMIT, keep it simple." ) - response = await llm.ainvoke(prompt_with_feedback) + response = await llm.ainvoke( + prompt_with_feedback, + timeout=ctx.remaining_budget_seconds, + ) else: - response = await llm.ainvoke(prompt) + response = await llm.ainvoke(prompt, timeout=ctx.remaining_budget_seconds) cypher = extract_cypher(response.content) if not cypher: @@ -279,6 +288,8 @@ async def generate_cypher( cypher = _sanitize_cypher(cypher) return cypher + except LatencyBudgetExceededError: + raise except Exception as exc: last_error = str(exc) logger.debug("Cypher generation attempt %d failed: %s", attempt + 1, exc) @@ -293,6 +304,7 @@ async def execute_cypher_retrieval( question: str, *, max_retries: int = 3, + ctx: Context = _UNBOUNDED, ) -> tuple[list[str], dict[str, dict]]: """Full text-to-cypher retrieval: generate -> validate -> execute -> parse. @@ -306,12 +318,15 @@ async def execute_cypher_retrieval( On any failure, returns empty results (silent degradation). """ - cypher = await generate_cypher(llm, question, max_retries=max_retries) + cypher = await generate_cypher(llm, question, max_retries=max_retries, ctx=ctx) if not cypher: return [], {} try: + ctx.ensure_budget("Cypher execution") result = await graph_store.query_raw(cypher) + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Cypher execution failed: %s — query: %s", exc, cypher) return [], {} diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/entity_discovery.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/entity_discovery.py index f1d54786..f55da3ef 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/entity_discovery.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/entity_discovery.py @@ -8,7 +8,11 @@ import re from typing import Any +from graphrag_sdk.core.context import Context +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError + logger = logging.getLogger(__name__) +_UNBOUNDED = Context() _ENUMERATION_RE = re.compile( r"\b(every|each|complete list|full list|list all|list of all" @@ -26,6 +30,7 @@ async def search_relates_edges( vector_store: Any, query_vector: list[float], rel_top_k: int = 15, + ctx: Context = _UNBOUNDED, ) -> tuple[list[tuple[str, float]], dict[str, dict]]: """Search RELATES edges by vector similarity. @@ -37,6 +42,7 @@ async def search_relates_edges( fact_strings: list[tuple[str, float]] = [] entities: dict[str, dict] = {} try: + ctx.ensure_budget("RELATES vector search") results = await vector_store.search_relationships(query_vector, top_k=rel_top_k) for rel in results: src = rel.get("src_name", "") @@ -58,6 +64,8 @@ async def search_relates_edges( tgt_id = tgt.strip().lower().replace(" ", "_") if tgt_id not in entities: entities[tgt_id] = {"name": tgt, "description": ""} + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("RELATES edge vector search failed: %s", exc) return fact_strings, entities @@ -68,6 +76,7 @@ async def discover_entities( vector_store: Any, llm_kw: list[str], all_keywords: list[str], + ctx: Context = _UNBOUNDED, ) -> tuple[dict[str, dict], dict[str, str]]: """2-path entity discovery. @@ -107,6 +116,7 @@ def _add(eid: str, info: dict, source: str) -> None: # so exact matches land at the head of `found` and survive the # downstream max_entities / result_assembly caps. try: + ctx.ensure_budget("entity exact-name search") result = await graph_store.query_raw( "UNWIND $keywords AS kw " "CALL { " @@ -127,6 +137,8 @@ def _add(eid: str, info: dict, source: str) -> None: }, "cypher_exact", ) + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Entity exact-name search failed: %s", exc) @@ -134,6 +146,7 @@ def _add(eid: str, info: dict, source: str) -> None: # Excludes exact matches (already added in pass a1) so the quota # isn't spent re-fetching them. try: + ctx.ensure_budget("entity contains search") result = await graph_store.query_raw( "UNWIND $keywords AS kw " "CALL { " @@ -157,17 +170,21 @@ def _add(eid: str, info: dict, source: str) -> None: }, "cypher_contains", ) + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Entity CONTAINS search failed: %s", exc) # Path b: Fulltext search on entity index for kw in all_keywords[:6]: try: + ctx.ensure_budget("entity fulltext search") ft_ents = await vector_store.fulltext_search_entities(kw, top_k=3) for ent in ft_ents: eid = ent.get("id", "") if eid: try: + ctx.ensure_budget("entity detail fetch") detail = await graph_store.query_raw( "MATCH (e:__Entity__ {id: $eid}) " "RETURN e.name AS name, e.description AS desc", @@ -183,9 +200,13 @@ def _add(eid: str, info: dict, source: str) -> None: }, "fulltext", ) + except LatencyBudgetExceededError: + raise except Exception: logger.debug("Entity detail fetch failed for %s", eid, exc_info=True) _add(eid, {"name": "", "description": ""}, "fulltext") + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Entity fulltext search failed for '%s': %s", kw, exc) @@ -197,6 +218,7 @@ async def expand_sibling_entities( found_entities: dict[str, dict], found_sources: dict[str, str], max_siblings: int = 20, + ctx: Context = _UNBOUNDED, ) -> int: """Expand discovered entities by finding graph siblings. @@ -216,6 +238,7 @@ async def expand_sibling_entities( added = 0 try: + ctx.ensure_budget("sibling entity expansion") result = await graph_store.query_raw( "MATCH (e:__Entity__) WHERE e.id IN $found_ids " "MATCH (e)-[]-(hub:__Entity__) " @@ -238,6 +261,8 @@ async def expand_sibling_entities( } found_sources[eid] = "sibling_expansion" added += 1 + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Sibling entity expansion failed: %s", exc) diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/local.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/local.py index 08944260..16ac116d 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/local.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/local.py @@ -52,9 +52,14 @@ async def _execute( top_k = kwargs.get("top_k", self._top_k) # Step 1: Embed the query - query_vector = await self._embedder.aembed_query(query) + ctx.ensure_budget("LocalRetrieval query embedding") + query_vector = await self._embedder.aembed_query( + query, + timeout=ctx.remaining_budget_seconds, + ) # Step 2: Vector search for matching chunks + ctx.ensure_budget("LocalRetrieval chunk vector search") chunk_results = await self._vector.search_chunks( query_vector=query_vector, top_k=top_k, @@ -74,6 +79,7 @@ async def _execute( if self._include_entities and self._graph: # 1-hop: find entities extracted from this chunk + ctx.ensure_budget("LocalRetrieval entity expansion") entities = await self._graph.get_connected_entities( chunk_id=chunk.get("id", ""), ) diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py index d7bcbd9b..adde66ca 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/multi_path.py @@ -11,6 +11,7 @@ from typing import Any from graphrag_sdk.core.context import Context +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError from graphrag_sdk.core.models import ( RawSearchResult, RetrieverResult, @@ -45,6 +46,14 @@ logger = logging.getLogger(__name__) +def _unpack_gather_result(result: Any, default: Any) -> Any: + if isinstance(result, LatencyBudgetExceededError): + raise result + if isinstance(result, BaseException): + return default + return result + + class MultiPathRetrieval(RetrievalStrategy): """Multi-path retrieval combining RELATES edge vector search, fulltext, graph traversal, and cosine reranking. @@ -186,33 +195,30 @@ async def _execute( **kwargs: Any, ) -> RawSearchResult: # 1. Extract keywords - simple_kw, llm_kw = await self._extract_keywords(query) + ctx.ensure_budget("MultiPath keyword extraction") + simple_kw, llm_kw = await self._extract_keywords(query, ctx) all_keywords = llm_kw[:8] + simple_kw ctx.log(f"MultiPath [1/9]: {len(all_keywords)} keywords extracted") # 2. Embed question only - query_vector = await self._embedder.aembed_query(query) + ctx.ensure_budget("MultiPath question embedding") + query_vector = await self._embedder.aembed_query( + query, + timeout=ctx.remaining_budget_seconds, + ) # 3. RELATES vector search + Text-to-Cypher (parallel when enabled) if self._enable_cypher: results = await asyncio.gather( - search_relates_edges(self._vector, query_vector, self._rel_top_k), - execute_cypher_retrieval(self._graph, self._llm, query), + search_relates_edges(self._vector, query_vector, self._rel_top_k, ctx=ctx), + execute_cypher_retrieval(self._graph, self._llm, query, ctx=ctx), return_exceptions=True, ) - # Unpack RELATES results - if isinstance(results[0], BaseException): - fact_strings_scored, rel_entities = [], {} - else: - fact_strings_scored, rel_entities = results[0] - # Unpack Cypher results - cypher_facts: list[str] = [] - cypher_entities: dict[str, dict] = {} - if not isinstance(results[1], BaseException): - cypher_facts, cypher_entities = results[1] + fact_strings_scored, rel_entities = _unpack_gather_result(results[0], ([], {})) + cypher_facts, cypher_entities = _unpack_gather_result(results[1], ([], {})) else: fact_strings_scored, rel_entities = await search_relates_edges( - self._vector, query_vector, self._rel_top_k + self._vector, query_vector, self._rel_top_k, ctx=ctx ) cypher_facts, cypher_entities = [], {} @@ -225,10 +231,10 @@ async def _execute( f"{len(fact_strings)} after filtering" + (f", {len(cypher_facts)} cypher results" if cypher_facts else "") ) - + # 4. Entity discovery (2 paths) + merge rel_entities + cypher_entities # 4. Entity discovery (2 paths) + merge rel_entities + cypher_entities found_entities, entity_sources = await discover_entities( - self._graph, self._vector, llm_kw, all_keywords + self._graph, self._vector, llm_kw, all_keywords, ctx=ctx ) for eid, einfo in rel_entities.items(): if eid not in found_entities: @@ -242,20 +248,25 @@ async def _execute( # 4b. Sibling expansion for enumeration queries if is_enumeration_query(query): - n_siblings = await expand_sibling_entities(self._graph, found_entities, entity_sources) + n_siblings = await expand_sibling_entities( + self._graph, + found_entities, + entity_sources, + ctx=ctx, + ) if n_siblings: ctx.log( f"MultiPath [4b/9]: +{n_siblings} sibling entities " f"({len(found_entities)} total)" ) - + # 5. Relationship expansion # 5. Relationship expansion entity_list = list(found_entities.items())[: self._max_entities] relationship_strings = await expand_relationships( - self._graph, entity_list, self._max_relationships + self._graph, entity_list, self._max_relationships, ctx=ctx ) ctx.log(f"MultiPath [5/9]: {len(relationship_strings)} relationships") - + # 6. Chunk retrieval (4 paths) # 6. Chunk retrieval (4 paths) candidate_chunks, chunk_sources, chunk_embeddings = await retrieve_chunks( self._vector, @@ -265,15 +276,20 @@ async def _execute( llm_kw, simple_kw, entity_list, + ctx=ctx, ) ctx.log( f"MultiPath [6/9]: {len(candidate_chunks)} candidate chunks " f"({len(chunk_embeddings)} with stored embeddings)" ) - # 7. Source document names - chunk_doc_map = await fetch_chunk_documents(self._graph, list(candidate_chunks.keys())) - + # 7. Source document names + chunk_doc_map = await fetch_chunk_documents( + self._graph, + list(candidate_chunks.keys()), + ctx=ctx, + ) + # 8. Cosine rerank (uses stored embeddings when available) # 8. Cosine rerank (uses stored embeddings when available) source_passages = await rerank_chunks( self._embedder, @@ -281,6 +297,7 @@ async def _execute( candidate_chunks, self._chunk_top_k, stored_embeddings=chunk_embeddings, + ctx=ctx, ) # Tag with source docs @@ -295,6 +312,7 @@ async def _execute( ctx.log(f"MultiPath [8/9]: {len(source_passages)} passages after rerank") # 9. Detect question type + assemble + ctx.ensure_budget("MultiPath result assembly") q_type_hint = detect_question_type(query) return assemble_raw_result( entity_list, @@ -322,24 +340,32 @@ def _format(self, raw: RawSearchResult) -> RetrieverResult: # -- Internal: keyword extraction (stays in orchestrator) -- - async def _extract_keywords(self, query: str) -> tuple[list[str], list[str]]: + async def _extract_keywords( + self, + query: str, + ctx: Context, + ) -> tuple[list[str], list[str]]: """Extract simple + LLM-based keywords from the query.""" words = re.sub(r"[?.!,;:'\"\-()\[\]]", " ", query.lower()).split() simple = [w for w in words if w not in self._STOP_WORDS and len(w) > 2][:12] llm_kw: list[str] = [] try: + ctx.ensure_budget("MultiPath keyword extraction LLM call") response = await self._llm.ainvoke( "Extract ALL proper nouns, character names, person names, place names, " "book titles, and specific terms from this question. " "Return them comma-separated, nothing else.\n\n" - f"Question: {query}\n\nNames: " + f"Question: {query}\n\nNames: ", + timeout=ctx.remaining_budget_seconds, ) llm_kw = [ k.strip().strip("'\"").rstrip("()").strip() for k in response.content.split(",") if k.strip() and len(k.strip()) > 1 ] + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug(f"LLM keyword extraction failed: {exc}") @@ -351,7 +377,12 @@ async def _search_relates_edges( self, query_vector: list[float] ) -> tuple[list[tuple[str, float]], dict[str, dict]]: """Backward-compat wrapper — delegates to module function.""" - return await search_relates_edges(self._vector, query_vector, self._rel_top_k) + return await search_relates_edges( + self._vector, + query_vector, + self._rel_top_k, + ctx=Context(), + ) @staticmethod def _detect_question_type(query: str) -> str: diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/relationship_expansion.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/relationship_expansion.py index 57264be8..802fa5c5 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/relationship_expansion.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/relationship_expansion.py @@ -6,13 +6,18 @@ import logging from typing import Any +from graphrag_sdk.core.context import Context +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError + logger = logging.getLogger(__name__) +_UNBOUNDED = Context() async def expand_relationships( graph_store: Any, entity_list: list[tuple[str, dict]], max_relationships: int = 20, + ctx: Context = _UNBOUNDED, ) -> list[str]: """1-hop + 2-hop relationship expansion from top entities. @@ -30,6 +35,7 @@ async def expand_relationships( eids_1hop = [eid for eid, _ in entity_list[:15]] if eids_1hop: try: + ctx.ensure_budget("1-hop relationship expansion") result = await graph_store.query_raw( "UNWIND $eids AS eid " "MATCH (a:__Entity__ {id: eid})-[r:RELATES]->(b:__Entity__) " @@ -50,6 +56,8 @@ async def expand_relationships( if fact: line += f": {fact}" relationship_strings.append(line) + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Relationship expansion failed: %s", exc) @@ -57,6 +65,7 @@ async def expand_relationships( eids_2hop = [eid for eid, _ in entity_list[:5]] if eids_2hop: try: + ctx.ensure_budget("2-hop relationship expansion") result = await graph_store.query_raw( "UNWIND $eids AS eid " "MATCH (a:__Entity__ {id: eid})-[r1:RELATES]->(b:__Entity__)" @@ -77,6 +86,8 @@ async def expand_relationships( seen.add(key) line = f"{a_name} —[{r1_type}]→ {b_name} —[{r2_type}]→ {c_name}" relationship_strings.append(line) + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Relationship expansion failed: %s", exc) diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/result_assembly.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/result_assembly.py index 6b2ad9a4..50230e5c 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/result_assembly.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/result_assembly.py @@ -8,10 +8,13 @@ import numpy as np +from graphrag_sdk.core.context import Context +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError from graphrag_sdk.core.models import RawSearchResult from graphrag_sdk.core.providers import Embedder logger = logging.getLogger(__name__) +_UNBOUNDED = Context() def cosine_sim(a: list[float], b: list[float]) -> float: @@ -29,6 +32,7 @@ async def rerank_chunks( candidate_chunks: dict[str, str], chunk_top_k: int = 15, stored_embeddings: dict[str, list[float]] | None = None, + ctx: Context = _UNBOUNDED, ) -> list[str]: """Rank candidates by cosine similarity, take top_k. @@ -60,10 +64,16 @@ async def rerank_chunks( # Fallback: re-embed all candidates (coverage too low for fast path) try: - chunk_vectors = await embedder.aembed_documents(chunk_texts) + ctx.ensure_budget("chunk reranking embedding") + chunk_vectors = await embedder.aembed_documents( + chunk_texts, + timeout=ctx.remaining_budget_seconds, + ) scored = [(i, cosine_sim(query_vector, cvec)) for i, cvec in enumerate(chunk_vectors)] scored.sort(key=lambda x: x[1], reverse=True) return [chunk_texts[i] for i, _ in scored[:chunk_top_k]] + except LatencyBudgetExceededError: + raise except Exception as exc: logger.debug("Chunk reranking failed, returning unranked: %s", exc) return chunk_texts[:chunk_top_k] diff --git a/graphrag_sdk/src/graphrag_sdk/storage/graph_store.py b/graphrag_sdk/src/graphrag_sdk/storage/graph_store.py index 3d04a50a..03651c8e 100644 --- a/graphrag_sdk/src/graphrag_sdk/storage/graph_store.py +++ b/graphrag_sdk/src/graphrag_sdk/storage/graph_store.py @@ -906,13 +906,13 @@ def _clean_properties(cls, props: dict[str, Any]) -> dict[str, Any]: for key, value in props.items(): if value is None: continue - if isinstance(value, (str, int, float, bool)): + if isinstance(value, str | int | float | bool): cleaned[key] = cls._sanitize_string(value) if isinstance(value, str) else value elif isinstance(value, list): # FalkorDB supports lists of primitives — filter items filtered: list[str | int | float | bool] = [] for item in value: - if not isinstance(item, (str, int, float, bool)): + if not isinstance(item, str | int | float | bool): continue if isinstance(item, str): filtered.append(cls._sanitize_string(item)) diff --git a/graphrag_sdk/tests/test_chunk_retrieval.py b/graphrag_sdk/tests/test_chunk_retrieval.py index 2b59fb49..a731cf30 100644 --- a/graphrag_sdk/tests/test_chunk_retrieval.py +++ b/graphrag_sdk/tests/test_chunk_retrieval.py @@ -5,7 +5,12 @@ import pytest -from graphrag_sdk.retrieval.strategies.chunk_retrieval import fetch_chunk_documents +from graphrag_sdk.core.context import Context +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError +from graphrag_sdk.retrieval.strategies.chunk_retrieval import ( + fetch_chunk_documents, + retrieve_chunks, +) def _graph_with_rows(rows: list[list]) -> MagicMock: @@ -90,3 +95,59 @@ async def test_query_failure_returns_empty_mapping(self): graph.query_raw = AsyncMock(side_effect=RuntimeError("graph down")) mapping = await fetch_chunk_documents(graph, ["chunk-1"]) assert mapping == {} + + async def test_budget_error_propagates(self): + graph = MagicMock() + graph.query_raw = AsyncMock( + side_effect=LatencyBudgetExceededError("budget exhausted") + ) + + with pytest.raises(LatencyBudgetExceededError, match="budget exhausted"): + await fetch_chunk_documents(graph, ["chunk-1"]) + + +class TestRetrieveChunks: + async def test_budget_error_propagates_from_fulltext_path(self): + vector = MagicMock() + vector.fulltext_search_chunks = AsyncMock( + side_effect=LatencyBudgetExceededError("budget exhausted") + ) + graph = MagicMock() + + with pytest.raises(LatencyBudgetExceededError, match="budget exhausted"): + await retrieve_chunks( + vector, + graph, + "query", + [0.1], + [], + [], + [], + ) + + async def test_budget_checked_between_fulltext_queries(self): + ctx = Context(latency_budget_ms=1000.0) + + async def first_fulltext_exhausts_budget(*args, **kwargs): + ctx.latency_budget_ms = 0.0 + return [] + + vector = MagicMock() + vector.fulltext_search_chunks = AsyncMock(side_effect=first_fulltext_exhausts_budget) + vector.search_chunks = AsyncMock(return_value=[]) + graph = MagicMock() + + with pytest.raises(LatencyBudgetExceededError, match="chunk fulltext search"): + await retrieve_chunks( + vector, + graph, + "query", + [0.1], + ["second-query"], + [], + [], + ctx=ctx, + ) + + assert vector.fulltext_search_chunks.await_count == 1 + vector.search_chunks.assert_not_awaited() diff --git a/graphrag_sdk/tests/test_connection.py b/graphrag_sdk/tests/test_connection.py index 922cff59..b981e660 100644 --- a/graphrag_sdk/tests/test_connection.py +++ b/graphrag_sdk/tests/test_connection.py @@ -105,7 +105,7 @@ async def test_query_retries(self): assert result == "success" assert mock_graph.query.call_count == 3 - async def test_query_exhausts_retries(self): + async def test_query_exhausts_retries(self, caplog): conn = FalkorDBConnection(ConnectionConfig(retry_count=2, retry_delay=0.0)) mock_graph = MagicMock() @@ -113,10 +113,26 @@ async def test_query_exhausts_retries(self): conn._graph = mock_graph conn._driver = MagicMock() - with pytest.raises(Exception, match="always fails"): - await conn.query("MATCH (n) RETURN n") + with caplog.at_level("ERROR", logger="graphrag_sdk.core.connection"): + with pytest.raises(Exception, match="always fails"): + await conn.query("MATCH (n) RETURN n") + assert "FalkorDB query failed after 2 attempts" in caplog.text assert mock_graph.query.call_count == 2 + async def test_query_non_transient_failure_is_typed_and_logged(self, caplog): + conn = FalkorDBConnection(ConnectionConfig(retry_count=3, retry_delay=0.0)) + + mock_graph = MagicMock() + mock_graph.query = AsyncMock(side_effect=Exception("already indexed")) + conn._graph = mock_graph + conn._driver = MagicMock() + + with caplog.at_level("ERROR", logger="graphrag_sdk.core.connection"): + with pytest.raises(Exception, match="already indexed"): + await conn.query("CREATE INDEX idx") + assert "Non-transient FalkorDB query failure" in caplog.text + assert mock_graph.query.call_count == 1 + async def test_query_with_params(self): conn = FalkorDBConnection(ConnectionConfig(retry_count=1)) diff --git a/graphrag_sdk/tests/test_context.py b/graphrag_sdk/tests/test_context.py index 921edd9b..70edbecf 100644 --- a/graphrag_sdk/tests/test_context.py +++ b/graphrag_sdk/tests/test_context.py @@ -6,6 +6,7 @@ import pytest from graphrag_sdk.core.context import Context +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError class TestContextCreation: @@ -66,6 +67,15 @@ def test_budget_exceeded(self): assert ctx.budget_exceeded is True assert ctx.remaining_budget_ms == 0.0 + def test_ensure_budget_raises_when_exhausted(self): + ctx = Context(latency_budget_ms=0.0) + with pytest.raises(LatencyBudgetExceededError, match="before test operation"): + ctx.ensure_budget("test operation") + + def test_ensure_budget_allows_operation_without_budget(self): + ctx = Context() + ctx.ensure_budget("test operation") + class TestChildContext: def test_inherits_parent(self): diff --git a/graphrag_sdk/tests/test_cypher_generation.py b/graphrag_sdk/tests/test_cypher_generation.py index 58694158..9c145e80 100644 --- a/graphrag_sdk/tests/test_cypher_generation.py +++ b/graphrag_sdk/tests/test_cypher_generation.py @@ -136,6 +136,55 @@ def test_removes_path_assignment(self): class TestExecuteCypherRetrieval: + async def test_budget_exhaustion_propagates_before_generation(self): + """Latency budget exhaustion should not be swallowed as generation failure.""" + from unittest.mock import AsyncMock, MagicMock + + from graphrag_sdk.core.context import Context + from graphrag_sdk.core.exceptions import LatencyBudgetExceededError + from graphrag_sdk.retrieval.strategies.cypher_generation import ( + execute_cypher_retrieval, + ) + + mock_llm = MagicMock() + mock_llm.ainvoke = AsyncMock() + mock_graph = MagicMock() + + with pytest.raises(LatencyBudgetExceededError, match="Cypher generation LLM call"): + await execute_cypher_retrieval( + mock_graph, + mock_llm, + "test?", + ctx=Context(latency_budget_ms=0.0), + ) + mock_llm.ainvoke.assert_not_awaited() + + async def test_budget_exhaustion_propagates_before_execution(self): + """Budget expiry during generation should block the Cypher DB query.""" + from unittest.mock import AsyncMock, MagicMock + + from graphrag_sdk.core.context import Context + from graphrag_sdk.core.exceptions import LatencyBudgetExceededError + from graphrag_sdk.core.models import LLMResponse + from graphrag_sdk.retrieval.strategies.cypher_generation import ( + execute_cypher_retrieval, + ) + + ctx = Context(latency_budget_ms=1000.0) + + async def generate_and_exhaust(*args, **kwargs): + ctx.latency_budget_ms = 0.0 + return LLMResponse(content="```cypher\nMATCH (n:Person) RETURN n.name LIMIT 10\n```") + + mock_llm = MagicMock() + mock_llm.ainvoke = AsyncMock(side_effect=generate_and_exhaust) + mock_graph = MagicMock() + mock_graph.query_raw = AsyncMock() + + with pytest.raises(LatencyBudgetExceededError, match="Cypher execution"): + await execute_cypher_retrieval(mock_graph, mock_llm, "test?", ctx=ctx) + mock_graph.query_raw.assert_not_awaited() + async def test_returns_empty_on_generation_failure(self): """When LLM returns garbage, should return empty results.""" from unittest.mock import AsyncMock, MagicMock diff --git a/graphrag_sdk/tests/test_exceptions.py b/graphrag_sdk/tests/test_exceptions.py index 12f84eea..64e98683 100644 --- a/graphrag_sdk/tests/test_exceptions.py +++ b/graphrag_sdk/tests/test_exceptions.py @@ -8,11 +8,14 @@ ConfigError, DatabaseError, EmbeddingError, + EmbeddingTimeoutError, ExtractionError, GraphRAGError, IndexError_, IngestionError, + LatencyBudgetExceededError, LLMError, + LLMTimeoutError, LoaderError, ResolutionError, RetrieverError, @@ -25,14 +28,26 @@ def test_base_exception(self): with pytest.raises(GraphRAGError): raise GraphRAGError("base error") + def test_latency_budget_error_is_graphrag_error(self): + with pytest.raises(GraphRAGError): + raise LatencyBudgetExceededError("budget exceeded") + def test_llm_error_is_graphrag_error(self): with pytest.raises(GraphRAGError): raise LLMError("llm failed") + def test_llm_timeout_error_is_llm_error(self): + with pytest.raises(LLMError): + raise LLMTimeoutError("llm timed out") + def test_embedding_error_is_graphrag_error(self): with pytest.raises(GraphRAGError): raise EmbeddingError("embedding failed") + def test_embedding_timeout_error_is_embedding_error(self): + with pytest.raises(EmbeddingError): + raise EmbeddingTimeoutError("embedding timed out") + def test_ingestion_error_hierarchy(self): assert issubclass(IngestionError, GraphRAGError) assert issubclass(LoaderError, IngestionError) diff --git a/graphrag_sdk/tests/test_facade.py b/graphrag_sdk/tests/test_facade.py index 78072b4f..ff6e833a 100644 --- a/graphrag_sdk/tests/test_facade.py +++ b/graphrag_sdk/tests/test_facade.py @@ -9,7 +9,11 @@ from graphrag_sdk.api.main import GraphRAG from graphrag_sdk.core.connection import ConnectionConfig, FalkorDBConnection from graphrag_sdk.core.context import Context -from graphrag_sdk.core.exceptions import ConfigError, DocumentNotFoundError +from graphrag_sdk.core.exceptions import ( + ConfigError, + DocumentNotFoundError, + LatencyBudgetExceededError, +) from graphrag_sdk.core.models import ( ApplyChangesResult, ChatMessage, @@ -102,6 +106,41 @@ async def _execute(self, query, ctx, **kwargs): g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, retrieval_strategy=strategy, embedding_dimension=8) assert g._retrieval_strategy is strategy + async def test_async_context_manager_returns_self_and_closes(self, mock_conn, embedder, llm): + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + + async with g as rag: + assert rag is g + + mock_conn.close.assert_awaited_once() + + async def test_async_context_manager_preserves_inner_exception_on_close_failure( + self, mock_conn, embedder, llm, caplog + ): + mock_conn.close = AsyncMock(side_effect=RuntimeError("close failed")) + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + + try: + with caplog.at_level("WARNING", logger="graphrag_sdk.api.main"): + with pytest.raises(ValueError, match="inner failure"): + async with g: + raise ValueError("inner failure") + finally: + mock_conn.close.assert_awaited_once() + assert "Error closing connection during __aexit__" in caplog.text + + async def test_async_context_manager_raises_close_failure_without_inner_exception( + self, mock_conn, embedder, llm + ): + mock_conn.close = AsyncMock(side_effect=RuntimeError("close failed")) + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + + with pytest.raises(RuntimeError, match="close failed"): + async with g: + pass + + mock_conn.close.assert_awaited_once() + class TestGraphRAGGraphAdmin: """A1: graph admin operations exposed as facade methods.""" @@ -410,6 +449,89 @@ async def rerank(self, query, result, ctx): assert result.items[0].content == "B" assert llm._call_index == 0 + async def test_retrieve_checks_budget_before_config_embedder_probe( + self, mock_conn, embedder + ): + llm = MockLLM(responses=["should not be called"]) + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + ctx = Context(latency_budget_ms=1000.0) + + async def config_query_exhausts_budget(*args, **kwargs): + ctx.latency_budget_ms = 0.0 + result = MagicMock() + result.result_set = [] + return result + + g._graph_store.query_raw = AsyncMock(side_effect=config_query_exhausts_budget) + mock_strategy = MagicMock(spec=RetrievalStrategy) + mock_strategy.search = AsyncMock(return_value=RetrieverResult()) + g._retrieval_strategy = mock_strategy + + with pytest.raises(LatencyBudgetExceededError, match="graph config embedder probe"): + await g.retrieve("test", ctx=ctx) + + assert embedder.call_count == 0 + mock_strategy.search.assert_not_awaited() + + async def test_retrieve_checks_budget_before_config_query( + self, mock_conn, embedder, monkeypatch + ): + llm = MockLLM(responses=["should not be called"]) + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + g._graph_store.query_raw = AsyncMock() + mock_strategy = MagicMock(spec=RetrievalStrategy) + mock_strategy.search = AsyncMock(return_value=RetrieverResult()) + g._retrieval_strategy = mock_strategy + ctx = Context(latency_budget_ms=1000.0) + + def exhaust_budget(self: Context, operation: str) -> None: + if operation == "graph config query": + raise LatencyBudgetExceededError("budget exhausted before config query") + + monkeypatch.setattr(Context, "ensure_budget", exhaust_budget) + with pytest.raises(LatencyBudgetExceededError, match="config query"): + await g.retrieve("test", ctx=ctx) + + g._graph_store.query_raw.assert_not_awaited() + mock_strategy.search.assert_not_awaited() + + async def test_retrieve_propagates_budget_error_from_config_query( + self, mock_conn, embedder + ): + llm = MockLLM(responses=["should not be called"]) + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + g._graph_store.query_raw = AsyncMock( + side_effect=LatencyBudgetExceededError("budget exhausted in config query") + ) + mock_strategy = MagicMock(spec=RetrievalStrategy) + mock_strategy.search = AsyncMock(return_value=RetrieverResult()) + g._retrieval_strategy = mock_strategy + + with pytest.raises(LatencyBudgetExceededError, match="config query"): + await g.retrieve("test", ctx=Context(latency_budget_ms=1000.0)) + + mock_strategy.search.assert_not_awaited() + + async def test_retrieve_propagates_budget_error_from_config_probe( + self, mock_conn, embedder + ): + llm = MockLLM(responses=["should not be called"]) + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + result = MagicMock() + result.result_set = [] + g._graph_store.query_raw = AsyncMock(return_value=result) + embedder.aembed_query = AsyncMock( + side_effect=LatencyBudgetExceededError("budget exhausted in config probe") + ) + mock_strategy = MagicMock(spec=RetrievalStrategy) + mock_strategy.search = AsyncMock(return_value=RetrieverResult()) + g._retrieval_strategy = mock_strategy + + with pytest.raises(LatencyBudgetExceededError, match="config probe"): + await g.retrieve("test", ctx=Context(latency_budget_ms=1000.0)) + + mock_strategy.search.assert_not_awaited() + class TestGraphRAGCompletion: async def test_completion_basic(self, mock_conn, embedder): @@ -534,6 +656,19 @@ async def test_completion_history_wrong_type_raises(self, mock_conn, embedder): with pytest.raises(TypeError, match="expected ChatMessage or dict"): await g.completion("test?", history=["not a dict"]) + async def test_completion_history_non_string_content_raises(self, mock_conn, embedder): + """History dict content must be a string.""" + llm = MockLLM(responses=["unused"]) + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + mock_strategy = MagicMock(spec=RetrievalStrategy) + mock_strategy.search = AsyncMock( + return_value=RetrieverResult(items=[RetrieverResultItem(content="c")]) + ) + g._retrieval_strategy = mock_strategy + + with pytest.raises(ValueError, match="content must be a string"): + await g.completion("test?", history=[{"role": "user", "content": 123}]) + async def test_completion_no_history_uses_messages_api(self, mock_conn, embedder): """Without history, completion still uses ainvoke_messages (unified path).""" llm = MockLLM(responses=["Single-turn answer."]) @@ -680,6 +815,19 @@ async def test_completion_rewrite_question_enabled(self, mock_conn, embedder): assert result.metadata["retrieval_query"] == "Where did Jane Doe go to college?" assert result.answer == "She attended Stanford University." + async def test_question_rewrite_budget_exhaustion_propagates(self, mock_conn, embedder): + llm = MockLLM(responses=["unused"]) + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + + with pytest.raises(LatencyBudgetExceededError, match="question rewrite LLM call"): + await g._rewrite_question_with_history( + "Where did she go?", + [ChatMessage(role="user", content="Who is Jane?")], + ctx=Context(latency_budget_ms=0.0), + ) + + assert llm._call_index == 0 + async def test_completion_rewrite_fallback_on_empty(self, mock_conn, embedder): """If the rewrite LLM returns empty, fall back to the original question.""" llm = MockLLM(responses=["", "Some answer."]) @@ -729,6 +877,40 @@ async def test_completion_custom_prompt_template_with_history(self, mock_conn, e assert "SRC" in final.content assert "What is it?" in final.content + async def test_completion_checks_budget_before_retrieval(self, mock_conn, embedder): + llm = MockLLM(responses=["unused"]) + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + mock_strategy = MagicMock(spec=RetrievalStrategy) + mock_strategy.search = AsyncMock( + return_value=RetrieverResult(items=[RetrieverResultItem(content="c")]) + ) + g._retrieval_strategy = mock_strategy + ctx = Context(latency_budget_ms=0.0) + + with pytest.raises(LatencyBudgetExceededError, match="completion retrieval"): + await g.completion("q?", ctx=ctx) + + mock_strategy.search.assert_not_awaited() + assert llm._call_index == 0 + + async def test_completion_checks_budget_before_final_llm(self, mock_conn, embedder): + llm = MockLLM(responses=["unused"]) + g = GraphRAG(connection=mock_conn, llm=llm, embedder=embedder, embedding_dimension=8) + mock_strategy = MagicMock(spec=RetrievalStrategy) + + async def search_and_exhaust(question, ctx): + ctx.latency_budget_ms = 0.0 + return RetrieverResult(items=[RetrieverResultItem(content="c")]) + + mock_strategy.search = AsyncMock(side_effect=search_and_exhaust) + g._retrieval_strategy = mock_strategy + + with pytest.raises(LatencyBudgetExceededError, match="completion LLM call"): + await g.completion("q?", ctx=Context(latency_budget_ms=1000.0)) + + mock_strategy.search.assert_awaited_once() + assert llm._call_index == 0 + class TestGraphRAGCompletionInjectionDefenses: """Verifies the S4 mitigations: context delimiters + close-tag neutralization.""" diff --git a/graphrag_sdk/tests/test_integration.py b/graphrag_sdk/tests/test_integration.py index 45b0102f..222437e7 100644 --- a/graphrag_sdk/tests/test_integration.py +++ b/graphrag_sdk/tests/test_integration.py @@ -237,6 +237,7 @@ async def _entity_count(rag, name: str) -> int: @pytest.mark.asyncio +@pytest.mark.integration class TestIncrementalUpdateInvariants: """v1.1.0 load-bearing correctness: scoped orphan cleanup preserves entities shared across documents. Run against real FalkorDB. diff --git a/graphrag_sdk/tests/test_loaders.py b/graphrag_sdk/tests/test_loaders.py index 46620b08..4af116be 100644 --- a/graphrag_sdk/tests/test_loaders.py +++ b/graphrag_sdk/tests/test_loaders.py @@ -1,6 +1,8 @@ """Tests for ingestion/loaders/ — TextLoader and PdfLoader.""" from __future__ import annotations +import logging + import pytest from graphrag_sdk.core.exceptions import LoaderError @@ -59,6 +61,20 @@ async def test_document_info_has_uid(self, ctx, tmp_path): result = await loader.load(str(file), ctx) assert result.document_info.uid # non-empty UUID + async def test_read_failure_is_logged_and_typed(self, ctx, tmp_path, monkeypatch, caplog): + file = tmp_path / "unreadable.txt" + file.write_text("content") + loader = TextLoader() + + def fail_read_text(*args, **kwargs): + raise PermissionError("denied") + + monkeypatch.setattr("pathlib.Path.read_text", fail_read_text) + with caplog.at_level(logging.ERROR, logger="graphrag_sdk.ingestion.loaders.text_loader"): + with pytest.raises(LoaderError, match="Failed to read"): + await loader.load(str(file), ctx) + assert "Failed to read text file" in caplog.text + class TestPdfLoader: """PDF loader tests — skipped if no PDF backend is installed.""" @@ -136,6 +152,26 @@ def fake_pypdf(self, path): result = await pdf_mod.PdfLoader().load(str(file), ctx) assert result.document_info.metadata["pdf_backend"] == "pypdf" + async def test_pymupdf_open_failure_is_logged_and_typed( + self, ctx, tmp_path, monkeypatch, caplog + ): + from graphrag_sdk.ingestion.loaders import pdf_loader as pdf_mod + + file = tmp_path / "corrupt.pdf" + file.write_bytes(b"%PDF-corrupt") + + class FakeFitz: + @staticmethod + def open(path): + raise RuntimeError("corrupt pdf") + + monkeypatch.setitem(__import__("sys").modules, "fitz", FakeFitz) + with caplog.at_level(logging.ERROR, logger="graphrag_sdk.ingestion.loaders.pdf_loader"): + with pytest.raises(LoaderError, match="Failed to read PDF"): + await pdf_mod.PdfLoader().load(str(file), ctx) + assert "Failed to read PDF" in caplog.text + assert "PyMuPDF" in caplog.text + async def test_raises_when_no_backend_installed(self, ctx, tmp_path, monkeypatch): """With neither backend available, a clear install message is raised.""" from graphrag_sdk.ingestion.loaders import pdf_loader as pdf_mod diff --git a/graphrag_sdk/tests/test_multi_path_retrieval.py b/graphrag_sdk/tests/test_multi_path_retrieval.py index b535730e..5eeabcfd 100644 --- a/graphrag_sdk/tests/test_multi_path_retrieval.py +++ b/graphrag_sdk/tests/test_multi_path_retrieval.py @@ -6,6 +6,7 @@ import pytest from graphrag_sdk.core.context import Context +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError from graphrag_sdk.core.models import ( LLMResponse, RawSearchResult, @@ -14,8 +15,10 @@ from graphrag_sdk.retrieval.strategies.entity_discovery import ( expand_sibling_entities, is_enumeration_query, + search_relates_edges, ) from graphrag_sdk.retrieval.strategies.multi_path import MultiPathRetrieval +from graphrag_sdk.retrieval.strategies.result_assembly import rerank_chunks from .conftest import MockEmbedder, MockLLM @@ -80,6 +83,85 @@ async def test_question_embedding(self, strategy, mp_embedder): await strategy.search("Who is Alice?") assert mp_embedder.call_count > 0 + async def test_latency_budget_checked_before_first_phase( + self, strategy, mp_embedder, mp_llm + ): + ctx = Context(latency_budget_ms=0.0) + + with pytest.raises(LatencyBudgetExceededError, match="keyword extraction"): + await strategy.search("Who is Alice?", ctx) + + assert mp_llm._call_index == 0 + assert mp_embedder.call_count == 0 + + async def test_relates_search_budget_error_propagates(self): + vector = MagicMock() + vector.search_relationships = AsyncMock( + side_effect=LatencyBudgetExceededError("budget exhausted") + ) + + with pytest.raises(LatencyBudgetExceededError, match="budget exhausted"): + await search_relates_edges(vector, [0.1], rel_top_k=3) + + async def test_relates_search_checks_budget_before_vector_call(self): + vector = MagicMock() + vector.search_relationships = AsyncMock(return_value=[]) + + with pytest.raises(LatencyBudgetExceededError, match="RELATES vector search"): + await search_relates_edges( + vector, + [0.1], + rel_top_k=3, + ctx=Context(latency_budget_ms=0.0), + ) + + vector.search_relationships.assert_not_awaited() + + async def test_rerank_budget_error_propagates(self, mp_embedder): + mp_embedder.aembed_documents = AsyncMock( + side_effect=LatencyBudgetExceededError("budget exhausted") + ) + + with pytest.raises(LatencyBudgetExceededError, match="budget exhausted"): + await rerank_chunks( + mp_embedder, + [0.1], + {"c1": "text"}, + stored_embeddings={}, + ) + + async def test_latency_budget_checked_before_question_embedding( + self, strategy, mp_embedder, mp_llm, monkeypatch + ): + ctx = Context(latency_budget_ms=1000.0) + + def exhaust_budget(self: Context, operation: str) -> None: + if operation == "MultiPath question embedding": + raise LatencyBudgetExceededError("budget exhausted before embedding") + + monkeypatch.setattr(Context, "ensure_budget", exhaust_budget) + with pytest.raises(LatencyBudgetExceededError, match="before embedding"): + await strategy.search("Who is Alice?", ctx) + + assert mp_llm._call_index == 1 + assert mp_embedder.call_count == 0 + + async def test_latency_budget_propagates_from_keyword_llm( + self, strategy, mp_embedder, mp_llm, monkeypatch + ): + ctx = Context(latency_budget_ms=1000.0) + + def exhaust_budget(self: Context, operation: str) -> None: + if operation == "MultiPath keyword extraction LLM call": + raise LatencyBudgetExceededError("budget exhausted before keyword LLM") + + monkeypatch.setattr(Context, "ensure_budget", exhaust_budget) + with pytest.raises(LatencyBudgetExceededError, match="keyword LLM"): + await strategy.search("Who is Alice?", ctx) + + assert mp_llm._call_index == 0 + assert mp_embedder.call_count == 0 + async def test_relates_edge_vector_search(self, mp_graph_store, mp_vector_store, mp_embedder, mp_llm): """RELATES edge vector search should be called and return facts + entities.""" mp_vector_store.search_relationships = AsyncMock(return_value=[ diff --git a/graphrag_sdk/tests/test_pipeline.py b/graphrag_sdk/tests/test_pipeline.py index 3945fcf1..321367ae 100644 --- a/graphrag_sdk/tests/test_pipeline.py +++ b/graphrag_sdk/tests/test_pipeline.py @@ -194,7 +194,7 @@ async def extract(self, chunks, schema, ctx): # KNOWS survives; WRONG pruned (+ Alien endpoint gone) assert result.relationships_created == 1 - async def test_pipeline_wraps_exception(self, ctx, mock_graph_store, mock_vector_store): + async def test_pipeline_wraps_exception(self, ctx, mock_graph_store, mock_vector_store, caplog): """Non-IngestionError exceptions get wrapped.""" class FailingLoader(LoaderStrategy): async def load(self, source, ctx): @@ -209,8 +209,10 @@ async def load(self, source, ctx): vector_store=mock_vector_store, schema=GraphSchema(), ) - with pytest.raises(IngestionError, match="Pipeline failed"): - await pipeline.run("test.txt", ctx) + with caplog.at_level("ERROR", logger="graphrag_sdk.ingestion.pipeline"): + with pytest.raises(IngestionError, match="Pipeline failed"): + await pipeline.run("test.txt", ctx) + assert "Pipeline failed with unexpected error" in caplog.text async def test_pipeline_writes_content_hash(self, ctx, mock_graph_store, mock_vector_store): """v1.1.0: Document node carries SHA-256 of the loaded text so diff --git a/graphrag_sdk/tests/test_providers.py b/graphrag_sdk/tests/test_providers.py index 4dfe0caa..f50bb36b 100644 --- a/graphrag_sdk/tests/test_providers.py +++ b/graphrag_sdk/tests/test_providers.py @@ -2,25 +2,25 @@ from __future__ import annotations import asyncio -import json +import time from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest from pydantic import BaseModel +from graphrag_sdk.core.exceptions import EmbeddingTimeoutError, LLMTimeoutError from graphrag_sdk.core.models import LLMResponse from graphrag_sdk.core.providers import ( Embedder, - LLMBatchItem, - LLMInterface, LiteLLM, LiteLLMEmbedder, + LLMBatchItem, + LLMInterface, OpenRouterEmbedder, OpenRouterLLM, ) - # ── Concrete test implementations ────────────────────────────── @@ -56,6 +56,26 @@ async def test_aembed_query_default(self): assert isinstance(result, list) assert result[0] == 5.0 # len("world") + async def test_aembed_query_timeout_raises_typed_error(self): + class SlowEmbedder(SimpleEmbedder): + def embed_query(self, text: str, **kwargs: Any) -> list[float]: + time.sleep(0.05) + return super().embed_query(text, **kwargs) + + emb = SlowEmbedder() + with pytest.raises(EmbeddingTimeoutError, match="timed out"): + await emb.aembed_query("world", timeout=0.001) + + async def test_aembed_documents_timeout_raises_typed_error(self): + class SlowEmbedder(SimpleEmbedder): + def embed_documents(self, texts: list[str], **kwargs: Any) -> list[list[float]]: + time.sleep(0.05) + return super().embed_documents(texts, **kwargs) + + emb = SlowEmbedder() + with pytest.raises(EmbeddingTimeoutError, match="timed out"): + await emb.aembed_documents(["world"], timeout=0.001) + def test_cannot_instantiate_abc(self): with pytest.raises(TypeError): Embedder() # type: ignore[abstract] @@ -81,6 +101,35 @@ async def test_ainvoke_default(self): response = await llm.ainvoke("Async test") assert response.content.startswith("Echo:") + async def test_ainvoke_timeout_raises_typed_error(self): + class SlowLLM(SimpleLLM): + def invoke(self, prompt: str, **kwargs: Any) -> LLMResponse: + time.sleep(0.05) + return super().invoke(prompt, **kwargs) + + llm = SlowLLM() + with pytest.raises(LLMTimeoutError, match="timed out"): + await llm.ainvoke("Async test", timeout=0.001, max_retries=1) + + async def test_ainvoke_rejects_non_positive_timeout(self): + llm = SimpleLLM() + with pytest.raises(ValueError, match="timeout must be > 0"): + await llm.ainvoke("Async test", timeout=0) + + async def test_provider_wait_timeout_zero_raises_typed_error(self): + from graphrag_sdk.core.providers._timeout import wait_for_provider_call + + async def never_called(): + await asyncio.sleep(0) + + with pytest.raises(LLMTimeoutError, match="timed out"): + await wait_for_provider_call( + never_called(), + timeout=0.0, + timeout_error=LLMTimeoutError, + operation="test LLM call", + ) + async def test_invoke_with_model(self): class Result(BaseModel): answer: int @@ -165,6 +214,30 @@ def invoke(self, prompt: str, **kwargs: Any) -> LLMResponse: # Single-line invariant assert "\n" not in msg + async def test_ainvoke_exhaustion_logs_error_without_exception_body(self, caplog): + secret_body = "Authorization: Bearer SECRET_KEY_xyz\nproxy=https://internal" + + class FailingLLM(LLMInterface): + def __init__(self) -> None: + super().__init__(model_name="fail") + + def invoke(self, prompt: str, **kwargs: Any) -> LLMResponse: + raise RuntimeError(f"upstream 500\n{secret_body}") + + import logging + + llm = FailingLLM() + with caplog.at_level(logging.ERROR, logger="graphrag_sdk.core.providers.base"): + with pytest.raises(RuntimeError): + await llm.ainvoke("hi", max_retries=1) + errors = [r for r in caplog.records if r.levelno == logging.ERROR] + assert errors + msg = errors[0].getMessage() + assert "LLM call failed after 1 attempts" in msg + assert "upstream 500" in msg + assert "SECRET_KEY_xyz" not in msg + assert "proxy=" not in msg + class TestSummarizeException: """S6: WARNING-level exception logs must be single-line and length-bounded.""" @@ -387,6 +460,19 @@ async def test_ainvoke(self): assert result.content == "async hi" mock_litellm.acompletion.assert_awaited_once() + @pytest.mark.asyncio + async def test_ainvoke_timeout_raises_typed_error(self): + async def slow_completion(**kwargs): + await asyncio.sleep(0.05) + return _mock_litellm_completion_response("too late") + + mock_litellm = MagicMock() + mock_litellm.acompletion = slow_completion + with patch.dict("sys.modules", {"litellm": mock_litellm}): + llm = LiteLLM(model="azure/gpt-4") + with pytest.raises(LLMTimeoutError, match="timed out"): + await llm.ainvoke("Hello", timeout=0.001, max_retries=1) + @pytest.mark.asyncio async def test_ainvoke_retries_on_failure(self): mock_litellm = MagicMock() @@ -603,6 +689,19 @@ async def test_aembed_query(self): assert result == vec + @pytest.mark.asyncio + async def test_aembed_query_timeout_raises_typed_error(self): + async def slow_embedding(**kwargs): + await asyncio.sleep(0.05) + return _mock_litellm_embedding_response([[0.7, 0.8, 0.9]]) + + mock_litellm = MagicMock() + mock_litellm.aembedding = slow_embedding + with patch.dict("sys.modules", {"litellm": mock_litellm}): + embedder = LiteLLMEmbedder(model="text-embedding-ada-002") + with pytest.raises(EmbeddingTimeoutError, match="timed out"): + await embedder.aembed_query("hello", timeout=0.001) + @pytest.mark.asyncio async def test_aembed_documents(self): mock_litellm = MagicMock() @@ -616,6 +715,38 @@ async def test_aembed_documents(self): assert result == vecs + @pytest.mark.asyncio + async def test_aembed_documents_timeout_is_not_binary_split(self): + async def slow_embedding(**kwargs): + await asyncio.sleep(0.05) + return _mock_litellm_embedding_response([[0.1, 0.2]]) + + mock_litellm = MagicMock() + mock_litellm.aembedding = AsyncMock(side_effect=slow_embedding) + with patch.dict("sys.modules", {"litellm": mock_litellm}): + embedder = LiteLLMEmbedder(model="text-embedding-ada-002") + with pytest.raises(EmbeddingTimeoutError, match="timed out"): + await embedder.aembed_documents(["a", "b"], timeout=0.001) + + assert mock_litellm.aembedding.await_count == 1 + + @pytest.mark.asyncio + async def test_aembed_documents_timeout_is_overall_deadline(self): + async def slow_embedding(**kwargs): + await asyncio.sleep(0.02) + return _mock_litellm_embedding_response([[0.1, 0.2]]) + + mock_litellm = MagicMock() + mock_litellm.aembedding = AsyncMock(side_effect=slow_embedding) + with patch.dict("sys.modules", {"litellm": mock_litellm}): + embedder = LiteLLMEmbedder(model="text-embedding-ada-002", batch_size=1) + started = time.monotonic() + with pytest.raises(EmbeddingTimeoutError, match="timed out"): + await embedder.aembed_documents(["a", "b"], timeout=0.03) + + assert time.monotonic() - started < 0.2 + assert mock_litellm.aembedding.await_count == 2 + def test_import_error(self): with patch.dict("sys.modules", {"litellm": None}): embedder = LiteLLMEmbedder(model="text-embedding-ada-002") @@ -632,7 +763,9 @@ class TestOpenRouterLLM: def test_invoke(self): mock_openai = MagicMock() mock_client = MagicMock() - mock_client.chat.completions.create.return_value = _mock_openai_completion_response("router hi") + mock_client.chat.completions.create.return_value = _mock_openai_completion_response( + "router hi" + ) mock_openai.OpenAI.return_value = mock_client with patch.dict("sys.modules", {"openai": mock_openai}): llm = OpenRouterLLM(model="anthropic/claude-sonnet-4-20250514", api_key="or-key") @@ -659,6 +792,22 @@ async def test_ainvoke(self): assert result.content == "async router" mock_openai.AsyncOpenAI.assert_called_once() + @pytest.mark.asyncio + async def test_ainvoke_timeout_raises_typed_error(self): + mock_openai = MagicMock() + mock_async_client = MagicMock() + + async def slow_create(**kwargs): + await asyncio.sleep(0.05) + return _mock_openai_completion_response("too late") + + mock_async_client.chat.completions.create = slow_create + mock_openai.AsyncOpenAI.return_value = mock_async_client + with patch.dict("sys.modules", {"openai": mock_openai}): + llm = OpenRouterLLM(model="openai/gpt-4", api_key="or-key") + with pytest.raises(LLMTimeoutError, match="timed out"): + await llm.ainvoke("Hello", timeout=0.001, max_retries=1) + def test_env_var_fallback(self): mock_openai = MagicMock() mock_client = MagicMock() @@ -705,7 +854,11 @@ def test_embed_documents_batching(self): ] mock_openai.OpenAI.return_value = mock_client with patch.dict("sys.modules", {"openai": mock_openai}): - embedder = OpenRouterEmbedder(model="openai/text-embedding-ada-002", api_key="k", batch_size=2) + embedder = OpenRouterEmbedder( + model="openai/text-embedding-ada-002", + api_key="k", + batch_size=2, + ) result = embedder.embed_documents(["a", "b", "c"]) assert len(result) == 3 @@ -727,6 +880,22 @@ async def test_aembed_query(self): assert result == vec + @pytest.mark.asyncio + async def test_aembed_query_timeout_raises_typed_error(self): + mock_openai = MagicMock() + mock_async_client = MagicMock() + + async def slow_create(**kwargs): + await asyncio.sleep(0.05) + return _mock_openai_embedding_response([[0.4, 0.5, 0.6]]) + + mock_async_client.embeddings.create = slow_create + mock_openai.AsyncOpenAI.return_value = mock_async_client + with patch.dict("sys.modules", {"openai": mock_openai}): + embedder = OpenRouterEmbedder(model="openai/text-embedding-ada-002", api_key="k") + with pytest.raises(EmbeddingTimeoutError, match="timed out"): + await embedder.aembed_query("hello", timeout=0.001) + def test_import_error(self): with patch.dict("sys.modules", {"openai": None}): embedder = OpenRouterEmbedder(model="text-embedding-ada-002", api_key="k") diff --git a/graphrag_sdk/tests/test_retrieval.py b/graphrag_sdk/tests/test_retrieval.py index d48bc7f1..d5425f10 100644 --- a/graphrag_sdk/tests/test_retrieval.py +++ b/graphrag_sdk/tests/test_retrieval.py @@ -7,7 +7,7 @@ import pytest from graphrag_sdk.core.context import Context -from graphrag_sdk.core.exceptions import RetrieverError +from graphrag_sdk.core.exceptions import LatencyBudgetExceededError, RetrieverError from graphrag_sdk.core.models import RawSearchResult, RetrieverResult, RetrieverResultItem from graphrag_sdk.retrieval.strategies.base import RetrievalStrategy from graphrag_sdk.retrieval.strategies.local import LocalRetrieval @@ -53,10 +53,12 @@ async def test_search_validates_whitespace_query(self, ctx): with pytest.raises(RetrieverError, match="Empty query"): await strategy.search(" ", ctx) - async def test_search_wraps_exception(self, ctx): + async def test_search_wraps_exception(self, ctx, caplog): strategy = StubRetrieval(should_fail=True) - with pytest.raises(RetrieverError, match="failed"): - await strategy.search("valid query", ctx) + with caplog.at_level("ERROR", logger="graphrag_sdk.retrieval.strategies.base"): + with pytest.raises(RetrieverError, match="failed"): + await strategy.search("valid query", ctx) + assert "Retrieval [StubRetrieval] failed" in caplog.text async def test_search_default_context(self): """search() creates default Context if none provided.""" @@ -114,6 +116,16 @@ async def test_local_embeds_query(self, ctx, local_strategy, embedder): await local_strategy.search("test query", ctx) assert embedder.call_count >= 1 + async def test_local_query_embedding_uses_remaining_budget(self, local_strategy, embedder): + embedder.aembed_query = AsyncMock(return_value=[0.1] * 8) + ctx = Context(latency_budget_ms=1000.0) + + await local_strategy.search("test query", ctx) + + timeout = embedder.aembed_query.call_args.kwargs["timeout"] + assert timeout is not None + assert 0 < timeout <= 1.0 + async def test_local_includes_entities(self, ctx, local_strategy): result = await local_strategy.search("test", ctx) # Should include entity info in content @@ -155,3 +167,45 @@ async def test_local_scores_in_items(self, ctx, local_strategy): async def test_local_chunk_id_in_metadata(self, ctx, local_strategy): result = await local_strategy.search("test", ctx) assert result.items[0].metadata["chunk_id"] == "chunk-1" + + async def test_local_budget_checked_before_embedding( + self, local_strategy, embedder, mock_vector_store + ): + with pytest.raises(LatencyBudgetExceededError, match="query embedding"): + await local_strategy.search("test", Context(latency_budget_ms=0.0)) + + assert embedder.call_count == 0 + mock_vector_store.search_chunks.assert_not_awaited() + + async def test_local_budget_checked_before_vector_search( + self, local_strategy, embedder, mock_vector_store, monkeypatch + ): + ctx = Context(latency_budget_ms=1000.0) + + def exhaust_budget(self: Context, operation: str) -> None: + if operation == "LocalRetrieval chunk vector search": + raise LatencyBudgetExceededError("budget exhausted before vector search") + + monkeypatch.setattr(Context, "ensure_budget", exhaust_budget) + with pytest.raises(LatencyBudgetExceededError, match="vector search"): + await local_strategy.search("test", ctx) + + assert embedder.call_count == 1 + mock_vector_store.search_chunks.assert_not_awaited() + + async def test_local_budget_checked_before_entity_expansion( + self, local_strategy, embedder, mock_vector_store, mock_graph_store, monkeypatch + ): + ctx = Context(latency_budget_ms=1000.0) + + def exhaust_budget(self: Context, operation: str) -> None: + if operation == "LocalRetrieval entity expansion": + raise LatencyBudgetExceededError("budget exhausted before entity expansion") + + monkeypatch.setattr(Context, "ensure_budget", exhaust_budget) + with pytest.raises(LatencyBudgetExceededError, match="entity expansion"): + await local_strategy.search("test", ctx) + + assert embedder.call_count == 1 + mock_vector_store.search_chunks.assert_awaited_once() + mock_graph_store.get_connected_entities.assert_not_awaited()