Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ build/

# Virtual environment
.venv/
venv/

# uv
uv.lock
Expand All @@ -28,4 +29,4 @@ uv.lock

# testing
.tmp/*
.coverage
.coverage
12 changes: 6 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ classifiers = [
dependencies = [
"click>=8.3,<9",
"pydantic>=2.0,<3",
"aiosqlite>=0.20",
"langgraph>=1.0",
"prompt_toolkit>=3.0",
"structlog>=24.0",
"aiosqlite>=0.20,<1",
"langgraph>=1.0,<2",
"prompt_toolkit>=3.0,<4",
"structlog>=24.0,<25",
]

[project.scripts]
Expand All @@ -36,8 +36,8 @@ Repository = "https://github.com/salesforce-misc/switchplane"
Issues = "https://github.com/salesforce-misc/switchplane/issues"

[project.optional-dependencies]
llm = ["langchain-core>=0.3"]
mcp = ["mcp>=1.0", "switchplane[llm]"]
llm = ["langchain-core>=0.3,<2"]
mcp = ["mcp>=1.0,<2", "switchplane[llm]"]
test = ["pytest>=7.0", "pytest-asyncio>=0.23", "pytest-cov>=4.0", "pytest-xdist>=3.5", "ruff>=0.9"]

[tool.hatch.build.targets.wheel]
Expand Down
24 changes: 24 additions & 0 deletions src/switchplane/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@
from switchplane.agent import AgentRecord
from switchplane.task import TaskRecord, TaskStatus

# Valid status transitions. Keys are the current status; values are the set of
# statuses that may follow. Any transition not listed here is rejected.
_VALID_TRANSITIONS: dict[TaskStatus, frozenset[TaskStatus]] = {
TaskStatus.PENDING: frozenset({TaskStatus.RUNNING, TaskStatus.CANCELLED, TaskStatus.FAILED}),
TaskStatus.RUNNING: frozenset(
{TaskStatus.RUNNING, TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED, TaskStatus.INTERRUPTED}
),
TaskStatus.INTERRUPTED: frozenset({TaskStatus.RUNNING, TaskStatus.FAILED, TaskStatus.CANCELLED}),
# Terminal states can only transition to PENDING (resume flow).
TaskStatus.COMPLETED: frozenset({TaskStatus.PENDING}),
TaskStatus.FAILED: frozenset({TaskStatus.PENDING}),
TaskStatus.CANCELLED: frozenset({TaskStatus.PENDING}),
}


Comment on lines +13 to 27
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hardcodes a state machine directly in the persistence layer — every app built on Switchplane gets the same rigid transitions with no way to override or opt out.

Different apps may have legitimate reasons for different flows (e.g., synchronous tasks that skip RUNNING, custom intermediate states, domain-specific lifecycle rules). Transition policy is an application concern, not a framework storage concern.

If we want to offer transition validation as a convenience, it should be opt-in or configurable at the app/task level — not baked into Store.update_task() where it silently constrains every consumer.

class Store:
"""Async SQLite store for control plane persistence."""
Expand Down Expand Up @@ -119,6 +133,16 @@ async def update_task(self, task_id: str, **fields: Any) -> None:
if not self._db:
raise RuntimeError("Store not initialized")

if "status" in fields and isinstance(fields["status"], TaskStatus):
new_status: TaskStatus = fields["status"]
cursor = await self._db.execute("SELECT status FROM tasks WHERE task_id = ?", (task_id,))
row = await cursor.fetchone()
if row:
current_status = TaskStatus(row[0])
allowed = _VALID_TRANSITIONS.get(current_status, frozenset())
if new_status not in allowed:
raise ValueError(f"Invalid task status transition: {current_status} → {new_status}")

# Build the SET clause dynamically
set_parts = []
values = []
Expand Down
Empty file added src/switchplane/py.typed
Empty file.
1 change: 1 addition & 0 deletions tests/test_control_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ async def test_clears(self, cp):
},
)
task_id = submit.result["task_id"]
await cp.store.update_task(task_id, status=TaskStatus.RUNNING)
await cp.store.update_task(task_id, status=TaskStatus.COMPLETED)

resp = await _request(cp, "clear_tasks")
Expand Down
42 changes: 42 additions & 0 deletions tests/test_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ async def test_update_status(self, store):
@pytest.mark.asyncio
async def test_update_result(self, store):
await store.create_task(_make_task())
await store.update_task("t1", status=TaskStatus.RUNNING)
await store.update_task("t1", status=TaskStatus.COMPLETED, result_json='{"answer": 42}')
result = await store.get_task("t1")
assert result.status == TaskStatus.COMPLETED
Expand Down Expand Up @@ -118,6 +119,47 @@ async def test_list_by_status(self, store):
assert len(running) == 1


class TestStatusTransitions:
@pytest.mark.asyncio
async def test_valid_pending_to_running(self, store):
await store.create_task(_make_task())
await store.update_task("t1", status=TaskStatus.RUNNING)
result = await store.get_task("t1")
assert result.status == TaskStatus.RUNNING

@pytest.mark.asyncio
async def test_valid_running_to_completed(self, store):
await store.create_task(_make_task(status=TaskStatus.RUNNING))
await store.update_task("t1", status=TaskStatus.COMPLETED)
result = await store.get_task("t1")
assert result.status == TaskStatus.COMPLETED

@pytest.mark.asyncio
async def test_valid_terminal_to_pending_resume(self, store):
for terminal in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
await store.create_task(_make_task(task_id=terminal.value, status=terminal))
await store.update_task(terminal.value, status=TaskStatus.PENDING)
result = await store.get_task(terminal.value)
assert result.status == TaskStatus.PENDING

@pytest.mark.asyncio
async def test_invalid_completed_to_running(self, store):
await store.create_task(_make_task(status=TaskStatus.COMPLETED))
with pytest.raises(ValueError, match="Invalid task status transition"):
await store.update_task("t1", status=TaskStatus.RUNNING)

@pytest.mark.asyncio
async def test_invalid_pending_to_completed(self, store):
await store.create_task(_make_task())
with pytest.raises(ValueError, match="Invalid task status transition"):
await store.update_task("t1", status=TaskStatus.COMPLETED)

@pytest.mark.asyncio
async def test_nonexistent_task_skips_validation(self, store):
# update_task on a missing task_id silently no-ops (existing behaviour)
await store.update_task("ghost", status=TaskStatus.RUNNING)


class TestAgentCRUD:
@pytest.mark.asyncio
async def test_upsert_and_get(self, store):
Expand Down
Loading