diff --git a/docs/api.md b/docs/api.md index cf399fd6..38d1139c 100644 --- a/docs/api.md +++ b/docs/api.md @@ -14,6 +14,7 @@ Loader Loader.__iter__ + ChunkSampler ``` (io-helpers)= @@ -28,6 +29,15 @@ DatasetCollection ``` +(abc)= +## abc +```{eval-rst} +.. autosummary:: + :toctree: generated/ + + abc.Sampler +``` + (types)= ## types @@ -36,4 +46,5 @@ :toctree: generated/ types.LoaderOutput + types.LoadRequest ``` diff --git a/pyproject.toml b/pyproject.toml index 3ab743c3..94741266 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -178,6 +178,10 @@ omit = [ "**/test_*.py", ] +[[tool.mypy.overrides]] +module = [ "anndata.*", "cupyx.*", "cupy.*" ] +ignore_missing_imports = true + [tool.cruft] skip = [ "tests", diff --git a/src/annbatch/__init__.py b/src/annbatch/__init__.py index d53341c0..39180c0b 100644 --- a/src/annbatch/__init__.py +++ b/src/annbatch/__init__.py @@ -2,10 +2,18 @@ from importlib.metadata import version -from . import types +from . import abc, types from .io import DatasetCollection, write_sharded from .loader import Loader +from .samplers._chunk_sampler import ChunkSampler __version__ = version("annbatch") -__all__ = ["Loader", "write_sharded", "DatasetCollection", "types"] +__all__ = [ + "Loader", + "DatasetCollection", + "types", + "write_sharded", + "ChunkSampler", + "abc", +] diff --git a/src/annbatch/abc/__init__.py b/src/annbatch/abc/__init__.py new file mode 100644 index 00000000..9a3f765f --- /dev/null +++ b/src/annbatch/abc/__init__.py @@ -0,0 +1,5 @@ +from .sampler import Sampler + +__all__ = [ + "Sampler", +] diff --git a/src/annbatch/abc/sampler.py b/src/annbatch/abc/sampler.py new file mode 100644 index 00000000..b22b7b17 --- /dev/null +++ b/src/annbatch/abc/sampler.py @@ -0,0 +1,70 @@ +"""Sampler classes for efficient chunk-based data access.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Iterator + + from annbatch.types import LoadRequest + + +class Sampler(ABC): + """Base sampler class. + + Samplers control how data is batched and loaded from the underlying datasets. + """ + + def sample(self, n_obs: int) -> Iterator[LoadRequest]: + """Sample load requests given the total number of observations. + + Parameters + ---------- + n_obs + The total number of observations available. + + Yields + ------ + LoadRequest + Load requests for batching data. + """ + self.validate(n_obs) + yield from self._sample(n_obs) + + @abstractmethod + def validate(self, n_obs: int) -> None: + """Validate the sampler configuration against the given n_obs. + + This method is called at the start of each `sample()` call. + Override this method to add custom validation for sampler parameters. + + Parameters + ---------- + n_obs + The total number of observations in the loader. + + Raises + ------ + ValueError + If the sampler configuration is invalid for the given n_obs. + """ + + @abstractmethod + def _sample(self, n_obs: int) -> Iterator[LoadRequest]: + """Implementation of the sample method. + + This method is called by the sample method to perform the actual sampling after + validation has passed. + + Parameters + ---------- + n_obs + The total number of observations available. + + Yields + ------ + LoadRequest + Load requests for batching data. + """ diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index b5366737..6f22d5d2 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import math from collections import OrderedDict, defaultdict from functools import singledispatchmethod from importlib.util import find_spec @@ -16,17 +15,16 @@ from scipy import sparse as sp from zarr import Array as ZarrArray +from annbatch.samplers import ChunkSampler from annbatch.types import BackingArray_T, InputInMemoryArray_T, LoaderOutput, OutputInMemoryArray_T from annbatch.utils import ( CSRContainer, MultiBasicIndexer, - WorkerHandle, - _batched, check_lt_1, check_var_shapes, load_x_and_obs, - split_given_size, to_torch, + validate_sampler, ) from .compat import IterableDataset @@ -35,6 +33,7 @@ from collections.abc import Callable, Iterator from types import ModuleType + from annbatch.abc import Sampler from annbatch.io import DatasetCollection # TODO: remove after sphinx 9 - myst compat @@ -68,30 +67,36 @@ class Loader[ If `preload_to_gpu` to True and `to_torch` is False, the yielded type is a `cupy` matrix. If `to_torch` is True, the yielded type is a :class:`torch.Tensor`. - If both `preload_to_gpu` and `to_torch` are False, then the return type is the CPU class for the fiven data type. + If both `preload_to_gpu` and `to_torch` are False, then the return type is the CPU class for the given data type. + When providing a custom sampler, `chunk_size`, `preload_nchunks`, `batch_size`, + `shuffle`, and `drop_last` must not be set (they are controlled by the `batch_sampler` instead). + When providing these arguments and no `batch_sampler`, they are used to construct a :class:`annbatch.ChunkSampler`. Parameters ---------- + batch_sampler + If not provided, a default :class:`annbatch.ChunkSampler` will be used with the same defaults below. chunk_size - The obs size (i.e., axis 0) of contiguous array data to fetch. + The obs size (i.e., axis 0) of contiguous array data to fetch. Mutually exclusive with `batch_sampler`. Defaults to 512. preload_nchunks - The number of chunks of contiguous array data to fetch. + The number of chunks of contiguous array data to fetch. Mutually exclusive with `batch_sampler`. Defaults to 32. shuffle - Whether or not to shuffle the data. + Whether or not to shuffle the data. Mutually exclusive with `batch_sampler`. Defaults to False. + batch_size + Batch size to yield from the dataset. Mutually exclusive with `batch_sampler`. Defaults to 1. + drop_last + Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. + If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. + Leave as False when using in conjunction with a :class:`torch.utils.data.DataLoader`. + Mutually exclusive with `batch_sampler`. Defaults to False. return_index Whether or not to yield the index on each iteration. - batch_size - Batch size to yield from the dataset. preload_to_gpu Whether or not to use cupy for non-io array operations like vstack and indexing once the data is in memory internally. This option entails greater GPU memory usage, but is faster at least for sparse operations. :func:`torch.vstack` does not support CSR sparse matrices, hence the current use of cupy internally. Setting this to `False` is advisable when using the :class:`torch.utils.data.DataLoader` wrapper or potentially with dense data. - For top performance, this should be used in conjuction with `to_torch` and then :meth:`torch.Tensor.to_dense` if you wish to denseify. - drop_last - Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. - If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. - Leave as False when using in conjunction with a :class:`torch.utils.data.DataLoader`. + For top performance, this should be used in conjuction with `to_torch` and then :meth:`torch.Tensor.to_dense` if you wish to densify. to_torch Whether to return `torch.Tensor` as the output. Data transferred should be 0-copy independent of source, and transfer to cuda when applicable is non-blocking. @@ -111,56 +116,68 @@ class Loader[ do_fit(batch) """ + _COMMON_SAMPLER_ARGS = { + "chunk_size": 512, + "preload_nchunks": 32, + "batch_size": 1, + "shuffle": False, + "drop_last": False, + } + # TODO(selmanozleyen): these should be also presented in the documentation + # but this is not ideal since they are hardcoded into the docstrings + # maybe we should make _COMMON_SAMPLER_ARGS a public class field? + _train_datasets: list[BackingArray] _obs: list[pd.DataFrame] | None = None _return_index: bool = False - _batch_size: int = 1 _shapes: list[tuple[int, int]] _preload_to_gpu: bool = True - _drop_last: bool = False _to_torch: bool = True - _shuffle: bool - _preload_nchunks: int - _worker_handle: WorkerHandle - _chunk_size: int _dataset_elem_cache: dict[int, CSRDatasetElems] + _batch_sampler: Sampler def __init__( self, *, - chunk_size: int = 512, - preload_nchunks: int = 32, - shuffle: bool = True, + batch_sampler: Sampler | None = None, + chunk_size: int | None = None, + preload_nchunks: int | None = None, + shuffle: bool | None = None, return_index: bool = False, - batch_size: int = 1, + batch_size: int | None = None, preload_to_gpu: bool = find_spec("cupy") is not None, - drop_last: bool = False, + drop_last: bool | None = None, to_torch: bool = find_spec("torch") is not None, ): - check_lt_1( - [ - chunk_size, - preload_nchunks, - ], - ["Chunk size", "Preload chunks"], - ) - if batch_size > (chunk_size * preload_nchunks): - raise NotImplementedError( - "Cannot yield batches bigger than the iterated in-memory size i.e., batch_size > (chunk_size * preload_nchunks)." - ) + sampler_args = { + "chunk_size": chunk_size, + "preload_nchunks": preload_nchunks, + "batch_size": batch_size, + "shuffle": shuffle, + "drop_last": drop_last, + } + if batch_sampler is not None: + if any(v is not None for v in sampler_args.values()): + provided_args = [name for name, val in sampler_args.items() if val is not None] + raise ValueError( + f"Cannot specify {', '.join(provided_args)} when providing a custom sampler. " + "These parameters are controlled by the sampler." + ) + self._batch_sampler = batch_sampler + else: + sampler_args_processed = { + k: (v if v is not None else Loader._COMMON_SAMPLER_ARGS[k]) for k, v in sampler_args.items() + } + self._batch_sampler = ChunkSampler(**sampler_args_processed) + if to_torch and not find_spec("torch"): - raise ImportError("Could not find torch dependeny. Try `pip install torch`.") + raise ImportError("Could not find torch dependency. Try `pip install torch`.") if preload_to_gpu and not find_spec("cupy"): raise ImportError("Follow the directions at https://docs.cupy.dev/en/stable/install.html to install cupy.") + self._return_index = return_index - self._batch_size = batch_size self._preload_to_gpu = preload_to_gpu self._to_torch = to_torch - self._drop_last = drop_last - self._chunk_size = chunk_size - self._preload_nchunks = preload_nchunks - self._shuffle = shuffle - self._worker_handle = WorkerHandle() self._train_datasets = [] self._shapes = [] self._dataset_elem_cache = {} @@ -223,8 +240,20 @@ def n_var(self) -> int: ------- The number of variables. """ + if len(self._shapes) == 0: + raise ValueError("No datasets added yet") return self._shapes[0][1] + @property + def batch_sampler(self) -> Sampler: + """The sampler used to generate batches. + + Returns + ------- + The sampler. + """ + return self._batch_sampler + def use_collection( self, collection: DatasetCollection, *, load_adata: Callable[[zarr.Group], ad.AnnData] = load_x_and_obs ) -> Self: @@ -252,6 +281,7 @@ def use_collection( self._collection_added = True return self + @validate_sampler def add_anndatas( self, adatas: list[ad.AnnData], @@ -265,7 +295,8 @@ def add_anndatas( """ check_lt_1([len(adatas)], ["Number of anndatas"]) for adata in adatas: - self.add_anndata(adata) + dataset, obs = self._prepare_dataset_and_obs(adata) + self._add_dataset_unchecked(dataset, obs) return self def add_anndata(self, adata: ad.AnnData) -> Self: @@ -276,15 +307,20 @@ def add_anndata(self, adata: ad.AnnData) -> Self: adata A :class:`anndata.AnnData` object, with :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` as the data matrix in :attr:`~anndata.AnnData.X`, and :attr:`~anndata.AnnData.obs` containing annotations to yield in a :class:`pandas.DataFrame`. """ + dataset, obs = self._prepare_dataset_and_obs(adata) + self.add_dataset(dataset, obs) + return self + + def _prepare_dataset_and_obs(self, adata: ad.AnnData) -> tuple[BackingArray, pd.DataFrame | None]: dataset = adata.X obs = adata.obs if len(obs.columns) == 0: obs = None if not isinstance(dataset, BackingArray_T.__value__): raise TypeError(f"Found {type(dataset)} but only {BackingArray_T.__value__} are usable") - self.add_dataset(cast("BackingArray", dataset), obs) - return self + return cast("BackingArray", dataset), obs + @validate_sampler def add_datasets(self, datasets: list[BackingArray], obs: list[pd.DataFrame] | None = None) -> Self: """Append datasets to this dataset. @@ -299,9 +335,10 @@ def add_datasets(self, datasets: list[BackingArray], obs: list[pd.DataFrame] | N if obs is None: obs = [None] * len(datasets) for ds, o in zip(datasets, obs, strict=True): - self.add_dataset(ds, o) + self._add_dataset_unchecked(ds, o) return self + @validate_sampler def add_dataset(self, dataset: BackingArray, obs: pd.DataFrame | None = None) -> Self: """Append a dataset to this dataset. @@ -312,6 +349,10 @@ def add_dataset(self, dataset: BackingArray, obs: pd.DataFrame | None = None) -> obs :class:`~pandas.DataFrame` obs, generally from :attr:`anndata.AnnData.obs`. """ + self._add_dataset_unchecked(dataset, obs) + return self + + def _add_dataset_unchecked(self, dataset: BackingArray, obs: pd.DataFrame | None = None) -> Self: if len(self._train_datasets) > 0: if self._obs is None and obs is not None: raise ValueError( @@ -389,6 +430,8 @@ def _slices_to_slices_with_array_index( ) -> OrderedDict[int, list[slice]]: """Given a list of slices, give the lookup between on-disk datasets and slices relative to that dataset. + In the codebase we use slice and chunk interchangeably. Not to be confused with the zarr chunking/sharding terminology. + Parameters ---------- slices @@ -401,8 +444,8 @@ def _slices_to_slices_with_array_index( A lookup between the dataset and its indexing slices, ordered by keys. """ dataset_index_to_slices: defaultdict[int, list[slice]] = defaultdict(list) - for slice in slices: - for relative_obs_indices in self._get_relative_obs_indices(slice, use_original_space=use_original_space): + for slice_ in slices: + for relative_obs_indices in self._get_relative_obs_indices(slice_, use_original_space=use_original_space): dataset_index_to_slices[relative_obs_indices[1]] += [relative_obs_indices[0]] keys = sorted(dataset_index_to_slices.keys()) dataset_index_to_slices_sorted = OrderedDict() @@ -410,19 +453,6 @@ def _slices_to_slices_with_array_index( dataset_index_to_slices_sorted[k] = dataset_index_to_slices[k] return dataset_index_to_slices_sorted - def _get_chunks(self, chunk_size: int) -> np.ndarray: - """Get a potentially shuffled list of chunk ids, accounting for the fact that this dataset might be inside a worker. - - Returns - ------- - A :class:`numpy.ndarray` of chunk ids. - """ - chunks = np.arange(math.ceil(self.n_obs / chunk_size)) - if self._shuffle: - self._worker_handle.shuffle(chunks) - - return self._worker_handle.get_part_for_worker(chunks) - @singledispatchmethod async def _fetch_data(self, dataset: ZarrArray | CSRDatasetElems, slices: list[slice]) -> InputInMemoryArray: """Fetch data from an on-disk store. @@ -593,105 +623,63 @@ def __iter__( [len(self._train_datasets), self.n_obs], ["Number of datasets", "Number of observations"], ) - # In order to handle data returned where (chunk_size * preload_nchunks) mod batch_size != 0 - # we must keep track of the leftover data. - in_memory_data = None - concatenated_obs = None - in_memory_indices = None - mod = self._sp_module if issubclass(self.dataset_type, ad.abc.CSRDataset) else np - for chunk_indices in _batched(self._get_chunks(self._chunk_size), self._preload_nchunks): - slices = [ - slice( - index * self._chunk_size, - min(self.n_obs, (index + 1) * self._chunk_size), - ) - for index in chunk_indices - ] - dataset_index_to_slices = self._slices_to_slices_with_array_index(slices) + + for load_request in self._batch_sampler.sample(self.n_obs): + chunks_to_load = load_request["chunks"] + splits = load_request["splits"] + # Sampler yields a list of slices that sum to batch_size + dataset_index_to_slices = self._slices_to_slices_with_array_index(chunks_to_load, use_original_space=False) # Fetch the data over slices chunks: list[InputInMemoryArray] = zsync.sync(self._index_datasets(dataset_index_to_slices)) - if any(isinstance(c, CSRContainer) for c in chunks): - chunks_converted: list[OutputInMemoryArray] = [ + in_memory_data: OutputInMemoryArray_T = self._accumulate_chunks(chunks) + # Accumulate labels and indices if possible + concatenated_obs: None | pd.DataFrame = self._maybe_accumulate_obs(dataset_index_to_slices) + in_memory_indices: None | np.ndarray = self._maybe_accumulate_indices(chunks_to_load) + + for split in splits: + data = in_memory_data[split] + yield { + "X": data if not self._to_torch else to_torch(data, self._preload_to_gpu), + "obs": concatenated_obs.iloc[split] if concatenated_obs is not None else None, + "index": in_memory_indices[split] if in_memory_indices is not None else None, + } + + def _accumulate_chunks(self, chunks: list[InputInMemoryArray]) -> OutputInMemoryArray_T: + """Convert fetched chunks to output array format (CSR or ndarray).""" + result: list[OutputInMemoryArray_T] = [] + for chunk in chunks: + if isinstance(chunk, CSRContainer): + result.append( self._sp_module.csr_matrix( - tuple(self._np_module.asarray(e) for e in c.elems), - shape=c.shape, - dtype="float64" if self._preload_to_gpu else c.dtype, + tuple(self._np_module.asarray(e) for e in chunk.elems), + shape=chunk.shape, + dtype="float64" if self._preload_to_gpu else chunk.dtype, ) - for c in chunks - ] - else: - chunks_converted = [self._np_module.asarray(c) for c in chunks] - # Accumulate obs - obs: None | list[pd.DataFrame] = None - if self._obs is not None: - obs = [] - for dataset_idx in dataset_index_to_slices.keys(): - obs += [ - self._obs[dataset_idx].iloc[ - np.concatenate([np.arange(s.start, s.stop) for s in dataset_index_to_slices[dataset_idx]]) - ] - ] - # Accumulate indices if necessary - indices: None | list[np.ndarray] = None - if self._return_index: - dataset_index_to_slices = self._slices_to_slices_with_array_index(slices, use_original_space=True) - dataset_indices = dataset_index_to_slices.keys() - indices = [ - np.concatenate( - [ - np.arange( - s.start, - s.stop, - ) - for s in dataset_index_to_slices[index] - ] - ) - for index in dataset_indices - ] - # Do batch returns, handling leftover data as necessary - in_memory_data = ( - mod.vstack(chunks_converted) - if in_memory_data is None - else mod.vstack([in_memory_data, *chunks_converted]) - ) - if self._obs is not None: - concatenated_obs = pd.concat(obs) if concatenated_obs is None else pd.concat([concatenated_obs, *obs]) - if self._return_index: - in_memory_indices = ( - np.concatenate(indices) - if in_memory_indices is None - else np.concatenate([in_memory_indices, *indices]) ) - # Create random indices into in_memory_data and then index into it - # If there is "leftover" at the end (see the modulo op), - # save it for the next iteration. - batch_indices = np.arange(in_memory_data.shape[0]) - if self._shuffle: - np.random.default_rng().shuffle(batch_indices) - splits = split_given_size(batch_indices, self._batch_size) - for i, s in enumerate(splits): - if s.shape[0] == self._batch_size: - output: LoaderOutput = { - "X": to_torch(in_memory_data[s], self._preload_to_gpu) if self._to_torch else in_memory_data[s], - "obs": concatenated_obs.iloc[s] if self._obs is not None else None, - "index": in_memory_indices[s] if self._return_index else None, - } - yield output - if i == (len(splits) - 1): # end of iteration, leftover data needs be kept - if (s.shape[0] % self._batch_size) != 0: - in_memory_data = in_memory_data[s] - if concatenated_obs is not None: - concatenated_obs = concatenated_obs.iloc[s] - if in_memory_indices is not None: - in_memory_indices = in_memory_indices[s] - else: - in_memory_data = None - concatenated_obs = None - in_memory_indices = None - if in_memory_data is not None and not self._drop_last: # handle any leftover data - output: LoaderOutput = { - "X": to_torch(in_memory_data, self._preload_to_gpu) if self._to_torch else in_memory_data, - "obs": concatenated_obs if self._obs is not None else None, - "index": in_memory_indices if self._return_index else None, - } - yield output + else: + result.append(self._np_module.asarray(chunk)) + mod = self._sp_module if issubclass(self.dataset_type, ad.abc.CSRDataset) else np + return mod.vstack(result) + + def _maybe_accumulate_obs(self, dataset_index_to_slices: OrderedDict[int, list[slice]]) -> pd.DataFrame | None: + """Gather obs labels for the loaded slices if possible.""" + if self._obs is None: + return None + return pd.concat( + [ + self._obs[idx].iloc[np.concatenate([np.arange(s.start, s.stop) for s in slices])] + for idx, slices in dataset_index_to_slices.items() + ] + ) + + def _maybe_accumulate_indices(self, slices: list[slice]) -> np.ndarray | None: + """Gather original indices for the loaded slices if possible.""" + if self._return_index is False: + return None + dataset_index_to_slices = self._slices_to_slices_with_array_index(slices, use_original_space=True) + return np.concatenate( + [ + np.concatenate([np.arange(s.start, s.stop) for s in dataset_index_to_slices[idx]]) + for idx in dataset_index_to_slices + ] + ) diff --git a/src/annbatch/samplers/__init__.py b/src/annbatch/samplers/__init__.py new file mode 100644 index 00000000..9f92bbf0 --- /dev/null +++ b/src/annbatch/samplers/__init__.py @@ -0,0 +1,5 @@ +from ._chunk_sampler import ChunkSampler + +__all__ = [ + "ChunkSampler", +] diff --git a/src/annbatch/samplers/_chunk_sampler.py b/src/annbatch/samplers/_chunk_sampler.py new file mode 100644 index 00000000..aebf05c3 --- /dev/null +++ b/src/annbatch/samplers/_chunk_sampler.py @@ -0,0 +1,184 @@ +"""Sampler classes for efficient chunk-based data access.""" + +from __future__ import annotations + +import math +from importlib.util import find_spec +from typing import TYPE_CHECKING + +import numpy as np + +from annbatch.abc import Sampler +from annbatch.utils import check_lt_1, split_given_size + +if TYPE_CHECKING: + from collections.abc import Iterator + + from annbatch.types import LoadRequest + from annbatch.utils import WorkerHandle + + +class ChunkSampler(Sampler): + """Chunk-based sampler for batched data access. + + Parameters + ---------- + batch_size + Number of observations per batch. + chunk_size + Size of each chunk i.e. the range of each chunk yielded. + mask + A slice defining the observation range to sample from (start:stop). + shuffle + Whether to shuffle chunk and index order. + preload_nchunks + Number of chunks to load per iteration. + drop_last + Whether to drop the last incomplete batch. + rng + Random number generator for shuffling. + """ + + _batch_size: int + _chunk_size: int + _shuffle: bool + _preload_nchunks: int + _mask: slice + _drop_last: bool + _rng: np.random.Generator + + def __init__( + self, + chunk_size: int, + preload_nchunks: int, + batch_size: int, + *, + mask: slice | None = None, + shuffle: bool = False, + drop_last: bool = False, + rng: np.random.Generator | None = None, + ): + if mask is None: + mask = slice(0, None) + if mask.step is not None and mask.step != 1: + raise ValueError(f"mask.step must be 1, but got {mask.step}") + start, stop = mask.start or 0, mask.stop + if start < 0: + raise ValueError("mask.start must be >= 0") + if stop is not None and start >= stop: + raise ValueError("mask.start must be < mask.stop when mask.stop is specified") + + check_lt_1([chunk_size, preload_nchunks], ["Chunk size", "Preloaded chunks"]) + preload_size = chunk_size * preload_nchunks + + if batch_size > preload_size: + raise ValueError( + "batch_size cannot exceed chunk_size * preload_nchunks. " + f"Got batch_size={batch_size}, but max is {preload_size}." + ) + if preload_size % batch_size != 0: + raise ValueError( + "chunk_size * preload_nchunks must be divisible by batch_size. " + f"Got {preload_size} % {batch_size} = {preload_size % batch_size}." + ) + self._rng = rng or np.random.default_rng() + self._batch_size, self._chunk_size, self._shuffle = batch_size, chunk_size, shuffle + self._preload_nchunks, self._mask, self._drop_last = ( + preload_nchunks, + slice(start, stop), + drop_last, + ) + + def validate(self, n_obs: int) -> None: + """Validate the sampler configuration against the loader's n_obs. + + Parameters + ---------- + n_obs + The total number of observations in the loader. + + Raises + ------ + ValueError + If the sampler configuration is invalid for the given n_obs. + """ + start, stop = self._mask.start or 0, self._mask.stop or n_obs + if stop > n_obs: + raise ValueError( + f"Sampler mask.stop ({stop}) exceeds loader n_obs ({n_obs}). " + "The sampler range must be within the loader's observations." + ) + if start >= stop: + raise ValueError(f"Sampler mask.start ({start}) must be < mask.stop ({stop}).") + + def _get_worker_handle(self) -> WorkerHandle | None: + worker_handle = None + if find_spec("torch"): + from torch.utils.data import get_worker_info + + from annbatch.utils import WorkerHandle + + if get_worker_info() is not None: + worker_handle = WorkerHandle() + # Worker mode validation - only check when there are multiple workers + # With batch_size=1, every batch is exactly 1 item, so no partial batches exist + if ( + worker_handle is not None + and worker_handle.num_workers > 1 + and not self._drop_last + and self._batch_size != 1 + ): + raise ValueError("When using DataLoader with multiple workers drop_last=False is not supported.") + return worker_handle + + def _sample(self, n_obs: int) -> Iterator[LoadRequest]: + worker_handle = self._get_worker_handle() + start, stop = self._mask.start or 0, self._mask.stop or n_obs + # Compute chunks directly from resolved mask range + # Create chunk indices for possible shuffling and worker sharding + chunk_indices = np.arange(math.ceil((stop - start) / self._chunk_size)) + if self._shuffle: + if worker_handle is None: + self._rng.shuffle(chunk_indices) + else: + worker_handle.shuffle(chunk_indices) + chunks = self._compute_chunks(chunk_indices, start, stop) + # Worker sharding: each worker gets a disjoint subset of chunks + if worker_handle is not None: + chunks = worker_handle.get_part_for_worker(chunks) + # Set up the iterator for chunks and the batch indices for splits + in_memory_size = self._chunk_size * self._preload_nchunks + chunks_per_batch = split_given_size(chunks, self._preload_nchunks) + batch_indices = np.arange(in_memory_size) + split_batch_indices = split_given_size(batch_indices, self._batch_size) + for batch_chunks in chunks_per_batch[:-1]: + if self._shuffle: + # Avoid copies using in-place shuffling since `self._shuffle` should not change mid-training + np.random.default_rng().shuffle(batch_indices) + split_batch_indices = split_given_size(batch_indices, self._batch_size) + yield {"chunks": batch_chunks, "splits": split_batch_indices} + # On the last yield, drop the last uneven batch and create new batch_indices since the in-memory size of this last yield could be divisible by batch_size but smaller than preload_nslices * slice_size + final_chunks = chunks_per_batch[-1] + total_obs_in_last_batch = int(sum(s.stop - s.start for s in final_chunks)) + if self._drop_last: + total_obs_in_last_batch -= total_obs_in_last_batch % self._batch_size + batch_indices = split_given_size( + (np.random.default_rng().permutation if self._shuffle else np.arange)(total_obs_in_last_batch), + self._batch_size, + ) + yield {"chunks": final_chunks, "splits": batch_indices} + + def _compute_chunks(self, chunk_indices: np.ndarray, start: int, stop: int) -> list[slice]: + """Compute chunks from start and stop indices. + + This function is used to compute the chunks for the data to load. + The chunks are computed such that the last chunk is the incomplete chunk if the total number of observations is not divisible by the chunk size. + Supposed to also work with shuffled chunk indices so that the last chunk computed isn't always the incomplete chunk. + """ + n_chunks, pivot_index = len(chunk_indices), chunk_indices[-1] + offsets = np.ones(n_chunks + 1, dtype=int) * self._chunk_size + offsets[0] = start + offsets[pivot_index + 1] = incomplete if (incomplete := (stop - start) % self._chunk_size) else self._chunk_size + offsets = np.cumsum(offsets) + starts, stops = offsets[:-1][chunk_indices], offsets[1:][chunk_indices] + return [slice(int(s), int(e)) for s, e in zip(starts, stops, strict=True)] diff --git a/src/annbatch/types.py b/src/annbatch/types.py index 2f1f9619..ba39b1b3 100644 --- a/src/annbatch/types.py +++ b/src/annbatch/types.py @@ -16,6 +16,25 @@ type OutputInMemoryArray_T = sp.csr_matrix | np.ndarray | CupyCSRMatrix | CupyArray | Tensor +class LoadRequest(TypedDict): + """Load request from sampler. + + This is the request format Loader will expect from the sampler. + Not satisfying the constrains documented here may result in unexpected behavior. + + Attributes + ---------- + chunks + Chunks to load - a list of slices with a range of chunk_size except the last one which may be smaller but not empty. + splits + How the in-memory data should be split into batches after it is read off disk and concatenated in-memory. + A list of splits, last one may be partial but not empty i.e. 1 <= len(last_split) <= batch_size. + """ + + chunks: list[slice] + splits: list[np.ndarray] + + class LoaderOutput[OutputInMemoryArray: OutputInMemoryArray_T](TypedDict): """The output of the loader, the "data matrix" with its obs, optional, and index, also optional.""" diff --git a/src/annbatch/utils.py b/src/annbatch/utils.py index 7ba876ac..0bf27cf8 100644 --- a/src/annbatch/utils.py +++ b/src/annbatch/utils.py @@ -1,11 +1,11 @@ from __future__ import annotations +import inspect import warnings from dataclasses import dataclass -from functools import cached_property +from functools import cached_property, wraps from importlib.util import find_spec -from itertools import islice -from typing import TYPE_CHECKING, Protocol +from typing import TYPE_CHECKING, Concatenate, Protocol import anndata as ad import numpy as np @@ -15,11 +15,42 @@ from .compat import CupyArray, CupyCSRMatrix, Tensor if TYPE_CHECKING: - from collections.abc import Generator, Iterable + from collections.abc import Callable + from annbatch.loader import Loader from annbatch.types import OutputInMemoryArray_T +def validate_sampler[**Param, RetType]( + method: Callable[Concatenate[Loader, Param], RetType], +) -> Callable[Concatenate[Loader, Param], RetType]: + """Decorator that validates n_obs before modifying state. + + Expects the first positional argument to be either: + - A single object with a `.shape` attribute + - A list of objects with `.shape` attributes + + The total n_obs is computed as sum of shape[0] values for a list of objects or the shape[0] value for a single object. + """ + sig = inspect.signature(method) + if len(sig.parameters) < 2: + raise ValueError("validate_sampler decorator expects at least two positional arguments after 'self'") + first_param_name = list(sig.parameters.keys())[1] + + @wraps(method) + def wrapper(self: Loader, *args: Param.args, **kwargs: Param.kwargs) -> RetType: + if len(args) > 0: + first_arg = args[0] + else: + first_arg = kwargs[first_param_name] + + n_obs = sum(item.shape[0] for item in first_arg) if isinstance(first_arg, list) else first_arg.shape[0] + self.batch_sampler.validate(n_obs) + return method(self, *args, **kwargs) + + return wrapper + + def split_given_size(a: np.ndarray, size: int) -> list[np.ndarray]: """Wrapper around `np.split` to split up an array into `size` chunks""" return np.split(a, np.arange(size, len(a), size)) @@ -34,14 +65,6 @@ class CSRContainer: dtype: np.dtype -def _batched[T](iterable: Iterable[T], n: int) -> Generator[list[T], None, None]: - if n < 1: - raise ValueError("n must be >= 1") - it = iter(iterable) - while batch := list(islice(it, n)): - yield batch - - # TODO: make this part of the public zarr or zarrs-python API. # We can do chunk coalescing in zarrs based on integer arrays, so I think # there would make sense with ezclump or similar. @@ -75,6 +98,13 @@ def _worker_info(self): return get_worker_info() return None + @property + def num_workers(self) -> int: + """Return the number of workers, or 1 if not in a worker context.""" + if self._worker_info is None: + return 1 + return self._worker_info.num_workers + @cached_property def _rng(self): if self._worker_info is None: diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 6afdd340..b6210786 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -12,7 +12,8 @@ import scipy.sparse as sp import zarr -from annbatch import Loader +from annbatch import ChunkSampler, Loader +from annbatch.abc import Sampler try: from cupy import ndarray as CupyArray @@ -146,13 +147,6 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: 50, preload_to_gpu, ], # batch size equal to in-memory size loading - [ - 10, - 5, - open_func, - 14, - preload_to_gpu, - ], # batch size does not divide in memory size evenly ] ] ], @@ -265,7 +259,7 @@ def test_to_torch( shuffle=False, chunk_size=5, preload_nchunks=10, - batch_size=42, + batch_size=25, preload_to_gpu=preload_to_gpu, return_index=True, to_torch=True, @@ -276,14 +270,16 @@ def test_to_torch( @pytest.mark.parametrize("drop_last", [True, False], ids=["drop", "kept"]) def test_drop_last(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], drop_last: bool): - # batch_size guaranteed to have leftovers to drop - batch_size = 42 + # batch_size guaranteed to have last batch to drop + chunk_size = 14 + preload_nchunks = 3 + batch_size = 21 zarr_path = next(adata_with_zarr_path_same_var_space[1].glob("*.zarr")) adata = ad.read_zarr(zarr_path) ds = Loader( shuffle=False, - chunk_size=5, - preload_nchunks=10, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, batch_size=batch_size, preload_to_gpu=False, return_index=True, @@ -297,12 +293,13 @@ def test_drop_last(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], batches += [batch["X"]] indices += [batch["index"]] total_obs = adata.shape[0] - leftover = total_obs % batch_size + remainder = total_obs % batch_size + assert remainder != 0, f"batch_size {batch_size} must not divide evenly into {total_obs} observations" for batch in batches[:-1]: assert batch.shape[0] == batch_size - assert batches[-1].shape[0] == (batch_size if drop_last else leftover) + assert batches[-1].shape[0] == (batch_size if drop_last else remainder) X = sp.vstack(batches).toarray() - assert X.shape[0] == (total_obs - leftover if drop_last else total_obs) + assert X.shape[0] == (total_obs - remainder if drop_last else total_obs) X_expected = adata[np.concatenate(indices)].layers["sparse"].toarray() np.testing.assert_allclose(X, X_expected) @@ -412,7 +409,7 @@ def test_default_data_structures( ): # format is a smoke test for sparse ds = Loader( - chunk_size=10, preload_nchunks=4, batch_size=22, shuffle=True, return_index=False, **kwargs + chunk_size=10, preload_nchunks=4, batch_size=20, shuffle=True, return_index=False, **kwargs ).add_dataset( **(open_sparse if issubclass(expected_cls, get_default_sparse()) else open_dense)( list(adata_with_zarr_path_same_var_space[1].iterdir())[0] @@ -426,9 +423,101 @@ def test_no_obs(simple_collection: tuple[ad.AnnData, DatasetCollection]): ds = Loader( chunk_size=10, preload_nchunks=4, - batch_size=22, + batch_size=20, ).use_collection( simple_collection[1], load_adata=lambda g: ad.AnnData(X=ad.io.sparse_dataset(g["layers"]["sparse"])), ) assert next(iter(ds))["obs"] is None + + +def test_add_dataset_validation_failure_preserves_state(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path]): + """Test that failed validation in add_dataset doesn't modify internal state.""" + + class FailOnSecondValidateSampler(Sampler): + """A sampler that fails validation after the first call.""" + + def __init__(self): + self._validate_count = 0 + + def validate(self, n_obs: int) -> None: + self._validate_count += 1 + if self._validate_count > 1: + raise ValueError("Validation failed on second add") + + @property + def batch_size(self) -> int: + return 10 + + @property + def worker_handle(self): + return None + + def _sample(self, n_obs: int, worker_handle=None): + yield from [] + + paths = list(adata_with_zarr_path_same_var_space[1].glob("*.zarr")) + data1 = open_dense(paths[0]) + data2 = open_dense(paths[1]) + + sampler = FailOnSecondValidateSampler() + loader = Loader(batch_sampler=sampler, preload_to_gpu=False, to_torch=False) + + # First add succeeds + loader.add_dataset(**data1) + + # Capture state before failed add + n_datasets_before = len(loader._train_datasets) + shapes_before = loader._shapes.copy() + + # Second add should fail validation BEFORE modifying state + with pytest.raises(ValueError, match="Validation failed on second add"): + loader.add_dataset(**data2) + + # State should be unchanged + assert len(loader._train_datasets) == n_datasets_before + assert loader._shapes == shapes_before + + +def test_given_batch_sampler_samples_subset_of_combined_datasets( + adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], +): + """Test given batch sampler that samples only a specific range from combined datasets. + + Uses multiple zarr files from fixture, combines them, and samples a subset. + """ + paths = list(adata_with_zarr_path_same_var_space[1].glob("*.zarr")) + datas = [open_dense(p) for p in paths] + + # Calculate expected n_obs before creating loader + expected_n_obs = sum(d["dataset"].shape[0] for d in datas) + start_idx, end_idx = expected_n_obs // 4, expected_n_obs // 2 + + sampler = ChunkSampler( + mask=slice(start_idx, end_idx), + batch_size=10, + chunk_size=10, + preload_nchunks=2, + ) + + loader = Loader(batch_sampler=sampler, preload_to_gpu=False, to_torch=False, return_index=True) + loader.add_datasets(**concat(datas)) + + # Collect all yielded indices + all_indices = [] + for batch in loader: + all_indices.append(batch["index"]) + + stacked_indices = np.concatenate(all_indices) + + # Verify we got exactly the expected range + assert set(stacked_indices) == set(range(start_idx, end_idx)) + assert len(stacked_indices) == end_idx - start_idx + + +@pytest.mark.parametrize("kwarg", [{"chunk_size": 10}, {"batch_size": 10}]) +def test_cannot_provide_batch_sampler_with_sampler_args(kwarg): + """Test that providing batch_sampler with sampler args raises in constructor.""" + chunk_sampler = ChunkSampler(mask=slice(0, 50), batch_size=5, chunk_size=10, preload_nchunks=2) + with pytest.raises(ValueError, match="Cannot specify.*when providing a custom sampler"): + Loader(batch_sampler=chunk_sampler, preload_to_gpu=False, to_torch=False, **kwarg) diff --git a/tests/test_sampler.py b/tests/test_sampler.py new file mode 100644 index 00000000..1ef6b2d9 --- /dev/null +++ b/tests/test_sampler.py @@ -0,0 +1,245 @@ +"""Tests for ChunkSampler.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from annbatch import ChunkSampler + +# TODO(selmanozleyen): Check for the validation within the _get_worker_handle method. Mock worker handle wouldn't make sense +# but overall one must also think about how validation can't be independent of the worker handle. + + +def collect_indices(sampler, n_obs): + """Helper to collect all indices from sampler.""" + indices = [] + for load_request in sampler.sample(n_obs): + assert len(load_request["splits"]) > 0, "splits must be non-empty" + assert all(len(s) > 0 for s in load_request["splits"]), "splits must be non-empty" + for s in load_request["chunks"]: + indices.extend(range(s.start, s.stop)) + return indices + + +class MockWorkerHandle: + """Simulates torch worker context for testing without actual DataLoader.""" + + def __init__(self, worker_id: int, num_workers: int, seed: int = 42): + self.worker_id = worker_id + self._num_workers = num_workers + self._rng = np.random.default_rng(seed) + + @property + def num_workers(self) -> int: + return self._num_workers + + def shuffle(self, obj): + self._rng.shuffle(obj) + + def get_part_for_worker(self, obj: np.ndarray) -> np.ndarray: + return np.array_split(obj, self._num_workers)[self.worker_id] + + +class ChunkSamplerWithMockWorkerHandle(ChunkSampler): + def set_worker_handle(self, worker_handle: MockWorkerHandle): + self.worker_handle = worker_handle + + def _get_worker_handle(self) -> MockWorkerHandle | None: + return self.worker_handle + + +# ============================================================================= +# Mask coverage tests +# ============================================================================= + + +@pytest.mark.parametrize( + "n_obs,chunk_size,start,stop,batch_size,preload_nchunks,shuffle", + [ + # Basic full dataset + pytest.param(100, 10, None, None, 5, 2, False, id="full_dataset"), + # mask.start only + pytest.param(100, 10, 30, None, 5, 2, False, id="start_at_chunk_boundary"), + pytest.param(100, 10, 35, None, 5, 2, False, id="start_not_at_chunk_boundary"), + pytest.param(120, 12, 90, None, 3, 1, False, id="start_near_end"), + pytest.param(100, 10, 20, None, 5, 2, False, id="start_mask_stop_none"), + # mask.stop only + pytest.param(50, 10, None, 50, 5, 2, False, id="stop_at_chunk_boundary"), + pytest.param(47, 10, None, 47, 5, 2, False, id="stop_not_at_chunk_boundary"), + # Both bounds + pytest.param(60, 10, 20, 60, 5, 2, False, id="both_at_chunk_boundaries"), + pytest.param(67, 10, 23, 67, 5, 2, False, id="both_not_at_chunk_boundaries"), + pytest.param(28, 10, 22, 28, 2, 1, False, id="single_chunk_span"), + pytest.param(100, 10, 15, 85, 5, 2, False, id="both_non_aligned"), + pytest.param(100, 10, 20, 80, 5, 2, False, id="both_aligned"), + # Edge cases + pytest.param(100, 10, 95, 100, 10, 1, False, id="very_small_mask"), + # With shuffle + pytest.param(100, 10, 30, None, 5, 2, True, id="shuffle_with_start"), + pytest.param(75, 10, 25, 75, 5, 2, True, id="shuffle_with_both_bounds"), + ], +) +def test_mask_coverage(n_obs, chunk_size, start, stop, batch_size, preload_nchunks, shuffle): + """Test sampler covers exactly the expected range, and ordering is correct when not shuffled.""" + sampler = ChunkSampler( + mask=slice(start, stop), + batch_size=batch_size, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + shuffle=shuffle, + rng=np.random.default_rng(42) if shuffle else None, + ) + + expected_start = start if start is not None else 0 + expected_stop = stop if stop is not None else n_obs + expected_indices = list(range(expected_start, expected_stop)) + + all_indices = collect_indices(sampler, n_obs) + + # Always check coverage + if shuffle: + assert set(all_indices) == set(expected_indices), "Sampler should cover all expected indices" + else: + assert all_indices == expected_indices, f"all_indices: {all_indices} != expected_indices: {expected_indices}" + + sampler.validate(n_obs) + + +def test_batch_sizes_match_expected_pattern(): + """Test that batch sizes match expected pattern.""" + n_obs, chunk_size, preload_nchunks, batch_size = 103, 10, 2, 5 + # last chunk is incomplete and is also the last batch in the load request + expected_last_chunk_size = 3 + expected_last_batch_size = 3 + expected_last_num_splits = 1 + expected_num_load_requests = 6 + sampler = ChunkSampler( + mask=slice(0, None), + batch_size=batch_size, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + ) + + all_requests = list(sampler.sample(n_obs)) + assert len(all_requests) == expected_num_load_requests + for req_idx, load_request in enumerate(all_requests[:-1]): + assert all(chunk.stop - chunk.start == chunk_size for chunk in load_request["chunks"]), ( + f"chunk size mismatch at request {req_idx}:", + f"chunks: {load_request['chunks']}", + ) + assert all(len(split) == batch_size for split in load_request["splits"]), ( + f"batch size mismatch at request {req_idx}:splits: {load_request['splits']}" + ) + last_request = all_requests[-1] + assert len(last_request["splits"]) == expected_last_num_splits, "last request num splits mismatch" + assert all(chunk.stop - chunk.start == expected_last_chunk_size for chunk in last_request["chunks"]), ( + "last request chunk size mismatch", + f"chunks: {last_request['chunks']}", + ) + assert all(len(split) == expected_last_batch_size for split in last_request["splits"]), ( + "last request batch size mismatch", + f"splits: {last_request['splits']}", + ) + + +# ============================================================================= +# Worker tests +# ============================================================================= + + +@pytest.mark.parametrize( + "n_obs,chunk_size,preload_nchunks,batch_size,num_workers,drop_last", + [ + pytest.param(200, 10, 2, 10, 2, True, id="two_workers"), + pytest.param(300, 10, 3, 10, 3, True, id="three_workers"), + # checks how it works with batch_size=1 since it is the default case and might be used in torch later + pytest.param(600, 10, 4, 1, 4, False, id="batch_size_one_torch_dataloader_case"), + pytest.param(100, 10, 4, 1, 1, False, id="batch_size_one_single_worker_case"), + pytest.param(95, 10, 4, 1, 1, False, id="batch_size_one_non_divisible_obs_case"), + pytest.param(100, 10, 4, 1, 3, False, id="batch_size_one_three_workers_uneven_case"), + ], +) +def test_workers_cover_full_dataset_without_overlap( + n_obs, chunk_size, preload_nchunks, batch_size, num_workers, drop_last +): + """Test workers cover full dataset without overlap. Also checks if there are empty splits in any of the load requests.""" + all_worker_indices = [] + for worker_id in range(num_workers): + worker_handle = MockWorkerHandle(worker_id, num_workers) + sampler = ChunkSamplerWithMockWorkerHandle( + mask=slice(0, None), + batch_size=batch_size, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + drop_last=drop_last, + ) + sampler.set_worker_handle(worker_handle) + all_worker_indices.append(collect_indices(sampler, n_obs)) + + # All workers should have disjoint chunks + for i in range(num_workers): + for j in range(i + 1, num_workers): + assert set(all_worker_indices[i]).isdisjoint(all_worker_indices[j]) + + # Together they cover the full dataset + assert set().union(*all_worker_indices) == set(range(n_obs)) + + +# ============================================================================= +# Validation tests +# ============================================================================= + + +@pytest.mark.parametrize( + "mask,n_obs,error_match", + [ + pytest.param(slice(0, 100), 100, None, id="valid_config"), + pytest.param(slice(0, 200), 100, "mask.stop.*exceeds loader n_obs", id="stop_exceeds_n_obs"), + ], +) +def test_validate(mask, n_obs, error_match): + """Test validate behavior for various configurations.""" + sampler = ChunkSampler(mask=mask, batch_size=5, chunk_size=10, preload_nchunks=2) + if error_match: + with pytest.raises(ValueError, match=error_match): + sampler.validate(n_obs=n_obs) + else: + sampler.validate(n_obs=n_obs) + + +@pytest.mark.parametrize( + "mask,error_match", + [ + pytest.param(slice(-1, 100), "mask.start must be >= 0", id="negative_start"), + pytest.param(slice(50, 50), "mask.start must be < mask.stop", id="start_equals_stop"), + pytest.param(slice(100, 50), "mask.start must be < mask.stop", id="start_greater_than_stop"), + pytest.param(slice(0, 100, 2), "mask.step must be 1, but got 2", id="step_not_one"), + ], +) +def test_invalid_mask_raises(mask, error_match): + """Test that invalid mask configurations raise ValueError at construction.""" + with pytest.raises(ValueError, match=error_match): + ChunkSampler(mask=mask, batch_size=5, chunk_size=10, preload_nchunks=2) + + +# ============================================================================= +# n_obs change tests (To verify nothing is cached between calls.) +# ============================================================================= + + +@pytest.mark.parametrize( + "n_obs_values,expected_ranges", + [ + pytest.param([50, 100], [range(50), range(100)], id="increase_changes_result"), + pytest.param([100, 100], [range(100), range(100)], id="same_gives_same_coverage"), + ], +) +def test_n_obs_coverage(n_obs_values, expected_ranges): + """Test that n_obs changes affect sampling results appropriately.""" + sampler = ChunkSampler(mask=slice(0, None), batch_size=5, chunk_size=10, preload_nchunks=2, shuffle=False) + + results = [collect_indices(sampler, n) for n in n_obs_values] + + for result, expected in zip(results, expected_ranges, strict=True): + assert result == list(expected), f"result: {result} != expected: {expected}"