diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index d9279dc2de9..998bbf76bab 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -1,6 +1,8 @@ from __future__ import annotations import base64 +import functools +import importlib.util import json import os import struct @@ -46,6 +48,19 @@ from xarray.core.types import ZarrArray, ZarrGroup +@functools.cache +def _has_unified_chunk_grid() -> bool: + """Check if zarr has the unified ChunkGrid with is_regular support. + + Defers the actual import so zarr stays lazy at module load time. + """ + if importlib.util.find_spec("zarr.core.chunk_grids") is None: + return False + from zarr.core.chunk_grids import ChunkGrid + + return hasattr(ChunkGrid, "is_regular") + + def _get_mappers(*, storage_options, store, chunk_store): # expand str and path-like arguments store = _normalize_path(store) @@ -333,7 +348,7 @@ async def async_getitem(self, key): ) -def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name): +def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, zarr_format): """ Given encoding chunks (possibly None or []) and variable chunks (possibly None or []). @@ -355,18 +370,24 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name): # while dask chunks can be variable sized # https://dask.pydata.org/en/latest/array-design.html#chunks if var_chunks and not enc_chunks: + if zarr_format == 3 and _has_unified_chunk_grid(): + return tuple(var_chunks) + if any(len(set(chunks[:-1])) > 1 for chunks in var_chunks): raise ValueError( - "Zarr requires uniform chunk sizes except for final chunk. " + "Zarr v2 requires uniform chunk sizes except for the final chunk. " f"Variable named {name!r} has incompatible dask chunks: {var_chunks!r}. " - "Consider rechunking using `chunk()`." + "Consider rechunking using `chunk()`, or switching to the " + "zarr v3 format with zarr-python>=3.2." ) if any((chunks[0] < chunks[-1]) for chunks in var_chunks): raise ValueError( - "Final chunk of Zarr array must be the same size or smaller " - f"than the first. Variable named {name!r} has incompatible Dask chunks {var_chunks!r}." - "Consider either rechunking using `chunk()` or instead deleting " - "or modifying `encoding['chunks']`." + "The final chunk of a Zarr v2 array or a Zarr v3 array without the " + "rectilinear chunks extension must be the same size or smaller " + f"than the first. Variable named {name!r} has incompatible Dask " + f"chunks {var_chunks!r}. " + "Consider switching to Zarr v3 with the rectilinear chunks extension, " + "rechunking using `chunk()` or deleting or modifying `encoding['chunks']`." ) # return the first chunk for each dimension return tuple(chunk[0] for chunk in var_chunks) @@ -389,8 +410,17 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name): var_chunks, ndim, name, + zarr_format, ) + # Rectilinear chunks: each element is a sequence of per-chunk edge lengths + if ( + zarr_format == 3 + and _has_unified_chunk_grid() + and any(not isinstance(x, int) for x in enc_chunks_tuple) + ): + return enc_chunks_tuple + for x in enc_chunks_tuple: if not isinstance(x, int): raise TypeError( @@ -532,6 +562,7 @@ def extract_zarr_variable_encoding( var_chunks=variable.chunks, ndim=variable.ndim, name=name, + zarr_format=zarr_format, ) if _zarr_v3() and chunks is None: chunks = "auto" @@ -910,9 +941,25 @@ def open_store_variable(self, name): ) attributes = dict(attributes) + if _has_unified_chunk_grid(): + from zarr.core.metadata.v3 import RectilinearChunkGrid, RegularChunkGrid + + chunk_grid = zarr_array.metadata.chunk_grid + if isinstance(chunk_grid, RegularChunkGrid): + chunks = chunk_grid.chunk_shape + elif isinstance(chunk_grid, RectilinearChunkGrid): + chunks = chunk_grid.chunk_shapes + else: + chunks = tuple(zarr_array.chunks) + preferred_chunks = dict(zip(dimensions, chunks, strict=True)) + else: + # Fallback for older zarr-python without unified chunk grid + chunks = tuple(zarr_array.chunks) + preferred_chunks = dict(zip(dimensions, chunks, strict=True)) + encoding = { - "chunks": zarr_array.chunks, - "preferred_chunks": dict(zip(dimensions, zarr_array.chunks, strict=True)), + "chunks": chunks, + "preferred_chunks": preferred_chunks, } if _zarr_v3(): @@ -1297,14 +1344,20 @@ def set_variables( # parallel writes. See https://github.com/pydata/xarray/issues/10831 effective_write_chunks = encoding.get("shards") or encoding["chunks"] - if self._align_chunks and isinstance(effective_write_chunks, tuple): + # Rectilinear chunks are tuples-of-tuples — align_chunks and + # safe_chunks validation only apply to regular (flat tuple) chunks. + _is_regular_chunks = isinstance(effective_write_chunks, tuple) and all( + isinstance(c, int) for c in effective_write_chunks + ) + + if self._align_chunks and _is_regular_chunks: v = grid_rechunk( v=v, enc_chunks=effective_write_chunks, region=region, ) - if self._safe_chunks and isinstance(effective_write_chunks, tuple): + if self._safe_chunks and _is_regular_chunks: # the hard case # DESIGN CHOICE: do not allow multiple dask chunks on a single zarr chunk # (or shard, when sharding is enabled) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index e42bfc2cd9f..1c185a8dd02 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -7299,6 +7299,54 @@ def test_extract_zarr_variable_encoding() -> None: ) +@requires_zarr_v3 +@requires_dask +def test_rectilinear_chunks_encoding_roundtrip(tmp_path: Path) -> None: + """Rectilinear chunk sizes in encoding are passed through to zarr v3.""" + + import zarr + + if not backends.zarr._has_unified_chunk_grid(): + pytest.skip("zarr does not have unified ChunkGrid support") + + chunk_sizes = [10, 20, 30] + data = np.arange(60, dtype="float32") + ds = xr.Dataset({"var": xr.Variable("x", data)}).chunk({"x": tuple(chunk_sizes)}) + + store_path = tmp_path / "rectilinear.zarr" + encoding = {"var": {"chunks": [chunk_sizes]}} + + with zarr.config.set({"array.rectilinear_chunks": True}): + ds.to_zarr(store_path, zarr_format=3, mode="w", encoding=encoding) + + roundtrip = xr.open_zarr(store_path, zarr_format=3) + assert roundtrip.chunks["x"] == tuple(chunk_sizes) + np.testing.assert_array_equal(roundtrip["var"].values, data) + + +@requires_zarr_v3 +@requires_dask +def test_rectilinear_chunks_no_encoding(tmp_path: Path) -> None: + """Variable dask chunks are written as rectilinear when no encoding is given.""" + import zarr + + if not backends.zarr._has_unified_chunk_grid(): + pytest.skip("zarr does not have unified ChunkGrid support") + + chunk_sizes = [15, 25, 20] + data = np.arange(60, dtype="float32") + ds = xr.Dataset({"var": xr.Variable("x", data)}).chunk({"x": tuple(chunk_sizes)}) + + store_path = tmp_path / "rectilinear_no_enc.zarr" + + with zarr.config.set({"array.rectilinear_chunks": True}): + ds.to_zarr(store_path, zarr_format=3, mode="w") + + roundtrip = xr.open_zarr(store_path, zarr_format=3) + assert roundtrip.chunks["x"] == tuple(chunk_sizes) + np.testing.assert_array_equal(roundtrip["var"].values, data) + + @requires_zarr @requires_fsspec @pytest.mark.filterwarnings("ignore:deallocating CachingFileManager")