diff --git a/pyproject.toml b/pyproject.toml index 7ff0e75a05b..bed4a5bd065 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ accel = [ complete = ["xarray[accel,etc,io,parallel,viz]"] io = [ "netCDF4>=1.6.0", - "h5netcdf>=1.4.0", + "h5netcdf[h5py]>=1.4.0", "pydap", "scipy>=1.13", "zarr>=2.18", @@ -81,6 +81,7 @@ dev = [ "pytest-timeout", "pytest-xdist", "pytest-asyncio", + "pytz", "ruff>=0.8.0", "sphinx", "sphinx_autosummary_accessors", diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 1732d567503..ab462811e8f 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import base64 import json import os @@ -43,7 +44,7 @@ if TYPE_CHECKING: from xarray.core.dataset import Dataset from xarray.core.datatree import DataTree - from xarray.core.types import ZarrArray, ZarrGroup + from xarray.core.types import ZarrArray, ZarrAsyncArray, ZarrAsyncGroup, ZarrGroup def _get_mappers(*, storage_options, store, chunk_store): @@ -603,6 +604,16 @@ def _put_attrs(zarr_obj, attrs): return zarr_obj +async def _put_attrs_async(zarr_async_obj, attrs): + """Async version of _put_attrs using zarr's asynchronous API.""" + try: + zarr_async_obj.attrs.clear() + await zarr_async_obj.update_attributes(attrs) + except TypeError as e: + raise TypeError("Invalid attribute in Dataset.attrs.") from e + return zarr_async_obj + + class ZarrStore(AbstractWritableDataStore): """Store for reading and writing data via zarr""" @@ -820,6 +831,14 @@ def members(self) -> dict[str, ZarrArray | ZarrGroup]: else: return self._members + async def _members_async( + self, + ) -> dict[str, ZarrAsyncArray | ZarrAsyncGroup | ZarrArray | ZarrGroup]: + if not self._cache_members: + return await self._fetch_members_async() + else: + return self._members + def _fetch_members(self) -> dict[str, ZarrArray | ZarrGroup]: """ Get the arrays and groups defined in the zarr group modelled by this Store @@ -831,6 +850,12 @@ def _fetch_members(self) -> dict[str, ZarrArray | ZarrGroup]: else: return dict(self.zarr_group.items()) + async def _fetch_members_async(self) -> dict[str, ZarrAsyncArray | ZarrAsyncGroup]: + members = [] + async for key, node in self._zarr_async_group().members(): + members.append((key, node)) + return dict(members) + def array_keys(self) -> tuple[str, ...]: from zarr import Array as ZarrArray @@ -838,6 +863,16 @@ def array_keys(self) -> tuple[str, ...]: key for (key, node) in self.members.items() if isinstance(node, ZarrArray) ) + async def _array_keys_async(self) -> tuple[str, ...]: + from zarr import Array as ZarrArray + from zarr import AsyncArray as ZarrAsyncArray + + return tuple( + key + for (key, node) in (await self._members_async()).items() + if isinstance(node, (ZarrAsyncArray, ZarrArray)) + ) + def arrays(self) -> tuple[tuple[str, ZarrArray], ...]: from zarr import Array as ZarrArray @@ -1070,7 +1105,23 @@ def store( kwargs["zarr_format"] = self.zarr_group.metadata.zarr_format zarr.consolidate_metadata(self.zarr_group.store, **kwargs) + def _zarr_async_group(self) -> ZarrAsyncGroup: + # Using private zarr attribute + return self.zarr_group._async_group + # Equivalent with public API + # return AsyncGroup(self.zarr_group.metadata, self.zarr_group.store_path) + def _open_existing_array(self, *, name) -> ZarrArray: + if _zarr_v3(): + from zarr import Array as ZarrArray + from zarr.api.synchronous import sync + + zarr_async_array = sync(self._open_existing_array_async(name=name)) + return ZarrArray(zarr_async_array) + return self._open_existing_array_sync_v2(name=name) + + def _open_existing_array_sync_v2(self, *, name) -> ZarrArray: + """Sync open of an existing array (zarr v2 only).""" import zarr from zarr import Array as ZarrArray @@ -1094,15 +1145,10 @@ def _open_existing_array(self, *, name) -> ZarrArray: # - The size of dimensions can not be expanded, that would require a call using `append_dim` # which is mutually exclusive with `region` empty: dict[str, bool] | dict[str, dict[str, bool]] - if _zarr_v3(): - empty = dict(config={"write_empty_chunks": self._write_empty}) - else: - empty = dict(write_empty_chunks=self._write_empty) + empty = dict(write_empty_chunks=self._write_empty) zarr_array = zarr.open( - store=( - self.zarr_group.store if _zarr_v3() else self.zarr_group.chunk_store - ), + store=self.zarr_group.chunk_store, # TODO: see if zarr should normalize these strings. path="/".join([self.zarr_group.name.rstrip("/"), name]).lstrip("/"), **empty, @@ -1112,9 +1158,77 @@ def _open_existing_array(self, *, name) -> ZarrArray: return cast(ZarrArray, zarr_array) + async def _open_existing_array_async(self, *, name) -> ZarrAsyncArray: + """Async version of _open_existing_array using zarr's asynchronous API.""" + + from zarr import Array as ZarrAsyncArray + + # TODO: if mode="a", consider overriding the existing variable + # metadata. This would need some case work properly with region + # and append_dim. + if self._write_empty is not None: + # Write to zarr_group.chunk_store instead of zarr_group.store + # See https://github.com/pydata/xarray/pull/8326#discussion_r1365311316 for a longer explanation + # The open_consolidated() enforces a mode of r or r+ + # (and to_zarr with region provided enforces a read mode of r+), + # and this function makes sure the resulting Group has a store of type ConsolidatedMetadataStore + # and a 'normal Store subtype for chunk_store. + # The exact type depends on if a local path was used, or a URL of some sort, + # but the point is that it's not a read-only ConsolidatedMetadataStore. + # It is safe to write chunk data to the chunk_store because no metadata would be changed by + # to_zarr with the region parameter: + # - Because the write mode is enforced to be r+, no new variables can be added to the store + # (this is also checked and enforced in xarray.backends.api.py::to_zarr()). + # - Existing variables already have their attrs included in the consolidated metadata file. + # - The size of dimensions can not be expanded, that would require a call using `append_dim` + # which is mutually exclusive with `region` + empty: dict[str, bool] | dict[str, dict[str, bool]] + empty = dict(config={"write_empty_chunks": self._write_empty}) + + from zarr.api.asynchronous import open as zarr_async_open + + zarr_async_array = await zarr_async_open( + store=(self.zarr_group.store), + # TODO: see if zarr should normalize these strings. + path="/".join([self.zarr_group.name.rstrip("/"), name]).lstrip("/"), + **empty, + ) + else: + zarr_async_array = await self._zarr_async_group().getitem(name) + + return cast(ZarrAsyncArray, zarr_async_array) + def _create_new_array( self, *, name, shape, dtype, fill_value, encoding, attrs ) -> ZarrArray: + if _zarr_v3(): + from zarr import Array as ZarrArray + from zarr.api.synchronous import sync + + zarr_async_array = sync( + self._create_new_array_async( + name=name, + shape=shape, + dtype=dtype, + fill_value=fill_value, + encoding=encoding, + attrs=attrs, + ) + ) + return ZarrArray(zarr_async_array) + return self._create_new_array_sync_v2( + name=name, + shape=shape, + dtype=dtype, + fill_value=fill_value, + encoding=encoding, + attrs=attrs, + ) + + def _create_new_array_sync_v2( + self, *, name, shape, dtype, fill_value, encoding, attrs + ) -> ZarrArray: + """Sync create of a new array (zarr v2 only).""" if coding.strings.check_vlen_dtype(dtype) is str: dtype = str @@ -1130,14 +1244,6 @@ def _create_new_array( else: encoding["write_empty_chunks"] = self._write_empty - if _zarr_v3(): - # zarr v3 deprecated origin and write_empty_chunks - # instead preferring to pass them via the config argument - encoding["config"] = {} - for c in ("write_empty_chunks", "order"): - if c in encoding: - encoding["config"][c] = encoding.pop(c) - zarr_array = self.zarr_group.create( name, shape=shape, @@ -1148,6 +1254,42 @@ def _create_new_array( zarr_array = _put_attrs(zarr_array, attrs) return zarr_array + async def _create_new_array_async( + self, *, name, shape, dtype, fill_value, encoding, attrs + ) -> ZarrAsyncArray: + """Async version of _create_new_array using zarr's asynchronous API.""" + if coding.strings.check_vlen_dtype(dtype) is str: + dtype = str + + if self._write_empty is not None: + if ( + "write_empty_chunks" in encoding + and encoding["write_empty_chunks"] != self._write_empty + ): + raise ValueError( + 'Differing "write_empty_chunks" values in encoding and parameters' + f'Got {encoding["write_empty_chunks"] = } and {self._write_empty = }' + ) + else: + encoding["write_empty_chunks"] = self._write_empty + + # zarr v3 deprecated origin and write_empty_chunks + # instead preferring to pass them via the config argument + encoding["config"] = {} + for c in ("write_empty_chunks", "order"): + if c in encoding: + encoding["config"][c] = encoding.pop(c) + + zarr_async_array = await self._zarr_async_group().create_array( + name, + shape=shape, + dtype=dtype, + fill_value=fill_value, + **encoding, + ) + zarr_async_array = await _put_attrs_async(zarr_async_array, attrs) + return zarr_async_array + def set_variables( self, variables: dict[str, Variable], @@ -1171,10 +1313,40 @@ def set_variables( List of dimension names that should be treated as unlimited dimensions. """ + if _zarr_v3(): + from zarr.api.synchronous import sync + + writes = sync( + self._set_variables_async( + variables=variables, + check_encoding_set=check_encoding_set, + writer=writer, + unlimited_dims=unlimited_dims, + ) + ) + else: + writes = self._set_variables_sync_v2( + variables=variables, + check_encoding_set=check_encoding_set, + writer=writer, + unlimited_dims=unlimited_dims, + ) + for source, target, region in writes: + writer.add(source, target, region) + def _set_variables_sync_v2( + self, + variables: dict[str, Variable], + check_encoding_set, + writer, + unlimited_dims=None, + ): + """Sync implementation of set_variables (zarr v2 or internal use).""" existing_keys = self.array_keys() is_zarr_v3_format = _zarr_v3() and self.zarr_group.metadata.zarr_format == 3 + writes = [] + for vn, v in variables.items(): name = _encode_variable_name(vn) attrs = v.attrs.copy() @@ -1296,7 +1468,137 @@ def set_variables( attrs=encoded_attrs, ) - writer.add(v.data, zarr_array, region) + writes.append((v.data, zarr_array, region)) + + return writes + + async def _set_variables_async( + self, + variables: dict[str, Variable], + check_encoding_set, + writer, + unlimited_dims=None, + ): + """Async version of set_variables using zarr's asynchronous API.""" + existing_keys = await self._array_keys_async() + is_zarr_v3_format = _zarr_v3() and self.zarr_group.metadata.zarr_format == 3 + + writes = await asyncio.gather( + *[ + self._set_one_variable_async( + vn, + v, + existing_keys=existing_keys, + is_zarr_v3_format=is_zarr_v3_format, + check_encoding_set=check_encoding_set, + ) + for vn, v in variables.items() + ] + ) + return list(writes) + + async def _set_one_variable_async( + self, + vn: str, + v: Variable, + *, + existing_keys: tuple[str, ...], + is_zarr_v3_format: bool, + check_encoding_set, + ): + """Process a single variable for async set_variables; returns (source, target, region) for writer.add.""" + from zarr import Array as ZarrArray + + name = _encode_variable_name(vn) + attrs = v.attrs.copy() + dims = v.dims + dtype = v.dtype + shape = v.shape + + if self._use_zarr_fill_value_as_mask: + fill_value = attrs.pop("_FillValue", None) + else: + fill_value = v.encoding.pop("fill_value", None) + if fill_value is None and v.dtype.kind == "f": + fill_value = np.nan + if "_FillValue" in attrs: + fv = attrs.pop("_FillValue") + if fv is not None: + attrs["_FillValue"] = FillValueCoder.encode(fv, dtype) + + if "_FillValue" in v.encoding: + if v.encoding.get("_FillValue") is not None: + raise ValueError("Zarr does not support _FillValue in encoding.") + else: + del v.encoding["_FillValue"] + + zarr_shape = None + write_region = self._write_region if self._write_region is not None else {} + write_region = {dim: write_region.get(dim, slice(None)) for dim in dims} + + if self._mode != "w" and name in existing_keys: + # existing variable: open and optionally resize via zarr async API + zarr_async_array = await self._open_existing_array_async(name=name) + if self._append_dim is not None and self._append_dim in dims: + append_axis = dims.index(self._append_dim) + assert write_region[self._append_dim] == slice(None) + write_region[self._append_dim] = slice( + zarr_async_array.shape[append_axis], None + ) + new_shape = ( + zarr_async_array.shape[:append_axis] + + (zarr_async_array.shape[append_axis] + v.shape[append_axis],) + + zarr_async_array.shape[append_axis + 1 :] + ) + await zarr_async_array.resize(new_shape) + zarr_shape = zarr_async_array.shape + region = tuple(write_region[dim] for dim in dims) + + encoding = extract_zarr_variable_encoding( + v, + raise_on_invalid=vn in check_encoding_set, + name=vn, + zarr_format=3 if is_zarr_v3_format else 2, + ) + effective_write_chunks = encoding.get("shards") or encoding["chunks"] + + if self._align_chunks and isinstance(effective_write_chunks, tuple): + v = grid_rechunk( + v=v, + enc_chunks=effective_write_chunks, + region=region, + ) + + if self._safe_chunks and isinstance(effective_write_chunks, tuple): + shape = zarr_shape or v.shape + validate_grid_chunks_alignment( + nd_v_chunks=v.chunks, + enc_chunks=effective_write_chunks, + region=region, + allow_partial_chunks=self._mode != "r+", + name=name, + backend_shape=shape, + ) + + if self._mode == "w" or name not in existing_keys: + encoded_attrs = {k: self.encode_attribute(v) for k, v in attrs.items()} + if is_zarr_v3_format: + encoding["dimension_names"] = dims + else: + encoded_attrs[DIMENSION_KEY] = dims + encoding["overwrite"] = self._mode == "w" + + zarr_async_array = await self._create_new_array_async( + name=name, + dtype=dtype, + shape=shape, + fill_value=fill_value, + encoding=encoding, + attrs=encoded_attrs, + ) + + zarr_array = ZarrArray(zarr_async_array) + return (v.data, zarr_array, region) def sync(self) -> None: pass diff --git a/xarray/core/types.py b/xarray/core/types.py index 69cee210798..351c990bc87 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -62,9 +62,13 @@ ZarrGroup = Any # type: ignore[misc, assignment, unused-ignore] try: # this is V3 only + from zarr import AsyncArray as ZarrAsyncArray + from zarr import AsyncGroup as ZarrAsyncGroup from zarr.storage import StoreLike as ZarrStoreLike except ImportError: ZarrStoreLike = Any # type: ignore[misc, assignment, unused-ignore] + ZarrAsyncArray = Any # type: ignore[misc, assignment, unused-ignore] + ZarrAsyncGroup = Any # type: ignore[misc, assignment, unused-ignore] # Anything that can be coerced to a shape tuple _ShapeLike = Union[SupportsIndex, Sequence[SupportsIndex]]