From 79fdf98c9b083c0a800867385c15a34a373f2ddb Mon Sep 17 00:00:00 2001 From: Ajay Chinthalapalli Jayakumar Date: Mon, 20 Apr 2026 15:49:06 -0700 Subject: [PATCH] Add py.typed marker, cap dependency versions, and validate task status transitions --- .gitignore | 3 ++- pyproject.toml | 12 +++++----- src/switchplane/persistence.py | 24 +++++++++++++++++++ src/switchplane/py.typed | 0 tests/test_control_plane.py | 1 + tests/test_persistence.py | 42 ++++++++++++++++++++++++++++++++++ 6 files changed, 75 insertions(+), 7 deletions(-) create mode 100644 src/switchplane/py.typed diff --git a/.gitignore b/.gitignore index 372f407..d37804d 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ build/ # Virtual environment .venv/ +venv/ # uv uv.lock @@ -28,4 +29,4 @@ uv.lock # testing .tmp/* -.coverage \ No newline at end of file +.coverage diff --git a/pyproject.toml b/pyproject.toml index 5755562..b1e407d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,10 +21,10 @@ classifiers = [ dependencies = [ "click>=8.3,<9", "pydantic>=2.0,<3", - "aiosqlite>=0.20", - "langgraph>=1.0", - "prompt_toolkit>=3.0", - "structlog>=24.0", + "aiosqlite>=0.20,<1", + "langgraph>=1.0,<2", + "prompt_toolkit>=3.0,<4", + "structlog>=24.0,<25", ] [project.scripts] @@ -36,8 +36,8 @@ Repository = "https://github.com/salesforce-misc/switchplane" Issues = "https://github.com/salesforce-misc/switchplane/issues" [project.optional-dependencies] -llm = ["langchain-core>=0.3"] -mcp = ["mcp>=1.0", "switchplane[llm]"] +llm = ["langchain-core>=0.3,<2"] +mcp = ["mcp>=1.0,<2", "switchplane[llm]"] test = ["pytest>=7.0", "pytest-asyncio>=0.23", "pytest-cov>=4.0", "pytest-xdist>=3.5", "ruff>=0.9"] [tool.hatch.build.targets.wheel] diff --git a/src/switchplane/persistence.py b/src/switchplane/persistence.py index ff395bb..209ec62 100644 --- a/src/switchplane/persistence.py +++ b/src/switchplane/persistence.py @@ -10,6 +10,20 @@ from switchplane.agent import AgentRecord from switchplane.task import TaskRecord, TaskStatus +# Valid status transitions. Keys are the current status; values are the set of +# statuses that may follow. Any transition not listed here is rejected. +_VALID_TRANSITIONS: dict[TaskStatus, frozenset[TaskStatus]] = { + TaskStatus.PENDING: frozenset({TaskStatus.RUNNING, TaskStatus.CANCELLED, TaskStatus.FAILED}), + TaskStatus.RUNNING: frozenset( + {TaskStatus.RUNNING, TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED, TaskStatus.INTERRUPTED} + ), + TaskStatus.INTERRUPTED: frozenset({TaskStatus.RUNNING, TaskStatus.FAILED, TaskStatus.CANCELLED}), + # Terminal states can only transition to PENDING (resume flow). + TaskStatus.COMPLETED: frozenset({TaskStatus.PENDING}), + TaskStatus.FAILED: frozenset({TaskStatus.PENDING}), + TaskStatus.CANCELLED: frozenset({TaskStatus.PENDING}), +} + class Store: """Async SQLite store for control plane persistence.""" @@ -119,6 +133,16 @@ async def update_task(self, task_id: str, **fields: Any) -> None: if not self._db: raise RuntimeError("Store not initialized") + if "status" in fields and isinstance(fields["status"], TaskStatus): + new_status: TaskStatus = fields["status"] + cursor = await self._db.execute("SELECT status FROM tasks WHERE task_id = ?", (task_id,)) + row = await cursor.fetchone() + if row: + current_status = TaskStatus(row[0]) + allowed = _VALID_TRANSITIONS.get(current_status, frozenset()) + if new_status not in allowed: + raise ValueError(f"Invalid task status transition: {current_status} → {new_status}") + # Build the SET clause dynamically set_parts = [] values = [] diff --git a/src/switchplane/py.typed b/src/switchplane/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_control_plane.py b/tests/test_control_plane.py index d92663d..f701cfd 100644 --- a/tests/test_control_plane.py +++ b/tests/test_control_plane.py @@ -355,6 +355,7 @@ async def test_clears(self, cp): }, ) task_id = submit.result["task_id"] + await cp.store.update_task(task_id, status=TaskStatus.RUNNING) await cp.store.update_task(task_id, status=TaskStatus.COMPLETED) resp = await _request(cp, "clear_tasks") diff --git a/tests/test_persistence.py b/tests/test_persistence.py index b6908b5..1f1ccda 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -82,6 +82,7 @@ async def test_update_status(self, store): @pytest.mark.asyncio async def test_update_result(self, store): await store.create_task(_make_task()) + await store.update_task("t1", status=TaskStatus.RUNNING) await store.update_task("t1", status=TaskStatus.COMPLETED, result_json='{"answer": 42}') result = await store.get_task("t1") assert result.status == TaskStatus.COMPLETED @@ -118,6 +119,47 @@ async def test_list_by_status(self, store): assert len(running) == 1 +class TestStatusTransitions: + @pytest.mark.asyncio + async def test_valid_pending_to_running(self, store): + await store.create_task(_make_task()) + await store.update_task("t1", status=TaskStatus.RUNNING) + result = await store.get_task("t1") + assert result.status == TaskStatus.RUNNING + + @pytest.mark.asyncio + async def test_valid_running_to_completed(self, store): + await store.create_task(_make_task(status=TaskStatus.RUNNING)) + await store.update_task("t1", status=TaskStatus.COMPLETED) + result = await store.get_task("t1") + assert result.status == TaskStatus.COMPLETED + + @pytest.mark.asyncio + async def test_valid_terminal_to_pending_resume(self, store): + for terminal in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED): + await store.create_task(_make_task(task_id=terminal.value, status=terminal)) + await store.update_task(terminal.value, status=TaskStatus.PENDING) + result = await store.get_task(terminal.value) + assert result.status == TaskStatus.PENDING + + @pytest.mark.asyncio + async def test_invalid_completed_to_running(self, store): + await store.create_task(_make_task(status=TaskStatus.COMPLETED)) + with pytest.raises(ValueError, match="Invalid task status transition"): + await store.update_task("t1", status=TaskStatus.RUNNING) + + @pytest.mark.asyncio + async def test_invalid_pending_to_completed(self, store): + await store.create_task(_make_task()) + with pytest.raises(ValueError, match="Invalid task status transition"): + await store.update_task("t1", status=TaskStatus.COMPLETED) + + @pytest.mark.asyncio + async def test_nonexistent_task_skips_validation(self, store): + # update_task on a missing task_id silently no-ops (existing behaviour) + await store.update_task("ghost", status=TaskStatus.RUNNING) + + class TestAgentCRUD: @pytest.mark.asyncio async def test_upsert_and_get(self, store):