Skip to content
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ pyarrow = { version = "*", index = "https://pypi.anaconda.org/scientific-python-

dask = { git = "https://github.com/dask/dask" }
distributed = { git = "https://github.com/dask/distributed" }
zarr = { git = "https://github.com/zarr-developers/zarr-python" }
zarr = { git = "https://github.com/maxrjones/zarr-python", branch = "poc/unified-chunk-grid" }
Copy link
Copy Markdown
Contributor

@dcherian dcherian Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
zarr = { git = "https://github.com/maxrjones/zarr-python", branch = "poc/unified-chunk-grid" }

Now that it's on main, we can apply the run-upstream label (which i will do now)

numcodecs = { git = "https://github.com/zarr-developers/numcodecs" }
cftime = { git = "https://github.com/Unidata/cftime" }
# packaging = { git = "https://github.com/pypa/packaging"} #? Pixi warns if this is enabled
Expand Down
75 changes: 64 additions & 11 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import base64
import functools
import importlib.util
import json
import os
import struct
Expand Down Expand Up @@ -46,6 +48,19 @@
from xarray.core.types import ZarrArray, ZarrGroup


@functools.cache
def _has_unified_chunk_grid() -> bool:
"""Check if zarr has the unified ChunkGrid with is_regular support.

Defers the actual import so zarr stays lazy at module load time.
"""
if importlib.util.find_spec("zarr.core.chunk_grids") is None:
return False
from zarr.core.chunk_grids import ChunkGrid

return hasattr(ChunkGrid, "is_regular")


def _get_mappers(*, storage_options, store, chunk_store):
# expand str and path-like arguments
store = _normalize_path(store)
Expand Down Expand Up @@ -284,7 +299,7 @@ async def async_getitem(self, key):
)


def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name):
def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, zarr_format):
"""
Given encoding chunks (possibly None or []) and variable chunks
(possibly None or []).
Expand All @@ -306,18 +321,24 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name):
# while dask chunks can be variable sized
# https://dask.pydata.org/en/latest/array-design.html#chunks
if var_chunks and not enc_chunks:
if zarr_format == 3 and _has_unified_chunk_grid():
return tuple(var_chunks)

if any(len(set(chunks[:-1])) > 1 for chunks in var_chunks):
raise ValueError(
"Zarr requires uniform chunk sizes except for final chunk. "
"Zarr v2 requires uniform chunk sizes except for final chunk. "
f"Variable named {name!r} has incompatible dask chunks: {var_chunks!r}. "
"Consider rechunking using `chunk()`."
"Consider rechunking using `chunk()`, or switching to the "
"zarr v3 format with zarr-python>=3.2."
)
if any((chunks[0] < chunks[-1]) for chunks in var_chunks):
raise ValueError(
"Final chunk of Zarr array must be the same size or smaller "
f"than the first. Variable named {name!r} has incompatible Dask chunks {var_chunks!r}."
"Consider either rechunking using `chunk()` or instead deleting "
"or modifying `encoding['chunks']`."
"Final chunk of a Zarr v2 array or a Zarr v3 array without the "
"rectilinear chunks extension must be the same size or smaller "
f"than the first. Variable named {name!r} has incompatible Dask "
f"chunks {var_chunks!r}. "
"Consider switching to Zarr v3 with the rectilinear chunks extension, "
"rechunking using `chunk()` or deleting or modifying `encoding['chunks']`."
)
# return the first chunk for each dimension
return tuple(chunk[0] for chunk in var_chunks)
Expand All @@ -340,8 +361,17 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name):
var_chunks,
ndim,
name,
zarr_format,
)

# Rectilinear chunks: each element is a sequence of per-chunk edge lengths
if (
zarr_format == 3
and _has_unified_chunk_grid()
and any(not isinstance(x, int) for x in enc_chunks_tuple)
):
return enc_chunks_tuple

for x in enc_chunks_tuple:
if not isinstance(x, int):
raise TypeError(
Expand Down Expand Up @@ -483,6 +513,7 @@ def extract_zarr_variable_encoding(
var_chunks=variable.chunks,
ndim=variable.ndim,
name=name,
zarr_format=zarr_format,
)
if _zarr_v3() and chunks is None:
chunks = "auto"
Expand Down Expand Up @@ -861,9 +892,25 @@ def open_store_variable(self, name):
)
attributes = dict(attributes)

if _has_unified_chunk_grid():
from zarr.core.metadata.v3 import RectilinearChunkGrid, RegularChunkGrid

chunk_grid = zarr_array.metadata.chunk_grid
if isinstance(chunk_grid, RegularChunkGrid):
chunks = chunk_grid.chunk_shape
elif isinstance(chunk_grid, RectilinearChunkGrid):
chunks = chunk_grid.chunk_shapes
else:
chunks = tuple(zarr_array.chunks)
preferred_chunks = dict(zip(dimensions, chunks, strict=True))
else:
# Fallback for older zarr-python without unified chunk grid
chunks = tuple(zarr_array.chunks)
preferred_chunks = dict(zip(dimensions, chunks, strict=True))

encoding = {
"chunks": zarr_array.chunks,
"preferred_chunks": dict(zip(dimensions, zarr_array.chunks, strict=True)),
"chunks": chunks,
"preferred_chunks": preferred_chunks,
}

if _zarr_v3():
Expand Down Expand Up @@ -1248,14 +1295,20 @@ def set_variables(
# parallel writes. See https://github.com/pydata/xarray/issues/10831
effective_write_chunks = encoding.get("shards") or encoding["chunks"]

if self._align_chunks and isinstance(effective_write_chunks, tuple):
# Rectilinear chunks are tuples-of-tuples — align_chunks and
# safe_chunks validation only apply to regular (flat tuple) chunks.
_is_regular_chunks = isinstance(effective_write_chunks, tuple) and all(
isinstance(c, int) for c in effective_write_chunks
)

if self._align_chunks and _is_regular_chunks:
v = grid_rechunk(
v=v,
enc_chunks=effective_write_chunks,
region=region,
)

if self._safe_chunks and isinstance(effective_write_chunks, tuple):
if self._safe_chunks and _is_regular_chunks:
# the hard case
# DESIGN CHOICE: do not allow multiple dask chunks on a single zarr chunk
# (or shard, when sharding is enabled)
Expand Down
48 changes: 48 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7247,6 +7247,54 @@ def test_extract_zarr_variable_encoding() -> None:
)


@requires_zarr_v3
@requires_dask
def test_rectilinear_chunks_encoding_roundtrip(tmp_path: Path) -> None:
"""Rectilinear chunk sizes in encoding are passed through to zarr v3."""

import zarr

if not backends.zarr._has_unified_chunk_grid():
pytest.skip("zarr does not have unified ChunkGrid support")

chunk_sizes = [10, 20, 30]
data = np.arange(60, dtype="float32")
ds = xr.Dataset({"var": xr.Variable("x", data)}).chunk({"x": tuple(chunk_sizes)})

store_path = tmp_path / "rectilinear.zarr"
encoding = {"var": {"chunks": [chunk_sizes]}}

with zarr.config.set({"array.rectilinear_chunks": True}):
ds.to_zarr(store_path, zarr_format=3, mode="w", encoding=encoding)

roundtrip = xr.open_zarr(store_path, zarr_format=3)
assert roundtrip.chunks["x"] == tuple(chunk_sizes)
np.testing.assert_array_equal(roundtrip["var"].values, data)


@requires_zarr_v3
@requires_dask
def test_rectilinear_chunks_no_encoding(tmp_path: Path) -> None:
"""Variable dask chunks are written as rectilinear when no encoding is given."""
import zarr

if not backends.zarr._has_unified_chunk_grid():
pytest.skip("zarr does not have unified ChunkGrid support")

chunk_sizes = [15, 25, 20]
data = np.arange(60, dtype="float32")
ds = xr.Dataset({"var": xr.Variable("x", data)}).chunk({"x": tuple(chunk_sizes)})

store_path = tmp_path / "rectilinear_no_enc.zarr"

with zarr.config.set({"array.rectilinear_chunks": True}):
ds.to_zarr(store_path, zarr_format=3, mode="w")

roundtrip = xr.open_zarr(store_path, zarr_format=3)
assert roundtrip.chunks["x"] == tuple(chunk_sizes)
np.testing.assert_array_equal(roundtrip["var"].values, data)


@requires_zarr
@requires_fsspec
@pytest.mark.filterwarnings("ignore:deallocating CachingFileManager")
Expand Down
Loading