Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History

X.Y.Z (YYYY-MM-DD)
------------------
* Add rechunk by size utility funtion (:pr:`284`)
* Add experimental fragments functionality (:pr:`282`)
* Run CI weekly on Monday @ 2h30 am UTC (:pr:`288`)
* Update minio server and client versions (:pr:`287`)
Expand Down
Empty file.
112 changes: 112 additions & 0 deletions daskms/experimental/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import pytest
import numpy as np
import dask.array as da
from itertools import product, combinations
from daskms.experimental.zarr import xds_to_zarr
from daskms.experimental.utils import rechunk_by_size

xarray = pytest.importorskip("xarray")

ZARR_MAX_CHUNK = 2 ** (32 - 1)


@pytest.fixture(scope="function")
def dataset():
ndim = 3

def get_large_shape(ndim, dtype, max_size=2**31, exceed=1.2):
dim_size = ((exceed * max_size) / dtype().itemsize) ** (1 / ndim)
return (int(np.ceil(dim_size)),) * ndim

large_shape = get_large_shape(ndim, np.complex64)

dv0 = da.zeros(large_shape, dtype=np.complex64, chunks=-1)
dv1 = da.zeros(large_shape, dtype=np.float32, chunks=-1)
dv2 = da.zeros(large_shape[0], dtype=int, chunks=-1)

coord_names = [f"coord{i}" for i in range(ndim)]

xds = xarray.Dataset(
{
"dv0": (coord_names[: dv0.ndim], dv0),
"dv1": (coord_names[: dv1.ndim], dv1),
"dv2": (coord_names[: dv2.ndim], dv2),
},
coords={cn: (cn, range(ds)) for cn, ds in zip(coord_names, large_shape)},
)

return xds


def test_error_before_rechunk(dataset, tmp_path_factory):
"""Original motivating case - chunks too large for zarr compressor."""

tmp_dir = tmp_path_factory.mktemp("datasets")
zarr_path = tmp_dir / "dataset.zarr"

with pytest.raises(ValueError, match=r"Column .* has a chunk of"):
xds_to_zarr(dataset, zarr_path)


def test_error_after_rechunk(dataset, tmp_path_factory):
"""Check that rechunking solves the original morivating case."""

tmp_dir = tmp_path_factory.mktemp("datasets")
zarr_path = tmp_dir / "dataset.zarr"

xds_to_zarr(rechunk_by_size(dataset), zarr_path) # No error.


@pytest.mark.parametrize("max_chunk_mem", [2**28, 2**29, 2**30])
def test_rechunk(dataset, max_chunk_mem):
"""Check that rechunking works for a range of target sizes."""

dataset = rechunk_by_size(dataset, max_chunk_mem=max_chunk_mem)

for dv in dataset.data_vars.values():
itr = product(*map(range, dv.data.blocks.shape))
assert all(dv.data.blocks[i].nbytes < max_chunk_mem for i in itr), (
f"Data variable {dv.name} contains chunks which exceed the "
f"maximum per chunk memory size of {max_chunk_mem}."
)


@pytest.mark.parametrize(
"unchunked_dims",
[*combinations(["coord0", "coord1", "coord2"], 2), *["coord0", "coord1", "coord2"]],
)
def test_rechunk_with_unchunkable_axis(dataset, unchunked_dims):
"""Check that rechunking works when some dimensions must not be chunked."""

dataset = rechunk_by_size(
dataset, max_chunk_mem=ZARR_MAX_CHUNK, unchunked_dims={unchunked_dims}
)

for dv in dataset.data_vars.values():
itr = product(*map(range, dv.data.blocks.shape))
assert all(dv.data.blocks[i].nbytes < ZARR_MAX_CHUNK for i in itr), (
f"Data variable {dv.name} contains chunks which exceed the "
f"maximum per chunk memory size of {ZARR_MAX_CHUNK}."
)


def test_rechunk_impossible(dataset):
"""Check that rechunking raises a sensible error in impossible cases."""

with pytest.raises(ValueError, match="Target chunk size could not be"):
rechunk_by_size(
dataset,
max_chunk_mem=ZARR_MAX_CHUNK,
unchunked_dims={"coord0", "coord1", "coord2"},
)


def test_rechunk_if_required(dataset):
dataset = dataset.chunk({c: 100 for c in dataset.coords.keys()})

rechunked_dataset = rechunk_by_size(dataset, only_when_needed=True)

assert rechunked_dataset.chunks == dataset.chunks, (
"rechunk_by_size has altered chunk sizes even though input dataset "
"did not require rechunking."
)
100 changes: 100 additions & 0 deletions daskms/experimental/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from pathlib import Path
from itertools import product

import dask
import dask.array as da
Expand Down Expand Up @@ -116,3 +117,102 @@ def store_path_split(store):
raise RuntimeError(f"len(parts) {len(parts)} not in (1, 2)")

return store.parent / name, subtable


def largest_chunk(arr):
return max(arr.blocks[i].nbytes for i in product(*map(range, arr.blocks.shape)))


def rechunk_by_size(
dataset, max_chunk_mem=2**31 - 1, unchunked_dims=None, only_when_needed=False
):
"""
Given an xarray.Dataset, rechunk it such that chunking is uniform and
consistent in all dimensions and all chunks are smaller than a specified
size in bytes.

Parameters
----------
dataset : xarray.Dataset
A dataset containing datavars and coords.
max_chunk_mem : int
Target maximum chunk size in bytes.
unchunked_dims: None or set
A set of dimensions which should not be chunked.
only_when_needed: bool
If set, only rechunk if existing chunks violate max_chunk_mem.

Returns
-------
rechunked_dataset : xarray.Dataset
Dataset with appropriate chunking.
"""

def _rechunk(data_array, unchunked_dims):
dims = set(data_array.dims)
unchunked_dims = unchunked_dims & dims
chunked_dims = dims - unchunked_dims

n_dim = len(dims)
n_unchunked_dim = len(unchunked_dims)
n_chunked_dim = n_dim - n_unchunked_dim

dim_sizes = data_array.sizes

# The maximum number of array elements in the chunk.
max_n_ele = max_chunk_mem // data_array.dtype.itemsize
# The number of elements associated with unchunkable dimensions.
fixed_n_ele = np.product([dim_sizes[k] for k in unchunked_dims])

if fixed_n_ele > max_n_ele:
raise ValueError(
f"Target chunk size could not be reached in rechunk_by_size. "
f"Unchunkable dimensions were: {unchunked_dims}."
)

chunk_dict = {k: dim_sizes[k] for k in unchunked_dims}

if n_chunked_dim == 0: # No chunking but still less than target size.
return chunk_dict

ideal_chunk = int(
np.power(max_n_ele / fixed_n_ele, 1 / (n_dim - n_unchunked_dim))
)

chunk_dict.update({k: ideal_chunk for k in chunked_dims})

new_unchunked_dims = {k for k in dims if chunk_dict[k] >= dim_sizes[k]}

if len(new_unchunked_dims) == n_dim:
return {k: dim_sizes[k] for k in unchunked_dims}
elif new_unchunked_dims != unchunked_dims:
return _rechunk(data_array, new_unchunked_dims)
else:
return chunk_dict

# Figure out chunking from the largest arrays to the smallest. NOTE:
# Using nbytes may be unreliable for object arrays.
dvs_and_coords = [*dataset.data_vars.values(), *dataset.coords.values()]
dvs_and_coords = [d for d in dvs_and_coords if isinstance(d.data, da.Array)]
dvs_and_coords = sorted(dvs_and_coords, key=lambda arr: arr.data.nbytes)

if only_when_needed:
largest_chunks = [largest_chunk(dc.data) for dc in dvs_and_coords]
if not any(lc > max_chunk_mem for lc in largest_chunks):
return dataset.copy()

chunk_dims = {}

for data_array in dvs_and_coords[::-1]: # From largest to smallest.
chunk_update = _rechunk(data_array, unchunked_dims or set())

chunk_dims.update(
{
k: min(chunk_update[k], chunk_dims.get(k, chunk_update[k]))
for k in chunk_update.keys()
}
)

rechunked_dataset = dataset.chunk(chunk_dims)

return rechunked_dataset
4 changes: 3 additions & 1 deletion daskms/experimental/zarr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def create_array(ds_group, column, column_schema, schema_chunks, coordinate=Fals
raise ValueError(
f"Column {column} has a chunk of "
f"dimension {zchunks} that will exceed "
f"zarr's 2GiB chunk limit"
f"zarr's 2GiB chunk limit. Consider calling "
f"daskms.experimental.utils.rechunk_by_size "
f"prior to writing."
)

array = ds_group.require_dataset(
Expand Down