Skip to content
27 changes: 25 additions & 2 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ class FillValueCoder:
"""

@classmethod
def encode(cls, value: int | float | str | bytes, dtype: np.dtype[Any]) -> Any:
def encode(
cls, value: int | float | complex | str | bytes, dtype: np.dtype[Any]
) -> Any:
if dtype.kind in "S":
# byte string, this implies that 'value' must also be `bytes` dtype.
assert isinstance(value, bytes)
Expand All @@ -135,13 +137,28 @@ def encode(cls, value: int | float | str | bytes, dtype: np.dtype[Any]) -> Any:
return int(value)
elif dtype.kind in "f":
return base64.standard_b64encode(struct.pack("<d", float(value))).decode()
elif dtype.kind in "c":
# complex - encode each component as base64, matching float encoding
assert isinstance(value, complex) or np.issubdtype(
type(value), np.complexfloating
)
return [
base64.standard_b64encode(
struct.pack("<d", float(value.real))
).decode(),
base64.standard_b64encode(
struct.pack("<d", float(value.imag))
).decode(),
]
elif dtype.kind in "U":
return str(value)
else:
raise ValueError(f"Failed to encode fill_value. Unsupported dtype {dtype}")

@classmethod
def decode(cls, value: int | float | str | bytes, dtype: str | np.dtype[Any]):
def decode(
cls, value: int | float | str | bytes | list, dtype: str | np.dtype[Any]
):
if dtype == "string":
# zarr V3 string type
return str(value)
Expand All @@ -153,6 +170,12 @@ def decode(cls, value: int | float | str | bytes, dtype: str | np.dtype[Any]):
if np_dtype.kind in "f":
assert isinstance(value, str | bytes)
return struct.unpack("<d", base64.standard_b64decode(value))[0]
elif np_dtype.kind in "c":
# complex - decode each component from base64, matching float decoding
assert isinstance(value, list | tuple) and len(value) == 2
real = struct.unpack("<d", base64.standard_b64decode(value[0]))[0]
imag = struct.unpack("<d", base64.standard_b64decode(value[1]))[0]
return complex(real, imag)
elif np_dtype.kind in "b":
return bool(value)
elif np_dtype.kind in "iu":
Expand Down
12 changes: 12 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7099,6 +7099,18 @@ def test_encode_zarr_attr_value() -> None:
assert actual3 == expected3


@requires_zarr
@pytest.mark.parametrize("dtype", [complex, np.complex64, np.complex128])
def test_fill_value_coder_complex(dtype) -> None:
"""Test that FillValueCoder round-trips complex fill values."""
from xarray.backends.zarr import FillValueCoder

for value in [dtype(1 + 2j), dtype(-3.5 + 4.5j), dtype(complex("nan+nanj"))]:
encoded = FillValueCoder.encode(value, np.dtype(dtype))
decoded = FillValueCoder.decode(encoded, np.dtype(dtype))
np.testing.assert_equal(np.array(decoded, dtype=dtype), np.array(value))


@requires_zarr
def test_extract_zarr_variable_encoding() -> None:
var = xr.Variable("x", [1, 2])
Expand Down
Loading