diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ab96aff4d1a..78085a21eb9 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -16,6 +16,9 @@ New Features - Added ``inherit='all_coords'`` option to :py:meth:`DataTree.to_dataset` to inherit all parent coordinates, not just indexed ones (:issue:`10812`, :pull:`11230`). +- Added ``max_concurrency`` parameter to :py:func:`open_datatree` to control + the maximum number of concurrent I/O operations when opening groups in parallel + with the Zarr backend (:pull:`10742`). By `Alfonso Ladino `_. Breaking Changes @@ -337,6 +340,9 @@ Documentation Performance ~~~~~~~~~~~ +- Improve performance of :py:func:`open_datatree` for zarr stores by using async/concurrent + loading of groups and indexes (:pull:`10742`). + By `Alfonso Ladino `_. - Add a fastpath to the backend plugin system for standard engines (:issue:`10178`, :pull:`10937`). By `Sam Levang `_. - Optimize :py:class:`~xarray.coding.variables.CFMaskCoder` decoder (:pull:`11105`). diff --git a/xarray/backends/api.py b/xarray/backends/api.py index fd992f3e5d8..a1f2da54fae 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -65,6 +65,7 @@ NestedSequence, T_Chunks, ) + from xarray.core.variable import Variable T_NetcdfEngine = Literal["netcdf4", "scipy", "h5netcdf"] T_Engine = Union[ @@ -349,7 +350,47 @@ def _datatree_from_backend_datatree( _protect_datatree_variables_inplace(backend_tree, cache) if create_default_indexes: - tree = backend_tree.map_over_datasets(_maybe_create_default_indexes) + _use_zarr_async = False + if engine == "zarr": + from xarray.backends.zarr import _zarr_v3 + + _use_zarr_async = _zarr_v3() + + if _use_zarr_async: + from zarr.core.sync import sync as zarr_sync + + async def create_indexes_async() -> dict[str, Dataset]: + import asyncio + from concurrent.futures import ThreadPoolExecutor + + executor = ThreadPoolExecutor( + max_workers=10, thread_name_prefix="xarray-idx" + ) + try: + results: dict[str, Dataset] = {} + + async def _create_index_for_node( + path: str, ds: Dataset + ) -> tuple[str, Dataset]: + return path, await _maybe_create_default_indexes_async( + ds, executor=executor + ) + + tasks = [ + _create_index_for_node(path, node.dataset) + for path, [node] in group_subtrees(backend_tree) + ] + for fut in asyncio.as_completed(tasks): + path, ds = await fut + results[path] = ds + return results + finally: + executor.shutdown(wait=True, cancel_futures=True) + + results = zarr_sync(create_indexes_async()) + tree = DataTree.from_dict(results, name=backend_tree.name) + else: + tree = backend_tree.map_over_datasets(_maybe_create_default_indexes) else: tree = backend_tree if chunks is not None: @@ -386,6 +427,33 @@ def _datatree_from_backend_datatree( return tree +async def _maybe_create_default_indexes_async(ds: Dataset, executor=None) -> Dataset: + import asyncio + + to_index_names = [ + name + for name, coord in ds.coords.items() + if coord.dims == (name,) and name not in ds.xindexes + ] + + if not to_index_names: + return ds + + loop = asyncio.get_running_loop() + + async def load_var(var: Variable) -> Variable: + try: + return await var.load_async() + except NotImplementedError: + return await loop.run_in_executor(executor, var.load) + + await asyncio.gather(*[load_var(ds.variables[name]) for name in to_index_names]) + + variables = {name: ds.variables[name] for name in to_index_names} + new_coords = Coordinates(variables) + return ds.assign_coords(new_coords) + + def open_dataset( filename_or_obj: T_PathFileOrDataStore, *, @@ -882,6 +950,7 @@ def open_datatree( chunked_array_type: str | None = None, from_array_kwargs: dict[str, Any] | None = None, backend_kwargs: dict[str, Any] | None = None, + max_concurrency: int | None = None, **kwargs, ) -> DataTree: """ @@ -1014,6 +1083,13 @@ def open_datatree( chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. For example if :py:func:`dask.array.Array` objects are used for chunking, additional kwargs will be passed to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. + max_concurrency : int, optional + Maximum number of concurrent I/O operations when opening groups in + parallel. This limits the number of groups that are loaded simultaneously. + Useful for controlling resource usage with large datatrees or stores + that may have limitations on concurrent access (e.g., icechunk). + Only used by backends that support parallel loading (currently Zarr v3). + If None (default), the backend uses its default value (typically 10). backend_kwargs: dict Additional keyword arguments passed on to the engine open function, equivalent to `**kwargs`. @@ -1021,8 +1097,12 @@ def open_datatree( Additional keyword arguments passed on to the engine open function. For example: - - 'group': path to the group in the given file to open as the root group as - a str. + - 'group': path to the group in the given file to open as the root + group as a str. If the string contains glob metacharacters + (``*``, ``?``, ``[``), it is interpreted as a pattern and only + groups whose paths match are loaded (along with their ancestors). + For example, ``group="*/sweep_0"`` loads every ``sweep_0`` one + level deep while skipping sibling groups. - 'lock': resource lock to use when reading data from disk. Only relevant when using dask or another form of parallelism. By default, appropriate locks are chosen to safely read and write files with the @@ -1074,6 +1154,9 @@ def open_datatree( ) overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None) + if max_concurrency is not None: + kwargs["max_concurrency"] = max_concurrency + backend_tree = backend.open_datatree( filename_or_obj, drop_variables=drop_variables, @@ -1265,8 +1348,12 @@ def open_groups( Additional keyword arguments passed on to the engine open function. For example: - - 'group': path to the group in the given file to open as the root group as - a str. + - 'group': path to the group in the given file to open as the root + group as a str. If the string contains glob metacharacters + (``*``, ``?``, ``[``), it is interpreted as a pattern and only + groups whose paths match are loaded (along with their ancestors). + For example, ``group="*/sweep_0"`` loads every ``sweep_0`` one + level deep while skipping sibling groups. - 'lock': resource lock to use when reading data from disk. Only relevant when using dask or another form of parallelism. By default, appropriate locks are chosen to safely read and write files with the diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 753c777649f..665fcb4363b 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -248,6 +248,37 @@ def _iter_nc_groups(root, parent="/"): yield from _iter_nc_groups(group, parent=gpath) +def _is_glob_pattern(pattern: str) -> bool: + return any(c in pattern for c in "*?[") + + +def _filter_group_paths(group_paths: Iterable[str], pattern: str) -> list[str]: + from xarray.core.treenode import NodePath + + matched: set[str] = {"/"} + for path in group_paths: + np_ = NodePath(path) + if np_.match(pattern): + matched.add(path) + for parent in np_.parents: + p = str(parent) + if p: + matched.add(p) + + return [p for p in group_paths if p in matched] + + +def _resolve_group_and_filter( + group: str | None, + all_group_paths: list[str], +) -> tuple[str | None, list[str]]: + if group is None: + return None, all_group_paths + if _is_glob_pattern(group): + return None, _filter_group_paths(all_group_paths, group) + return group, all_group_paths + + def find_root_and_group(ds): """Find the root and group name of a netCDF4/h5netcdf dataset.""" hierarchy = () diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 041aa49cf76..ab9ce6f2181 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -655,7 +655,11 @@ def open_groups_as_dict( open_kwargs: dict[str, Any] | None = None, **kwargs, ) -> dict[str, Dataset]: - from xarray.backends.common import _iter_nc_groups + from xarray.backends.common import ( + _is_glob_pattern, + _iter_nc_groups, + _resolve_group_and_filter, + ) from xarray.core.treenode import NodePath from xarray.core.utils import close_on_error @@ -664,10 +668,12 @@ def open_groups_as_dict( emit_phony_dims_warning, phony_dims = _check_phony_dims(phony_dims) filename_or_obj = _normalize_filename_or_obj(filename_or_obj) + + effective_group = None if (group and _is_glob_pattern(group)) else group store = H5NetCDFStore.open( filename_or_obj, format=format, - group=group, + group=effective_group, lock=lock, invalid_netcdf=invalid_netcdf, phony_dims=phony_dims, @@ -678,15 +684,17 @@ def open_groups_as_dict( open_kwargs=open_kwargs, ) - # Check for a group and make it a parent if it exists - if group: - parent = NodePath("/") / NodePath(group) + if effective_group: + parent = NodePath("/") / NodePath(effective_group) else: parent = NodePath("/") manager = store._manager + all_group_paths = list(_iter_nc_groups(store.ds, parent=parent)) + _, filtered_paths = _resolve_group_and_filter(group, all_group_paths) + groups_dict = {} - for path_group in _iter_nc_groups(store.ds, parent=parent): + for path_group in filtered_paths: group_store = H5NetCDFStore(manager, group=path_group, **kwargs) store_entrypoint = StoreBackendEntrypoint() with close_on_error(group_store): @@ -701,7 +709,7 @@ def open_groups_as_dict( decode_timedelta=decode_timedelta, ) - if group: + if effective_group: group_name = str(NodePath(path_group).relative_to(parent)) else: group_name = str(NodePath(path_group)) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 39dedd139c0..277f8c91a92 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -859,13 +859,19 @@ def open_groups_as_dict( autoclose=False, **kwargs, ) -> dict[str, Dataset]: - from xarray.backends.common import _iter_nc_groups + from xarray.backends.common import ( + _is_glob_pattern, + _iter_nc_groups, + _resolve_group_and_filter, + ) from xarray.core.treenode import NodePath filename_or_obj = _normalize_path(filename_or_obj) + + effective_group = None if (group and _is_glob_pattern(group)) else group store = NetCDF4DataStore.open( filename_or_obj, - group=group, + group=effective_group, format=format, clobber=clobber, diskless=diskless, @@ -875,15 +881,17 @@ def open_groups_as_dict( autoclose=autoclose, ) - # Check for a group and make it a parent if it exists - if group: - parent = NodePath("/") / NodePath(group) + if effective_group: + parent = NodePath("/") / NodePath(effective_group) else: parent = NodePath("/") manager = store._manager + all_group_paths = list(_iter_nc_groups(store.ds, parent=parent)) + _, filtered_paths = _resolve_group_and_filter(group, all_group_paths) + groups_dict = {} - for path_group in _iter_nc_groups(store.ds, parent=parent): + for path_group in filtered_paths: group_store = NetCDF4DataStore(manager, group=path_group, **kwargs) store_entrypoint = StoreBackendEntrypoint() with close_on_error(group_store): @@ -897,7 +905,7 @@ def open_groups_as_dict( use_cftime=use_cftime, decode_timedelta=decode_timedelta, ) - if group: + if effective_group: group_name = str(NodePath(path_group).relative_to(parent)) else: group_name = str(NodePath(path_group)) diff --git a/xarray/backends/store.py b/xarray/backends/store.py index 2c3cd42ae92..2e9d8838443 100644 --- a/xarray/backends/store.py +++ b/xarray/backends/store.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from collections.abc import Iterable from typing import TYPE_CHECKING @@ -72,5 +73,37 @@ def open_dataset( return ds + async def open_dataset_async( + self, + filename_or_obj: T_PathFileOrDataStore, + *, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + drop_variables: str | Iterable[str] | None = None, + set_indexes: bool = True, + use_cftime=None, + decode_timedelta=None, + ) -> Dataset: + """Async version of open_dataset. + + Offloads the entire open_dataset operation to a thread to avoid blocking + the event loop. This is necessary because decode_cf_variables can trigger + data reads (e.g., for time decoding) which may use synchronous I/O. + """ + return await asyncio.to_thread( + self.open_dataset, + filename_or_obj, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + set_indexes=set_indexes, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + BACKEND_ENTRYPOINTS["store"] = (None, StoreBackendEntrypoint) diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 6681673025c..2e196ec17cb 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -1,10 +1,12 @@ from __future__ import annotations +import asyncio import base64 import json import os import struct from collections.abc import Hashable, Iterable, Mapping +from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Any, Literal, Self, cast import numpy as np @@ -624,6 +626,70 @@ class ZarrStore(AbstractWritableDataStore): "zarr_group", ) + @classmethod + def _create_stores_from_members( + cls, + group_members, + mode, + consolidate_on_close, + append_dim, + write_region, + safe_chunks, + write_empty, + close_store_on_close, + use_zarr_fill_value_as_mask, + align_chunks, + cache_members, + ): + """Create ZarrStore instances from group members dict. + + Parameters + ---------- + group_members : dict[str, zarr.Group] + Dictionary mapping group paths to zarr Group instances. + mode : str + Access mode. + consolidate_on_close : bool + Whether to consolidate on close. + append_dim : str | None + Dimension to append along. + write_region : dict | None + Region to write to. + safe_chunks : bool + Whether to check chunk compatibility. + write_empty : bool | None + Whether to write empty chunks. + close_store_on_close : bool + Whether to close store on close. + use_zarr_fill_value_as_mask : bool + Whether to use zarr fill value as mask. + align_chunks : bool + Whether to align chunks. + cache_members : bool + Whether to cache group members. + + Returns + ------- + dict[str, ZarrStore] + Dictionary mapping group paths to ZarrStore instances. + """ + return { + grp: cls( + group_store, + mode, + consolidate_on_close, + append_dim, + write_region, + safe_chunks, + write_empty, + close_store_on_close, + use_zarr_fill_value_as_mask, + align_chunks=align_chunks, + cache_members=cache_members, + ) + for grp, group_store in group_members.items() + } + @classmethod def open_store( cls, @@ -664,20 +730,125 @@ def open_store( zarr_format=zarr_format, ) - from zarr import Group - - group_members: dict[str, Group] = {} group_paths = list(_iter_zarr_groups(zarr_group, parent=group)) - for path in group_paths: - if path == group: - group_members[path] = zarr_group - else: - rel_path = path.removeprefix(f"{group}/") - group_members[path] = zarr_group[rel_path.removeprefix("/")] + group_members = _build_group_members(zarr_group, group_paths, group) - out = { - group: cls( - group_store, + return cls._create_stores_from_members( + group_members, + mode, + consolidate_on_close, + append_dim, + write_region, + safe_chunks, + write_empty, + close_store_on_close, + use_zarr_fill_value_as_mask, + align_chunks, + cache_members, + ) + + @classmethod + async def open_store_async( + cls, + store, + mode: ZarrWriteModes = "r", + synchronizer=None, + group=None, + consolidated=False, + consolidate_on_close=False, + chunk_store=None, + storage_options=None, + append_dim=None, + write_region=None, + safe_chunks=True, + align_chunks=False, + zarr_version=None, + zarr_format=None, + use_zarr_fill_value_as_mask=None, + write_empty: bool | None = None, + cache_members: bool = False, + ): + """Async version of open_store using flat group discovery. + + This method uses store.list() to discover all groups in a single + async call, which is significantly faster than recursive traversal + for stores that support listing (like icechunk). + + Parameters + ---------- + store : MutableMapping, str, or path-like + Store or path to directory in file system or name of zip file. + mode : {"r", "r+", "a", "a-", "w", "w-"}, optional + Access mode for the store. + synchronizer : object, optional + Zarr synchronizer. + group : str, optional + Group path. + consolidated : bool, optional + Whether to use consolidated metadata. + consolidate_on_close : bool, optional + Whether to consolidate metadata on close. + chunk_store : MutableMapping, optional + Separate chunk store. + storage_options : dict, optional + Storage options for fsspec. + append_dim : str, optional + Dimension to append along. + write_region : dict, optional + Region to write to. + safe_chunks : bool, optional + Whether to check chunk compatibility. + align_chunks : bool, optional + Whether to align chunks. + zarr_version : int, optional + Deprecated, use zarr_format. + zarr_format : int, optional + Zarr format version (2 or 3). + use_zarr_fill_value_as_mask : bool, optional + Whether to use zarr fill value as mask. + write_empty : bool, optional + Whether to write empty chunks. + cache_members : bool, optional + Whether to cache group members. Defaults to False for async + to avoid synchronous member fetching overhead. + + Returns + ------- + dict[str, ZarrStore] + Dictionary mapping group paths to ZarrStore instances. + """ + loop = asyncio.get_running_loop() + + # Run on executor to avoid reentrant sync() deadlock + ( + zarr_group, + consolidate_on_close, + close_store_on_close, + use_zarr_fill_value_as_mask, + ) = await loop.run_in_executor( + None, + lambda: _get_open_params( + store=store, + mode=mode, + synchronizer=synchronizer, + group=group, + consolidated=consolidated, + consolidate_on_close=consolidate_on_close, + chunk_store=chunk_store, + storage_options=storage_options, + zarr_version=zarr_version, + use_zarr_fill_value_as_mask=use_zarr_fill_value_as_mask, + zarr_format=zarr_format, + ), + ) + + group_paths = await _iter_zarr_groups_async(zarr_group, parent=group) + + # Run on executor to avoid reentrant sync() deadlock + def _build_and_create(): + group_members = _build_group_members(zarr_group, group_paths, group) + return cls._create_stores_from_members( + group_members, mode, consolidate_on_close, append_dim, @@ -686,12 +857,11 @@ def open_store( write_empty, close_store_on_close, use_zarr_fill_value_as_mask, - align_chunks=align_chunks, - cache_members=cache_members, + align_chunks, + cache_members, ) - for group, group_store in group_members.items() - } - return out + + return await loop.run_in_executor(None, _build_and_create) @classmethod def open_group( @@ -1728,27 +1898,234 @@ def open_datatree( storage_options=None, zarr_version=None, zarr_format=None, + max_concurrency: int | None = None, ) -> DataTree: + from xarray.backends.common import _is_glob_pattern, _resolve_group_and_filter + filename_or_obj = _normalize_path(filename_or_obj) - groups_dict = self.open_groups_as_dict( - filename_or_obj=filename_or_obj, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, - group=group, - mode=mode, - synchronizer=synchronizer, - consolidated=consolidated, - chunk_store=chunk_store, - storage_options=storage_options, - zarr_version=zarr_version, - zarr_format=zarr_format, + + effective_group = None if (group and _is_glob_pattern(group)) else group + if effective_group: + parent = str(NodePath("/") / NodePath(effective_group)) + else: + parent = str(NodePath("/")) + + # Use async path for zarr v3 (concurrent loading), sync path for zarr v2 + if _zarr_v3(): + from zarr.core.sync import sync as zarr_sync + + # Sync call — safe outside zarr's IO loop + ( + zarr_group, + consolidate_on_close, + close_store_on_close, + use_zarr_fill_value_as_mask, + ) = _get_open_params( + store=filename_or_obj, + mode=mode, + synchronizer=synchronizer, + group=parent, + consolidated=consolidated, + consolidate_on_close=False, + chunk_store=chunk_store, + storage_options=storage_options, + zarr_version=zarr_version, + use_zarr_fill_value_as_mask=None, + zarr_format=zarr_format, + ) + + return zarr_sync( + self._open_datatree_from_stores_async( + zarr_group=zarr_group, + parent=parent, + group=group, + mode=mode, + consolidate_on_close=consolidate_on_close, + close_store_on_close=close_store_on_close, + use_zarr_fill_value_as_mask=use_zarr_fill_value_as_mask, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + max_concurrency=max_concurrency, + ) + ) + else: + # Fallback for zarr v2: sequential loading + stores = ZarrStore.open_store( + filename_or_obj, + group=parent, + mode=mode, + synchronizer=synchronizer, + consolidated=consolidated, + consolidate_on_close=False, + chunk_store=chunk_store, + storage_options=storage_options, + zarr_version=zarr_version, + zarr_format=zarr_format, + ) + all_paths = list(stores.keys()) + _, filtered_paths = _resolve_group_and_filter(group, all_paths) + groups_dict = {} + for path_group in filtered_paths: + store = stores[path_group] + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + group_ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + if effective_group: + group_name = str(NodePath(path_group).relative_to(parent)) + else: + group_name = str(NodePath(path_group)) + groups_dict[group_name] = group_ds + return datatree_from_dict_with_io_cleanup(groups_dict) + + async def _open_datatree_from_stores_async( + self, + zarr_group, + parent: str, + group: str | None, + *, + mode, + consolidate_on_close, + close_store_on_close, + use_zarr_fill_value_as_mask, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + drop_variables: str | Iterable[str] | None = None, + use_cftime=None, + decode_timedelta=None, + max_concurrency: int | None = None, + ) -> DataTree: + """Open datatree using native async for I/O, threads only for CPU decode.""" + if max_concurrency is None: + max_concurrency = 10 + + loop = asyncio.get_running_loop() + executor = ThreadPoolExecutor( + max_workers=max_concurrency, thread_name_prefix="xarray-cpu" ) + from zarr import Array as ZarrSyncArray + from zarr import Group as ZarrSyncGroup + from zarr.core.group import AsyncGroup as ZarrAsyncGroup + + async_root = zarr_group._async_group + parent_nodepath = NodePath(parent) + + # Phase 1: Walk tree, collect groups + per-group members in one async pass. + # This replaces both _iter_zarr_groups_async and per-group _fetch_members. + group_async: dict[str, ZarrAsyncGroup] = { + str(parent_nodepath): async_root, + } + group_children: dict[str, dict] = {str(parent_nodepath): {}} + + async for rel_path, member in async_root.members(max_depth=None): + full_path = str(parent_nodepath / rel_path) + + if isinstance(member, ZarrAsyncGroup): + group_async[full_path] = member + group_children[full_path] = {} + + parts = rel_path.rsplit("/", 1) + child_name = parts[-1] + parent_rel = parts[0] if len(parts) > 1 else "" + parent_path = ( + str(parent_nodepath / parent_rel) + if parent_rel + else str(parent_nodepath) + ) + if parent_path in group_children: + group_children[parent_path][child_name] = member + + # Filter groups when glob pattern is used + from xarray.backends.common import _resolve_group_and_filter + + effective_group, filtered_paths = _resolve_group_and_filter( + group, list(group_async.keys()) + ) + filtered_set = set(filtered_paths) + group_async = {k: v for k, v in group_async.items() if k in filtered_set} + group_children = {k: v for k, v in group_children.items() if k in filtered_set} + + # Phase 2: Open each group — wrap async objects, run CPU decode in threads. + async def open_one(path_group: str) -> tuple[str, Dataset]: + async_grp = group_async[path_group] + children = group_children.get(path_group, {}) + + def _cpu_open(): + sync_group = ZarrSyncGroup(async_grp) + sync_members = { + name: ( + ZarrSyncGroup(child) + if isinstance(child, ZarrAsyncGroup) + else ZarrSyncArray(child) + ) + for name, child in children.items() + } + + store = ZarrStore( + sync_group, + mode, + consolidate_on_close, + append_dim=None, + write_region=None, + safe_chunks=True, + write_empty=None, + close_store_on_close=close_store_on_close, + use_zarr_fill_value_as_mask=use_zarr_fill_value_as_mask, + align_chunks=False, + cache_members=False, + ) + store._members = sync_members + store._cache_members = True + + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + return store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + + ds = await loop.run_in_executor(executor, _cpu_open) + if effective_group: + group_name = str(NodePath(path_group).relative_to(parent)) + else: + group_name = str(NodePath(path_group)) + return group_name, ds + + groups_dict: dict[str, Dataset] = {} + + async def collect_result(path_group: str) -> None: + group_name, ds = await open_one(path_group) + groups_dict[group_name] = ds + + try: + async with asyncio.TaskGroup() as tg: + for path_group in sorted(group_async.keys()): + tg.create_task(collect_result(path_group)) + finally: + executor.shutdown(wait=True, cancel_futures=True) + return datatree_from_dict_with_io_cleanup(groups_dict) def open_groups_as_dict( @@ -1771,11 +2148,13 @@ def open_groups_as_dict( zarr_version=None, zarr_format=None, ) -> dict[str, Dataset]: + from xarray.backends.common import _is_glob_pattern, _resolve_group_and_filter + filename_or_obj = _normalize_path(filename_or_obj) - # Check for a group and make it a parent if it exists - if group: - parent = str(NodePath("/") / NodePath(group)) + effective_group = None if (group and _is_glob_pattern(group)) else group + if effective_group: + parent = str(NodePath("/") / NodePath(effective_group)) else: parent = str(NodePath("/")) @@ -1792,8 +2171,11 @@ def open_groups_as_dict( zarr_format=zarr_format, ) + _, filtered_paths = _resolve_group_and_filter(group, list(stores.keys())) + groups_dict = {} - for path_group, store in stores.items(): + for path_group in filtered_paths: + store = stores[path_group] store_entrypoint = StoreBackendEntrypoint() with close_on_error(store): @@ -1807,13 +2189,122 @@ def open_groups_as_dict( use_cftime=use_cftime, decode_timedelta=decode_timedelta, ) - if group: + if effective_group: group_name = str(NodePath(path_group).relative_to(parent)) else: group_name = str(NodePath(path_group)) groups_dict[group_name] = group_ds return groups_dict + async def open_groups_as_dict_async( + self, + filename_or_obj: T_PathFileOrDataStore, + *, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + drop_variables: str | Iterable[str] | None = None, + use_cftime=None, + decode_timedelta=None, + group: str | None = None, + mode="r", + synchronizer=None, + consolidated=None, + chunk_store=None, + storage_options=None, + zarr_version=None, + zarr_format=None, + ) -> dict[str, Dataset]: + """Asynchronously open each group into a Dataset concurrently. + + This mirrors open_groups_as_dict but parallelizes per-group Dataset opening, + which can significantly reduce latency on high-RTT object stores. + """ + from xarray.backends.common import _is_glob_pattern, _resolve_group_and_filter + + filename_or_obj = _normalize_path(filename_or_obj) + + effective_group = None if (group and _is_glob_pattern(group)) else group + if effective_group: + parent = str(NodePath("/") / NodePath(effective_group)) + else: + parent = str(NodePath("/")) + + stores = await ZarrStore.open_store_async( + filename_or_obj, + group=parent, + mode=mode, + synchronizer=synchronizer, + consolidated=consolidated, + consolidate_on_close=False, + chunk_store=chunk_store, + storage_options=storage_options, + zarr_version=zarr_version, + zarr_format=zarr_format, + ) + + _, filtered_paths = _resolve_group_and_filter(group, list(stores.keys())) + filtered_set = set(filtered_paths) + + loop = asyncio.get_running_loop() + max_workers = min(len(stores), 10) if stores else 1 + executor = ThreadPoolExecutor( + max_workers=max_workers, thread_name_prefix="xarray" + ) + + async def open_one(path_group: str, store) -> tuple[str, Dataset]: + store_entrypoint = StoreBackendEntrypoint() + + def _load_sync(): + with close_on_error(store): + return store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + + ds = await loop.run_in_executor(executor, _load_sync) + if effective_group: + group_name = str(NodePath(path_group).relative_to(parent)) + else: + group_name = str(NodePath(path_group)) + return group_name, ds + + try: + tasks = [ + open_one(path_group, store) + for path_group, store in stores.items() + if path_group in filtered_set + ] + results = await asyncio.gather(*tasks) + finally: + executor.shutdown(wait=True, cancel_futures=True) + return dict(results) + + +def _build_group_members( + zarr_group: ZarrGroup, + group_paths: list[str], + parent: str | None, +) -> dict[str, ZarrGroup]: + parent = parent or "/" + group_members: dict[str, ZarrGroup] = {} + + for path in group_paths: + if path == parent: + group_members[path] = zarr_group + else: + rel_path = path.removeprefix(f"{parent}/").removeprefix("/") + group_members[path] = cast("ZarrGroup", zarr_group[rel_path]) + + return group_members + def _iter_zarr_groups(root: ZarrGroup, parent: str = "/") -> Iterable[str]: parent_nodepath = NodePath(parent) @@ -1823,6 +2314,24 @@ def _iter_zarr_groups(root: ZarrGroup, parent: str = "/") -> Iterable[str]: yield from _iter_zarr_groups(group, parent=str(gpath)) +async def _iter_zarr_groups_async(root: ZarrGroup, parent: str = "/") -> list[str]: + try: + from zarr.core.group import AsyncGroup + except (ImportError, ModuleNotFoundError): + # zarr v2: no async group support, fall back to sync + return list(_iter_zarr_groups(root, parent=parent)) + + parent_nodepath = NodePath(parent) + group_paths = [str(parent_nodepath)] + + async_group = root._async_group + async for name, member in async_group.members(max_depth=None): + if isinstance(member, AsyncGroup): + group_paths.append(str(parent_nodepath / name)) + + return sorted(group_paths) + + def _get_open_params( store, mode, diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index 32f224e89a6..1ae00546616 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -377,6 +377,79 @@ def test_open_datatree_specific_group(self, tmpdir, simple_datatree) -> None: assert subgroup_tree.root.parent is None assert_equal(subgroup_tree, expected_subtree) + def test_open_datatree_group_glob(self, tmpdir) -> None: + original_dt = DataTree.from_dict( + { + "/": xr.Dataset({"root_var": 1}), + "/A": xr.Dataset({"a_var": 2}), + "/A/sweep_0": xr.Dataset({"data": ("x", [1, 2])}), + "/A/sweep_1": xr.Dataset({"data": ("x", [3, 4])}), + "/B": xr.Dataset({"b_var": 3}), + "/B/sweep_0": xr.Dataset({"data": ("x", [5, 6])}), + } + ) + filepath = tmpdir / "glob_test.nc" + original_dt.to_netcdf(filepath, engine=self.engine) + + with open_datatree(filepath, group="*/sweep_0", engine=self.engine) as tree: + paths = {node.path for node in tree.subtree} + assert "/A/sweep_0" in paths + assert "/B/sweep_0" in paths + assert "/A/sweep_1" not in paths + + def test_open_datatree_group_glob_no_match(self, tmpdir) -> None: + original_dt = DataTree.from_dict( + { + "/": xr.Dataset({"root_var": 1}), + "/A": xr.Dataset({"a_var": 2}), + } + ) + filepath = tmpdir / "glob_nomatch.nc" + original_dt.to_netcdf(filepath, engine=self.engine) + + with open_datatree(filepath, group="*/nonexistent", engine=self.engine) as tree: + paths = {node.path for node in tree.subtree} + assert paths == {"/"} + + def test_open_datatree_group_glob_preserves_data(self, tmpdir) -> None: + original_dt = DataTree.from_dict( + { + "/": xr.Dataset({"root_var": 1}), + "/A": xr.Dataset({"a_var": 2}), + "/A/sweep_0": xr.Dataset({"data": ("x", [1, 2])}), + } + ) + filepath = tmpdir / "glob_data.nc" + original_dt.to_netcdf(filepath, engine=self.engine) + + with open_datatree(filepath, group="*/sweep_0", engine=self.engine) as tree: + assert tree["/A"].dataset["a_var"].item() == 2 + np.testing.assert_array_equal( + tree["/A/sweep_0"].dataset["data"].values, [1, 2] + ) + + def test_open_groups_group_glob(self, tmpdir) -> None: + original_dt = DataTree.from_dict( + { + "/": xr.Dataset({"root_var": 1}), + "/A": xr.Dataset({"a_var": 2}), + "/A/sweep_0": xr.Dataset({"data": ("x", [1, 2])}), + "/A/sweep_1": xr.Dataset({"data": ("x", [3, 4])}), + } + ) + filepath = tmpdir / "glob_groups.nc" + original_dt.to_netcdf(filepath, engine=self.engine) + + groups = open_groups(filepath, group="*/sweep_0", engine=self.engine) + try: + assert "/" in groups + assert "/A" in groups + assert "/A/sweep_0" in groups + assert "/A/sweep_1" not in groups + finally: + for ds in groups.values(): + ds.close() + @requires_h5netcdf_or_netCDF4 class TestGenericNetCDFIO(NetCDFIOBase): @@ -1025,6 +1098,62 @@ def test_open_datatree_specific_group( assert subgroup_tree.root.parent is None assert_equal(subgroup_tree, expected_subtree) + def test_open_datatree_group_glob(self, tmpdir, zarr_format) -> None: + original_dt = DataTree.from_dict( + { + "/": xr.Dataset({"root_var": 1}), + "/A": xr.Dataset({"a_var": 2}), + "/A/sweep_0": xr.Dataset({"data": ("x", [1, 2])}), + "/A/sweep_1": xr.Dataset({"data": ("x", [3, 4])}), + "/B": xr.Dataset({"b_var": 3}), + "/B/sweep_0": xr.Dataset({"data": ("x", [5, 6])}), + } + ) + filepath = str(tmpdir / "glob_test.zarr") + original_dt.to_zarr(filepath, zarr_format=zarr_format) + + with open_datatree(filepath, group="*/sweep_0", engine=self.engine) as tree: + paths = {node.path for node in tree.subtree} + assert "/A/sweep_0" in paths + assert "/B/sweep_0" in paths + assert "/A/sweep_1" not in paths + + def test_open_datatree_group_glob_no_match(self, tmpdir, zarr_format) -> None: + original_dt = DataTree.from_dict( + { + "/": xr.Dataset({"root_var": 1}), + "/A": xr.Dataset({"a_var": 2}), + } + ) + filepath = str(tmpdir / "glob_nomatch.zarr") + original_dt.to_zarr(filepath, zarr_format=zarr_format) + + with open_datatree(filepath, group="*/nonexistent", engine=self.engine) as tree: + paths = {node.path for node in tree.subtree} + assert paths == {"/"} + + def test_open_groups_group_glob(self, tmpdir, zarr_format) -> None: + original_dt = DataTree.from_dict( + { + "/": xr.Dataset({"root_var": 1}), + "/A": xr.Dataset({"a_var": 2}), + "/A/sweep_0": xr.Dataset({"data": ("x", [1, 2])}), + "/A/sweep_1": xr.Dataset({"data": ("x", [3, 4])}), + } + ) + filepath = str(tmpdir / "glob_groups.zarr") + original_dt.to_zarr(filepath, zarr_format=zarr_format) + + groups = open_groups(filepath, group="*/sweep_0", engine=self.engine) + try: + assert "/" in groups + assert "/A" in groups + assert "/A/sweep_0" in groups + assert "/A/sweep_1" not in groups + finally: + for ds in groups.values(): + ds.close() + @requires_dask def test_open_groups_chunks(self, tmpdir, zarr_format) -> None: """Test `open_groups` with chunks on a zarr store.""" @@ -1142,3 +1271,52 @@ def test_zarr_engine_recognised(self, tmpdir, zarr_format) -> None: with open_datatree(filepath) as roundtrip_dt: assert_identical(original_dt, roundtrip_dt) + + +class TestGlobPatternUtilities: + def test_is_glob_pattern(self) -> None: + from xarray.backends.common import _is_glob_pattern + + assert _is_glob_pattern("*/sweep_0") + assert _is_glob_pattern("VCP-34/sweep_[01]") + assert _is_glob_pattern("sweep_?") + assert not _is_glob_pattern("VCP-34") + assert not _is_glob_pattern("/group/subgroup") + + def test_filter_group_paths(self) -> None: + from xarray.backends.common import _filter_group_paths + + paths = ["/", "/A", "/A/sweep_0", "/A/sweep_1", "/B", "/B/sweep_0"] + result = _filter_group_paths(paths, "*/sweep_0") + assert result == ["/", "/A", "/A/sweep_0", "/B", "/B/sweep_0"] + + def test_filter_group_paths_no_match(self) -> None: + from xarray.backends.common import _filter_group_paths + + paths = ["/", "/A", "/B"] + result = _filter_group_paths(paths, "*/nonexistent") + assert result == ["/"] + + def test_resolve_group_and_filter_none(self) -> None: + from xarray.backends.common import _resolve_group_and_filter + + paths = ["/", "/A"] + effective, filtered = _resolve_group_and_filter(None, paths) + assert effective is None + assert filtered == paths + + def test_resolve_group_and_filter_literal(self) -> None: + from xarray.backends.common import _resolve_group_and_filter + + paths = ["/", "/A"] + effective, filtered = _resolve_group_and_filter("A", paths) + assert effective == "A" + assert filtered == paths + + def test_resolve_group_and_filter_glob(self) -> None: + from xarray.backends.common import _resolve_group_and_filter + + paths = ["/", "/A", "/A/sweep_0", "/A/sweep_1", "/B", "/B/sweep_0"] + effective, filtered = _resolve_group_and_filter("*/sweep_0", paths) + assert effective is None + assert filtered == ["/", "/A", "/A/sweep_0", "/B", "/B/sweep_0"] diff --git a/xarray/tests/test_backends_zarr_async.py b/xarray/tests/test_backends_zarr_async.py new file mode 100644 index 00000000000..cea1128ca92 --- /dev/null +++ b/xarray/tests/test_backends_zarr_async.py @@ -0,0 +1,259 @@ +"""Tests for internal asynchronous zarr group loading functionality.""" + +from __future__ import annotations + +import contextlib + +import numpy as np +import pytest + +import xarray as xr +from xarray.backends.api import _maybe_create_default_indexes_async +from xarray.backends.store import StoreBackendEntrypoint +from xarray.backends.zarr import ZarrBackendEntrypoint +from xarray.testing import assert_equal +from xarray.tests import ( + has_zarr_v3, + parametrize_zarr_format, + requires_zarr, + requires_zarr_v3, +) + +if has_zarr_v3: + from zarr.storage import MemoryStore + + +def create_dataset_with_coordinates(n_coords=5): + """Create a dataset with coordinate variables to trigger indexing.""" + coords = {} + for i in range(n_coords): + coords[f"coord_{i}"] = (f"coord_{i}", np.arange(3)) + + coord_names = list(coords.keys()) + data_vars = {} + + if len(coord_names) >= 2: + data_vars["temperature"] = (coord_names[:2], np.random.random((3, 3))) + if len(coord_names) >= 1: + data_vars["pressure"] = (coord_names[:1], np.random.random(3)) + + data_vars["simple"] = ([], np.array(42.0)) + + ds = xr.Dataset(data_vars=data_vars, coords=coords) + return ds + + +def create_test_datatree(n_groups=3, coords_per_group=5): + """Create a DataTree for testing with multiple groups.""" + root_ds = create_dataset_with_coordinates(coords_per_group) + tree_dict = {"/": root_ds} + + for i in range(n_groups): + group_name = f"/group_{i:03d}" + group_ds = create_dataset_with_coordinates(n_coords=coords_per_group) + tree_dict[group_name] = group_ds + + tree = xr.DataTree.from_dict(tree_dict) + return tree + + +@requires_zarr +class TestAsyncZarrGroupLoading: + """Tests for internal asynchronous zarr group loading functionality.""" + + @contextlib.contextmanager + def create_zarr_store(self): + """Create a zarr target for testing.""" + if has_zarr_v3: + with MemoryStore() as store: + yield store + else: + from zarr.storage import MemoryStore as V2MemoryStore + + store = V2MemoryStore() + yield store + + @parametrize_zarr_format + def test_sync_datatree_roundtrip_with_async_optimization(self, zarr_format): + """Test that sync open_datatree with internal async optimization preserves data integrity.""" + dtree = create_test_datatree(n_groups=3, coords_per_group=4) + + with self.create_zarr_store() as store: + dtree.to_zarr(store, mode="w", consolidated=False, zarr_format=zarr_format) + + # Use sync open_datatree which internally uses async for zarr v3 + dtree_loaded = xr.open_datatree( + store, + consolidated=False, + zarr_format=zarr_format, + create_default_indexes=True, + engine="zarr", + ) + assert_equal(dtree, dtree_loaded) + + @pytest.mark.asyncio + @requires_zarr + @parametrize_zarr_format + async def test_backend_open_groups_async_equivalence(self, zarr_format): + """Backend async group opening returns the same groups and datasets as sync.""" + dtree = create_test_datatree(n_groups=3, coords_per_group=4) + backend = ZarrBackendEntrypoint() + + with self.create_zarr_store() as store: + dtree.to_zarr(store, mode="w", consolidated=False, zarr_format=zarr_format) + + groups_sync = backend.open_groups_as_dict( + store, + consolidated=False, + zarr_format=zarr_format, + ) + + groups_async = await backend.open_groups_as_dict_async( + store, + consolidated=False, + zarr_format=zarr_format, + ) + + assert set(groups_sync.keys()) == set(groups_async.keys()) + for k in list(groups_sync.keys())[:2]: + assert_equal(groups_sync[k], groups_async[k]) + + @pytest.mark.asyncio + async def test_maybe_create_default_indexes_async_no_coords_needing_indexes(self): + """Test _maybe_create_default_indexes_async with no coordinates needing indexes.""" + ds = xr.Dataset( + { + "temperature": (("x", "y"), np.random.random((3, 4))), + } + ) + + result = await _maybe_create_default_indexes_async(ds) + assert_equal(ds, result) + assert len(result.xindexes) == 0 + + @pytest.mark.asyncio + async def test_maybe_create_default_indexes_async_creates_indexes(self): + """Test _maybe_create_default_indexes_async creates indexes for coordinate variables.""" + coords = {"time": ("time", np.arange(5)), "x": ("x", np.arange(3))} + data_vars = { + "temperature": (("time", "x"), np.random.random((5, 3))), + } + ds = xr.Dataset(data_vars, coords) + ds_no_indexes = ds.drop_indexes(["time", "x"]) + + assert len(ds_no_indexes.xindexes) == 0 + + result = await _maybe_create_default_indexes_async(ds_no_indexes) + + assert "time" in result.xindexes + assert "x" in result.xindexes + assert len(result.xindexes) == 2 + + @pytest.mark.asyncio + async def test_maybe_create_default_indexes_async_partial_indexes(self): + """Test with mix of coords that need indexes and those that don't.""" + coords = { + "time": ("time", np.arange(5)), + "x": ("x", np.arange(3)), + } + data_vars = { + "temperature": (("time", "x"), np.random.random((5, 3))), + } + ds = xr.Dataset(data_vars=data_vars, coords=coords) + ds_partial = ds.drop_indexes(["x"]) + + assert "time" in ds_partial.xindexes + assert "x" not in ds_partial.xindexes + + result = await _maybe_create_default_indexes_async(ds_partial) + + assert "time" in result.xindexes + assert "x" in result.xindexes + + @pytest.mark.asyncio + async def test_maybe_create_default_indexes_async_all_indexes_exist(self): + """Test that function returns original dataset when all coords already have indexes.""" + ds = create_dataset_with_coordinates(n_coords=2) + + assert len(ds.xindexes) > 0 + + result = await _maybe_create_default_indexes_async(ds) + assert result is ds # Same object returned + + @parametrize_zarr_format + def test_sync_open_datatree_source_encoding(self, zarr_format): + """Test that open_datatree sets source encoding correctly.""" + import os + import tempfile + + dtree = create_test_datatree(n_groups=2, coords_per_group=2) + + with tempfile.TemporaryDirectory() as tmpdir: + store_path = os.path.join(tmpdir, "test.zarr") + dtree.to_zarr( + store_path, mode="w", consolidated=False, zarr_format=zarr_format + ) + + dtree_loaded = xr.open_datatree( + store_path, + consolidated=False, + zarr_format=zarr_format, + engine="zarr", + ) + assert "source" in dtree_loaded.encoding + # Normalize paths for cross-platform comparison + source = os.path.normpath(dtree_loaded.encoding["source"]) + expected = os.path.normpath(store_path) + assert expected in source + + @requires_zarr_v3 + @parametrize_zarr_format + def test_sync_open_datatree_uses_async_internally(self, zarr_format): + """Test that sync open_datatree uses async index creation for zarr v3.""" + from unittest.mock import patch + + dtree = create_test_datatree(n_groups=2, coords_per_group=3) + + with self.create_zarr_store() as store: + dtree.to_zarr(store, mode="w", consolidated=False, zarr_format=zarr_format) + + # Patch the async index creation function to verify it's called + with patch( + "xarray.backends.api._maybe_create_default_indexes_async", + wraps=_maybe_create_default_indexes_async, + ) as mock_async: + dtree_loaded = xr.open_datatree( + store, + consolidated=False, + zarr_format=zarr_format, + create_default_indexes=True, + engine="zarr", + ) + + # For zarr v3, the async function should be called + assert mock_async.call_count > 0 + assert_equal(dtree, dtree_loaded) + + @pytest.mark.asyncio + @requires_zarr_v3 + @parametrize_zarr_format + async def test_store_backend_open_dataset_async_equivalence(self, zarr_format): + """Test that StoreBackendEntrypoint.open_dataset_async returns same result as sync.""" + from xarray.backends.zarr import ZarrStore + + ds = create_dataset_with_coordinates(n_coords=3) + + with self.create_zarr_store() as store: + ds.to_zarr(store, mode="w", consolidated=False, zarr_format=zarr_format) + + zarr_store = ZarrStore.open_group( + store, + consolidated=False, + zarr_format=zarr_format, + ) + + store_entrypoint = StoreBackendEntrypoint() + ds_sync = store_entrypoint.open_dataset(zarr_store) + ds_async = await store_entrypoint.open_dataset_async(zarr_store) + + assert_equal(ds_sync, ds_async)