diff --git a/changes/3997.feature.md b/changes/3997.feature.md new file mode 100644 index 0000000000..b534d0d41b --- /dev/null +++ b/changes/3997.feature.md @@ -0,0 +1 @@ +Added `zarr.run`, a public function for running zarr's asynchronous API (`AsyncArray`, `AsyncGroup`, and their methods) from synchronous code. It runs a coroutine to completion on a zarr-managed event loop and returns the result, providing downstream libraries a supported bridge from async to sync without managing their own event loop. If called from within a running event loop, it raises `RuntimeError` (matching `asyncio.run`); in that case, `await` the async API directly. diff --git a/docs/api/zarr/index.md b/docs/api/zarr/index.md index 194edb7913..811a8eddc8 100644 --- a/docs/api/zarr/index.md +++ b/docs/api/zarr/index.md @@ -40,6 +40,11 @@ Zarr-Python provides both synchronous and asynchronous APIs: - **[Async API](./api/asynchronous.md)** - Asynchronous operations for concurrent access - **[Sync API](./api/synchronous.md)** - Synchronous operations for simple usage +To run the asynchronous API from synchronous code, use `zarr.run`, which +executes a coroutine to completion on a Zarr-managed event loop and returns +its result. It raises `RuntimeError` if called from within a running event +loop (in that case, `await` the async API directly instead). + ## Abstract Base Classes The ABC module defines interfaces for extending Zarr: diff --git a/src/zarr/__init__.py b/src/zarr/__init__.py index cdf3840c3b..dc1a657de6 100644 --- a/src/zarr/__init__.py +++ b/src/zarr/__init__.py @@ -37,6 +37,7 @@ from zarr.core.array import Array, AsyncArray from zarr.core.config import config from zarr.core.group import AsyncGroup, Group +from zarr.core.sync import run # in case setuptools scm screw up and find version to be 0.0.0 assert not __version__.startswith("0.0.0") @@ -173,6 +174,7 @@ def set_format(log_format: str) -> None: "open_group", "open_like", "print_debug_info", + "run", "save", "save_array", "save_group", diff --git a/src/zarr/core/sync.py b/src/zarr/core/sync.py index 260d4ad841..cbdbe17fc0 100644 --- a/src/zarr/core/sync.py +++ b/src/zarr/core/sync.py @@ -32,7 +32,11 @@ class SyncError(Exception): - pass + """Internal: raised by `sync` when called from within a running event loop. + + The public `run` wrapper translates this to `RuntimeError` (matching + `asyncio.run`); downstream code should not catch `SyncError` directly. + """ def _get_lock() -> threading.Lock: @@ -160,6 +164,75 @@ def sync[T]( return return_result +def run[T]( + coro: Coroutine[Any, Any, T], + *, + loop: asyncio.AbstractEventLoop | None = None, + timeout: float | None = None, +) -> T: + """Run a coroutine to completion on a zarr-managed event loop and return its result. + + This is the supported bridge for running zarr's asynchronous API + (`AsyncArray`, `AsyncGroup`, and their methods) from synchronous code. It + runs `coro` on an event loop managed by zarr, blocking the calling thread + until the coroutine finishes, then returns its result (or re-raises any + exception the coroutine raised). + + The intended use is composing async zarr operations and running them + concurrently from a synchronous program. To run several operations + concurrently, build the composition inside a coroutine so that `gather` + binds to zarr's loop rather than the calling thread:: + + import asyncio + import zarr + + group = zarr.run(zarr.api.asynchronous.open_group(store)) + + async def read_all(): + return await asyncio.gather(*(g.getitem(...) for g in groups)) + + results = zarr.run(read_all()) + + Parameters + ---------- + coro : Coroutine + The coroutine to run. + loop : asyncio.AbstractEventLoop, optional + The event loop to run `coro` on. If `None` (the default), zarr's + managed loop is used. Most callers should leave this unset. + timeout : float, optional + Maximum number of seconds to wait for `coro` to finish. If `None` + (the default), waits indefinitely. Raises `TimeoutError` if exceeded. + + Returns + ------- + The value returned by `coro`. + + Raises + ------ + RuntimeError + If called from within a running event loop on the calling thread. In + an async context, `await` the async API directly instead. (This + mirrors `asyncio.run`, which raises `RuntimeError` for the same + misuse.) + TimeoutError + If `coro` does not finish within `timeout` seconds. + + Notes + ----- + `zarr.run` commits to this signature and to running coroutines on a + zarr-managed loop. It does not commit to *how* that loop is managed (a + process-global loop, a dedicated thread, etc.); those internals may change. + """ + try: + return sync(coro, loop=loop, timeout=timeout) + except SyncError as e: + raise RuntimeError( + "zarr.run cannot be called from within a running event loop. " + "In an async context, await the async API directly." + ) from e + + def _get_loop() -> asyncio.AbstractEventLoop: """Create or return the default fsspec IO loop diff --git a/tests/test_sync.py b/tests/test_sync.py index c5eadb0f4f..8b896ba8eb 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -13,6 +13,7 @@ _get_loop, cleanup_resources, loop, + run, sync, ) @@ -163,3 +164,97 @@ def test_cleanup_resources_idempotent() -> None: _get_executor() # trigger resource creation (iothread, loop, thread-pool) cleanup_resources() cleanup_resources() + + +# --- public `zarr.run` API --------------------------------------------------- +# `zarr.run` is the supported public bridge for running async zarr operations +# from synchronous code. It is a thin wrapper over the internal `sync`; these +# tests pin the public contract independently of the internal function. + + +def test_run_returns_coroutine_result() -> None: + """`zarr.run` returns the value the coroutine resolves to.""" + foo = AsyncMock(return_value="foo") + assert run(foo()) == "foo" + foo.assert_awaited_once() + + +def test_run_is_public() -> None: + """`run` is exported at the top level and is the same object as the internal one.""" + assert zarr.run is run + assert "run" in zarr.__all__ + + +def test_run_propagates_exception() -> None: + """An exception raised inside the coroutine propagates to the caller.""" + foo = AsyncMock(side_effect=ValueError("foo-bar")) + with pytest.raises(ValueError, match="foo-bar"): + run(foo()) + foo.assert_awaited_once() + + +def test_run_timeout() -> None: + """`zarr.run` raises `TimeoutError` if the coroutine exceeds `timeout`.""" + duration = 0.02 + + async def foo() -> None: + await asyncio.sleep(duration) + + with pytest.raises(asyncio.TimeoutError): + run(foo(), timeout=duration / 10) + + +@pytest.mark.filterwarnings("ignore:coroutine.*was never awaited") +def test_run_raises_runtimeerror_inside_running_loop() -> None: + """Calling `zarr.run` from within a running loop raises `RuntimeError`. + + This mirrors `asyncio.run`'s behavior for the same misuse, and hides the + internal `SyncError` from the public surface. + """ + + def inner() -> str: + # plain (not async) on purpose: an un-awaited inner coroutine would be + # garbage-collected during a later test and surface as a spurious + # "coroutine was never awaited" failure. Mirrors the internal-`sync` + # test above. + return "inner" + + async def outer() -> str: + return run(inner()) # type: ignore[arg-type] + + with pytest.raises(RuntimeError): + run(outer()) + + +def test_run_inside_running_loop_does_not_leak_syncerror() -> None: + """The internal `SyncError` is not surfaced to callers of `zarr.run`.""" + + def inner() -> str: + return "inner" + + async def outer() -> str: + return run(inner()) # type: ignore[arg-type] + + with pytest.raises(RuntimeError) as excinfo: + run(outer()) + # SyncError is preserved as the cause but is not the raised type. + assert not isinstance(excinfo.value, SyncError) + assert isinstance(excinfo.value.__cause__, SyncError) + + +def test_run_composes_with_gather() -> None: + """The headline downstream pattern: run several coroutines concurrently. + + The `gather` is constructed inside a coroutine so it binds to zarr's loop + rather than the calling thread (which has no running loop). + """ + + async def double(x: int) -> int: + await asyncio.sleep(0) + return x * 2 + + async def run_all() -> list[int]: + return await asyncio.gather(*(double(i) for i in range(5))) + + results = run(run_all()) + assert results == [0, 2, 4, 6, 8]