diff --git a/.github/workflows/test-gpu.yaml b/.github/workflows/test-gpu.yaml index c60c2f9d..c5ffff45 100644 --- a/.github/workflows/test-gpu.yaml +++ b/.github/workflows/test-gpu.yaml @@ -41,7 +41,7 @@ jobs: runs-on: "cirun-aws-gpu--${{ github.run_id }}" strategy: matrix: - extras: ["torch,cupy-cuda12", "torch", "cupy-cuda12"] + extras: ["torch,cupy-cuda12", "torch", "cupy-cuda12", "jax-cuda12"] # Setting a timeout of 30 minutes, as the AWS costs money # At time of writing, a typical run takes about 5 minutes timeout-minutes: 30 diff --git a/README.md b/README.md index b97dc95e..6780b5cc 100644 --- a/README.md +++ b/README.md @@ -131,7 +131,7 @@ with ad.settings.override(remove_unused_categories=False): batch_size=4096, chunk_size=32, preload_nchunks=256, - to_torch=True + to="torch" ) # `use_collection` automatically uses the on-disk `X` and full `obs` in the `Loader` # but the `load_adata` arg can override this behavior diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb index b5a61d1c..ae95984a 100644 --- a/docs/notebooks/example.ipynb +++ b/docs/notebooks/example.ipynb @@ -261,7 +261,7 @@ " preload_nchunks=32, # Number of chunks to preload + shuffle - default settings should work well\n", " # If True, preloaded chunks are moved to GPU memory via `cupy`, which can put more pressure on GPU memory but will accelerate loading ~20%\n", " preload_to_gpu=False,\n", - " to_torch=True,\n", + " to=\"torch\",\n", ")\n", "\n", "# Add in the shuffled data that should be used for training.\n", diff --git a/pyproject.toml b/pyproject.toml index 3df46cbd..9147c014 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,13 @@ optional-dependencies.doc = [ "sphinxcontrib-bibtex>=1", "sphinxext-opengraph", ] +optional-dependencies.jax = [ "jax" ] +optional-dependencies.jax-cuda12 = [ + "jax[cuda12]", +] +optional-dependencies.jax-cuda13 = [ + "jax[cuda13]", +] optional-dependencies.test = [ "annbatch[zarrs]", "coverage", @@ -88,7 +95,7 @@ envs.docs.scripts.clean = "git clean -fdX -- {args:docs}" envs.hatch-test.features = [ "test" ] envs.hatch-test.python = "3.14" envs.hatch-test.matrix = [ - { deps = [ "min-low", "pre", "torch", "min-high" ] }, + { deps = [ "min-low", "pre", "torch", "min-high", "jax" ] }, ] # If the matrix variable `deps` is set to "pre", # set the environment variable `UV_PRERELEASE` to "allow". @@ -106,6 +113,7 @@ envs.hatch-test.overrides.matrix.deps.python = [ ] envs.hatch-test.overrides.matrix.deps.features = [ { if = [ "torch" ], value = "torch" }, + { if = [ "jax" ], value = "jax" }, ] [tool.ruff] diff --git a/src/annbatch/compat.py b/src/annbatch/compat.py index 3a7b5918..987c67a6 100644 --- a/src/annbatch/compat.py +++ b/src/annbatch/compat.py @@ -19,3 +19,10 @@ from cupyx.scipy.sparse import csr_matrix as CupyCSRMatrix else: CupyCSRMatrix = type("csr_matrix", (), {"__module__": "cupyx.scipy.sparse"}) + +if TYPE_CHECKING or find_spec("jax"): + from jax import Array as JaxArray + from jax.experimental.sparse import CSR as JAXCsrMatrix +else: + JAXCsrMatrix = type("CSR", (), {"__module__": "jax.experimental.sparse"}) + JaxArray = type("Array", (), {"__module__": "jax"}) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 2cca7615..2966a991 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -5,8 +5,9 @@ from functools import singledispatchmethod from importlib.metadata import version from importlib.util import find_spec -from itertools import accumulate, chain, pairwise +from itertools import accumulate from typing import TYPE_CHECKING, Literal, NamedTuple, Self, cast +from warnings import warn import anndata as ad import numpy as np @@ -18,19 +19,17 @@ from zarr import Array as ZarrArray from annbatch.samplers import RandomSampler, SequentialSampler -from annbatch.types import BackingArray_T, InputInMemoryArray_T, LoaderOutput, OutputInMemoryArray_T +from annbatch.types import BackingArray_T, Dense_T, LoaderOutput, OutputInMemoryArray_T from annbatch.utils import ( CSRContainer, MultiBasicIndexer, check_lt_1, check_var_shapes, - interval_indexer_from_slices, load_x_and_obs_and_var, - to_torch, validate_sampler, ) -from .compat import IterableDataset +from .compat import CupyArray, IterableDataset, JaxArray, Tensor if TYPE_CHECKING: from collections.abc import Callable, Iterator @@ -42,8 +41,7 @@ # TODO: remove after sphinx 9 - myst compat BackingArray = BackingArray_T OutputInMemoryArray = OutputInMemoryArray_T - InputInMemoryArray = InputInMemoryArray_T - + Dense = Dense_T type concat_strategies = Literal["concat-shuffle", "shuffle-concat"] @@ -55,19 +53,7 @@ class CSRDatasetElems(NamedTuple): data: zarr.AsyncArray -def _cupy_dtype(dtype: np.dtype) -> np.dtype: - if dtype in {np.dtype("float32"), np.dtype("float64"), np.dtype("bool")}: - return dtype - if dtype.itemsize < 4: - return np.dtype("float32") - return np.dtype("float64") - - -class Loader[ - BackingArray: BackingArray_T, - InputInMemoryArray: InputInMemoryArray_T, - OutputInMemoryArray: OutputInMemoryArray_T, -](IterableDataset): +class Loader[BackingArray: BackingArray_T, OutputInMemoryArray: OutputInMemoryArray_T, Dense: Dense_T](IterableDataset): """A loader for on-disk data anndata stores. This loader batches together slice requests to the underlying stores to achieve higher performance. @@ -119,6 +105,10 @@ class Loader[ Data transferred should be 0-copy independent of source, and transfer to cuda when applicable is non-blocking. Defaults to True if `torch` is installed. concat_strategy + .. deprecated:: 0.1.4 + We now write directly from disk to the in-memory buffer from which data is yielded. + This has optimal memory and compute performance obviating the need for this argument. + It will be removed in the next minor release. The strategy for how in-memory, preloaded data should be concatenated and yielded. With `concat-shuffle`, preloaded data is concatenated and then subsetted/shuffled (higher memory usage, but faster, at least for sparse data) With `shuffle-concat`, preloaded data is first shuffled/subsetted chunk-by-chunk and then concatenated (lower memory usage, potentially faster for dense data) @@ -156,10 +146,9 @@ class Loader[ _return_index: bool = False _shapes: list[tuple[int, int]] _preload_to_gpu: bool = True - _to_torch: bool = True + _to: Literal["torch", "cupy", "jax"] | None = None _dataset_elem_cache: dict[int, CSRDatasetElems] _batch_sampler: Sampler - _concat_strategy: None | concat_strategies = None _dataset_intervals: pd.IntervalIndex | None = None _collection_added: bool = False @@ -172,12 +161,52 @@ def __init__( shuffle: bool | None = None, return_index: bool = False, batch_size: int | None = None, - preload_to_gpu: bool = find_spec("cupy") is not None, + preload_to_gpu: bool | None = None, drop_last: bool | None = None, - to_torch: bool = find_spec("torch") is not None, + to_torch: bool | None = None, concat_strategy: None | concat_strategies = None, rng: np.random.Generator | None = None, + to: Literal["torch", "cupy", "jax"] | None = None, ): + if concat_strategy is not None: + warn( + "concat_strategy has no effect and will be removed in an upcoming release thanks to writing directly to output buffers.", + DeprecationWarning, + stacklevel=2, + ) + if preload_to_gpu is None or preload_to_gpu: + if to is None: + if find_spec("cupy") is None: + if preload_to_gpu: + raise ImportError( + "Install cupy using our extras to recover old behavior of automatic cupy usage. Otherwise, don't use this argument without to." + ) + to = None + else: + warn( + "preload_to_gpu will no longer automatically put the array on the gpu.", + FutureWarning, + stacklevel=2, + ) + to = "cupy" + if to_torch is not None: + warn( + 'to_torch will be replaced by `to = "torch"`.', + DeprecationWarning, + stacklevel=2, + ) + to_torch = "torch" if to_torch else None + else: + to_torch = "torch" if find_spec("torch") is not None else None + if to is not None: + if find_spec(to) is not None: + self._to = to + else: + raise ImportError( + f"Try `pip install annbatch[{to}-cudaXX]` (with XX replaced by your cuda version if cuda is required, see our `pyproject.toml`). Found {to} argument to `Loader`." + ) + else: + self._to = to_torch # args that are passed after resolving defaults core_sampler_args = { "chunk_size": chunk_size, @@ -205,10 +234,6 @@ def __init__( ) else: self._batch_sampler = SequentialSampler(**resolved_core_args) - if to_torch and not find_spec("torch"): - raise ImportError("Could not find torch dependency. Try `pip install torch`.") - if preload_to_gpu and not find_spec("cupy"): - raise ImportError("Follow the directions at https://docs.cupy.dev/en/stable/install.html to install cupy.") self._return_index = return_index self._preload_to_gpu = preload_to_gpu @@ -216,7 +241,6 @@ def __init__( self._train_datasets = [] self._shapes = [] self._dataset_elem_cache = {} - self._concat_strategy = concat_strategy def __len__(self) -> int: return self._batch_sampler.n_iters(self.n_obs) @@ -447,7 +471,7 @@ def _add_dataset_unchecked( ) if not isinstance(dataset, BackingArray_T.__value__): raise TypeError(f"Cannot add dataset of type {type(dataset)}") - if (is_sparse := isinstance(dataset, ad.abc.CSRDataset)) and not dataset.backend == "zarr": + if isinstance(dataset, ad.abc.CSRDataset) and not dataset.backend == "zarr": raise TypeError( "Cannot add CSRDataset backed by h5ad at the moment: see https://github.com/zarr-developers/VirtualiZarr/pull/790" ) @@ -471,11 +495,6 @@ def _add_dataset_unchecked( "All datasets must have identical var DataFrames. " "The var of the new dataset does not match the existing var." ) - if self._concat_strategy is None: - if is_sparse: - self._concat_strategy = "concat-shuffle" - else: - self._concat_strategy = "shuffle-concat" self._update_dataset_intervals() return self @@ -559,40 +578,117 @@ def _slices_to_slices_with_array_index( dataset_index_to_slices_sorted[k] = dataset_index_to_slices[k] return dataset_index_to_slices_sorted - def _get_kwargs_for_zarr_fetching(self, z: zarr.Array, indexer_shape: tuple[int, ...]) -> dict: - buffer_prototype = zarr.core.buffer.default_buffer_prototype() - kwargs = {"prototype": buffer_prototype} - if self._preload_to_gpu: - import cupyx as cpx + def _allocate_out[OutT: CSRContainer | Dense]( + self, dataset_index_to_slices: OrderedDict[int, list[slice]] + ) -> tuple[OutT, OutT]: + """Preallocate a single contiguous output buffer covering all datasets and slices. - kwargs["out"] = buffer_prototype.nd_buffer(cpx.empty_pinned(indexer_shape, z.dtype)) - return kwargs + For sparse data the buffer is a :class:`~annbatch.utils.CSRContainer` whose ``data`` + and ``indices`` arrays span the total number of non-zeros (derived from the cached + ``indptr``) and whose ``indptr`` array spans the total number of rows + 1. + For dense data it is a plain :class:`numpy.ndarray` of shape + ``(total_rows, n_var)``. + + Must be called after :meth:`_ensure_sparse_cache` for sparse datasets. + """ + total_rows = sum(s.stop - s.start for slices in dataset_index_to_slices.values() for s in slices) + + def _alloc(shape: tuple[int, ...], dtype: np.dtype) -> Dense: + match self._to: + case "torch": + import torch + + # lol how do people put up with this? + for dtype_str in [ + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", + "float16", + "float32", + "float64", + ]: + if str(dtype) == dtype_str: + res = torch.empty(shape, dtype=getattr(torch, dtype_str), device="cpu", pin_memory=True) + return res + raise TypeError(f"Unsupported dtype for conversion to torch: {dtype}") + case "cupy": + import cupyx as cpx + + return cpx.empty_pinned(shape, dtype) + case None | "jax": + # I can't figure out if jax supports this or not, but probably not since their memory is immutable. + return np.empty(shape, dtype) + + if issubclass(self.dataset_type, ad.abc.CSRDataset): + total_nnz = sum( + int(self._dataset_elem_cache[idx].indptr[s.stop] - self._dataset_elem_cache[idx].indptr[s.start]) + for idx, slices in dataset_index_to_slices.items() + for s in slices + ) + first_idx = next(iter(dataset_index_to_slices)) + data_dtype = self._dataset_elem_cache[first_idx].data.dtype + indices_dtype = self._dataset_elem_cache[first_idx].indices.dtype + indptr_dtype = self._dataset_elem_cache[first_idx].indptr.dtype + return_container = CSRContainer( + elems=( + _alloc((total_nnz,), data_dtype), + _alloc((total_nnz,), indices_dtype), + np.empty(total_rows + 1, dtype=indptr_dtype), + ), + shape=(total_rows, self.n_var), + dtype=data_dtype, + ) + # Can't roundtrip pinned memory in torch! + write_container = CSRContainer( + elems=return_container.elems + if self._to != "torch" + else tuple(e.numpy() for e in return_container.elems), + shape=(total_rows, self.n_var), + dtype=data_dtype, + ) + return return_container, write_container + else: + first_idx = next(iter(dataset_index_to_slices)) + dtype = self._train_datasets[first_idx].dtype + shape_res = self._train_datasets[first_idx].shape[1:] + return_container = _alloc((total_rows, *shape_res), dtype) + return return_container, return_container if self._to != "torch" else return_container.numpy() @singledispatchmethod - async def _fetch_data(self, dataset: ZarrArray | CSRDatasetElems, slices: list[slice]) -> InputInMemoryArray: - """Fetch data from an on-disk store. + async def _fetch_data( + self, + dataset: ZarrArray | CSRDatasetElems, + slices: list[slice], + out: CSRContainer | JaxArray | np.ndarray | CupyArray | Tensor, + ) -> None: + """Fetch data from an on-disk store into a preallocated buffer. Parameters ---------- dataset The underlying store. slices - The slices to fetch - - Returns - ------- - The sparse or dense fetched data. + The slices to fetch. + out + Preallocated buffer to write into — a contiguous view of the full + output buffer allocated by :meth:`_allocate_out`. Raises ------ NotImplementedError - If the dataset is not recognized. + If the dataset type is not recognised. """ raise NotImplementedError(f"Cannot fetch data for type {type(dataset)}") @_fetch_data.register - async def _fetch_data_dense(self, dataset: ZarrArray, slices: list[slice]) -> np.ndarray: - print(Version(version("zarr")) <= Version("3.1.6")) + async def _fetch_data_dense( + self, dataset: ZarrArray, slices: list[slice], out: JaxArray | np.ndarray | CupyArray | Tensor + ) -> None: indexer = MultiBasicIndexer( [ zarr.core.indexing.BasicIndexer( @@ -605,13 +701,12 @@ async def _fetch_data_dense(self, dataset: ZarrArray, slices: list[slice]) -> np for s in slices ] ) - res = cast( - "np.ndarray", - await dataset._async_array._get_selection( - indexer, **self._get_kwargs_for_zarr_fetching(dataset, indexer.shape) - ), + buffer_prototype = zarr.core.buffer.default_buffer_prototype() + await dataset._async_array._get_selection( + indexer, + prototype=buffer_prototype, + out=buffer_prototype.nd_buffer(out), ) - return res async def _create_sparse_elems(self, idx: int) -> CSRDatasetElems: """Fetch the in-memory indptr, and backed indices and data for a given dataset index. @@ -672,12 +767,12 @@ async def _fetch_data_sparse( self, dataset: CSRDatasetElems, slices: list[slice], - ) -> CSRContainer: + out: CSRContainer, + ) -> None: # See https://github.com/scverse/anndata/blob/361325fc621887bf4f381e9412b150fcff599ff7/src/anndata/_core/sparse_dataset.py#L272-L295 # for the inspiration of this function. indptr, indices, data = dataset - indptr_indices = [indptr[slice(s.start, s.stop + 1)] for s in slices] - indptr_limits = [slice(i[0], i[-1]) for i in indptr_indices] + indptr_limits = [slice(int(indptr[s.start]), int(indptr[s.stop])) for s in slices] indexer = MultiBasicIndexer( [ zarr.core.indexing.BasicIndexer( @@ -691,61 +786,105 @@ async def _fetch_data_sparse( ] ) - data_np, indices_np = await asyncio.gather( - *( - z._get_selection(indexer, **self._get_kwargs_for_zarr_fetching(z, indexer.shape)) - for z in [data, indices] - ) - ) - gaps = (s1.start - s0.stop for s0, s1 in pairwise(indptr_limits)) - offsets = accumulate(chain([indptr_limits[0].start], gaps)) - start_indptr = indptr_indices[0] - next(offsets) - if len(slices) < 2: # there is only one slice so no need to concatenate - return CSRContainer( - elems=(data_np, indices_np, start_indptr), - shape=(start_indptr.shape[0] - 1, self.n_var), - dtype=data_np.dtype, - ) - end_indptr = np.concatenate([s[1:] - o for s, o in zip(indptr_indices[1:], offsets, strict=True)]) - indptr_np = np.concatenate([start_indptr, end_indptr]) - return CSRContainer( - elems=(data_np, indices_np, indptr_np), - shape=(indptr_np.shape[0] - 1, self.n_var), - dtype=data_np.dtype, + buffer_prototype = zarr.core.buffer.default_buffer_prototype() + await asyncio.gather( + data._get_selection( + indexer, + prototype=buffer_prototype, + out=buffer_prototype.nd_buffer(out.elems[0]), + ), + indices._get_selection( + indexer, + prototype=buffer_prototype, + out=buffer_prototype.nd_buffer(out.elems[1]), + ), ) async def _index_datasets( self, dataset_index_to_slices: OrderedDict[int, list[slice]], - ) -> list[InputInMemoryArray]: - """Helper function meant to encapsulate asynchronous calls so that we can use the same event loop as zarr. + ) -> CSRContainer | Dense: + """Preallocate one output buffer, dispatch concurrent fetches into per-dataset views, then return the buffer. Parameters ---------- dataset_index_to_slices A lookup of the list-placement index of a dataset to the request slices. - fetch_data - The function to do the fetching for a given slice-dataset index pair. """ - tasks = [] - if is_sparse := issubclass(self.dataset_type, ad.abc.CSRDataset): + is_sparse = issubclass(self.dataset_type, ad.abc.CSRDataset) + if is_sparse: await self._ensure_sparse_cache() - for dataset_idx in dataset_index_to_slices.keys(): + + return_container, write_container = self._allocate_out(dataset_index_to_slices) + + tasks = [] + row_offset = 0 + nnz_offset = 0 + + for dataset_idx, slices in dataset_index_to_slices.items(): + nrows = sum(s.stop - s.start for s in slices) + if is_sparse: + cached_indptr = self._dataset_elem_cache[dataset_idx].indptr + nnnz = sum(int(cached_indptr[s.stop] - cached_indptr[s.start]) for s in slices) + out_view: CSRContainer | np.ndarray = CSRContainer( + elems=( + write_container.elems[0][nnz_offset : nnz_offset + nnnz], + write_container.elems[1][nnz_offset : nnz_offset + nnnz], + write_container.elems[2][row_offset : row_offset + nrows + 1], + ), + shape=(nrows, self.n_var), + dtype=write_container.dtype, + ) + nnz_offset += nnnz + else: + out_view = write_container[row_offset : row_offset + nrows] + tasks.append( self._fetch_data( self._get_elem_from_cache(dataset_idx) if is_sparse else self._train_datasets[dataset_idx], - dataset_index_to_slices[dataset_idx], + slices, + out_view, + ) + ) + row_offset += nrows + + await asyncio.gather(*tasks) + + if is_sparse: + running_nnz = 0 + row_pos = 0 + write_container.elems[2][0] = 0 + for dataset_idx, slices in dataset_index_to_slices.items(): + cached_indptr = self._dataset_elem_cache[dataset_idx].indptr + for s in slices: + nrows_s = s.stop - s.start + write_container.elems[2][row_pos + 1 : row_pos + nrows_s + 1] = ( + cached_indptr[s.start + 1 : s.stop + 1] - cached_indptr[s.start] + running_nnz + ) + running_nnz += int(cached_indptr[s.stop] - cached_indptr[s.start]) + row_pos += nrows_s + if self._to == "jax": + import jax.numpy as jnp + + return_container = ( + CSRContainer( + elems=tuple(jnp.array(e, copy=False) for e in return_container.elems), + dtype=return_container.dtype, + shape=return_container.shape, ) + if isinstance(return_container, CSRContainer) + else jnp.array(return_container, copy=False) ) - return await asyncio.gather(*tasks) + return return_container def __iter__( self, ) -> Iterator[LoaderOutput[OutputInMemoryArray]]: """Iterate over the on-disk datasets. - Data is fetched from `N` on-disk anndata objects, returning `N` blocks which are then either concatenated immediately and then yieled/shuffled, or subsetted to shuffled subsets and then concatenated/yielded. - See `concat_strategy` initialization argument for more information. + Data for all requested datasets is fetched concurrently into a single preallocated + buffer, converted to the output format once, and then yielded as direct row-index + subsets — no vstack or intermediate concatenation is required. Yields ------ @@ -755,74 +894,40 @@ def __iter__( [len(self._train_datasets), self.n_obs], ["Number of datasets", "Number of observations"], ) - mod = self._sp_module if issubclass(self.dataset_type, ad.abc.CSRDataset) else np + is_sparse = issubclass(self.dataset_type, ad.abc.CSRDataset) for load_request in self._batch_sampler.sample(self.n_obs): chunks_to_load = load_request["chunks"] splits = load_request["splits"] - # Sampler yields a list of slices that sum to batch_size dataset_index_to_slices = self._slices_to_slices_with_array_index(chunks_to_load, use_original_space=False) - # Fetch the data over slices - chunks: list[InputInMemoryArray] = zsync.sync(self._index_datasets(dataset_index_to_slices)) - in_memory_data = self._accumulate_chunks(chunks) - # Accumulate labels and indices if possible + + raw_out: CSRContainer | Dense = zsync.sync(self._index_datasets(dataset_index_to_slices)) + + in_memory_data = raw_out.to(self._to) if isinstance(raw_out, CSRContainer) else raw_out + if self._preload_to_gpu: + match self._to: + case "torch": + in_memory_data = in_memory_data.cuda(non_blocking=True) + case "jax": + import jax + + gpu = jax.devices("gpu")[0] + in_memory_data = jax.device_put(in_memory_data, device=gpu) + concatenated_obs: None | pd.DataFrame = self._maybe_accumulate_obs(dataset_index_to_slices) in_memory_indices: None | np.ndarray = self._maybe_accumulate_indices(chunks_to_load) - if self._concat_strategy == "concat-shuffle": - in_memory_data = mod.vstack(in_memory_data) - for split in splits: - data = in_memory_data[split] - yield { - "X": data if not self._to_torch else to_torch(data, self._preload_to_gpu), - "obs": concatenated_obs.iloc[split] if concatenated_obs is not None else None, - "var": self._var, - "index": in_memory_indices[split] if in_memory_indices is not None else None, - } - elif self._concat_strategy == "shuffle-concat": - # An IntervalIndexer with start-stop bounds of each chunk's dataset - dataset_interval_indexer = interval_indexer_from_slices(dataset_index_to_slices.values()) - for split in splits: - sorted_split = np.sort(split) - # Get the index of the dataset for the given split relative to the in-memory data - dataset_locs = dataset_interval_indexer.get_indexer_for(sorted_split) - # Get the left bound of that dataset relative to the in-memory data - offsets = dataset_interval_indexer.left[dataset_locs] - # Stack the chunks in dataset order, offseting each split by its dataset's leftmost in-memory bound - data = mod.vstack( - [ - chunk[sorted_split[dataset_locs == i] - offsets[dataset_locs == i]] - for i, chunk in enumerate(in_memory_data) - ] - ) - yield { - "X": data if not self._to_torch else to_torch(data, self._preload_to_gpu), - "obs": concatenated_obs.iloc[sorted_split] if concatenated_obs is not None else None, - "var": self._var, - "index": in_memory_indices[sorted_split] if in_memory_indices is not None else None, - } - else: # pragma: no cover - raise RuntimeError( - f"Found unrecognized concatenation strategy at iteration time {self._concat_strategy}. Please open an issue" - ) + for split in splits: + data = in_memory_data[split] + yield { + "X": data, + "obs": concatenated_obs.iloc[split] if concatenated_obs is not None else None, + "var": self._var, + "index": in_memory_indices[split] if in_memory_indices is not None else None, + } + # https://github.com/cupy/cupy/issues/9625 - if self._preload_to_gpu and issubclass(self.dataset_type, ad.abc.CSRDataset): + if self._preload_to_gpu and is_sparse: self._np_module.get_default_memory_pool().free_all_blocks() - def _accumulate_chunks(self, chunks: list[InputInMemoryArray]) -> list[OutputInMemoryArray_T]: - """Convert fetched chunks to output array format (CSR or ndarray).""" - result: list[OutputInMemoryArray_T] = [] - for chunk in chunks: - if isinstance(chunk, CSRContainer): - result.append( - self._sp_module.csr_matrix( - tuple(self._np_module.asarray(e) for e in chunk.elems), - shape=chunk.shape, - dtype=_cupy_dtype(chunk.dtype) if self._preload_to_gpu else chunk.dtype, - ) - ) - else: - result.append(self._np_module.asarray(chunk)) - return result - def _maybe_accumulate_obs(self, dataset_index_to_slices: OrderedDict[int, list[slice]]) -> pd.DataFrame | None: """Gather obs labels for the loaded slices if possible.""" if self._obs is None: diff --git a/src/annbatch/types.py b/src/annbatch/types.py index 1f0c22ab..a120b050 100644 --- a/src/annbatch/types.py +++ b/src/annbatch/types.py @@ -8,12 +8,11 @@ from scipy import sparse as sp from zarr import Array as ZarrArray -from .compat import CupyArray, CupyCSRMatrix, Tensor -from .utils import CSRContainer +from .compat import CupyArray, CupyCSRMatrix, JaxArray, JAXCsrMatrix, Tensor type BackingArray_T = ad.abc.CSRDataset | ZarrArray -type InputInMemoryArray_T = CSRContainer | np.ndarray -type OutputInMemoryArray_T = sp.csr_matrix | np.ndarray | CupyCSRMatrix | CupyArray | Tensor +type OutputInMemoryArray_T = sp.csr_matrix | np.ndarray | CupyCSRMatrix | CupyArray | Tensor | JaxArray | JAXCsrMatrix +type Dense_T = JaxArray | np.ndarray | CupyArray | Tensor class LoadRequest(TypedDict): diff --git a/src/annbatch/utils.py b/src/annbatch/utils.py index c02864f3..311e4da5 100644 --- a/src/annbatch/utils.py +++ b/src/annbatch/utils.py @@ -5,7 +5,7 @@ import warnings from dataclasses import dataclass from functools import wraps -from typing import TYPE_CHECKING, Concatenate, Protocol +from typing import TYPE_CHECKING, Concatenate, Literal, Protocol import anndata as ad import numpy as np @@ -13,7 +13,7 @@ import scipy as sp import zarr -from .compat import CupyArray, CupyCSRMatrix, Tensor +from .compat import CupyArray, CupyCSRMatrix, JaxArray, JAXCsrMatrix, Tensor if TYPE_CHECKING: from collections.abc import Callable, Iterable @@ -68,14 +68,60 @@ def interval_indexer_from_slices(slices: Iterable[slice]) -> pd.IntervalIndex: ) +def _cupy_dtype(dtype: np.dtype) -> np.dtype: + if dtype in {np.dtype("float32"), np.dtype("float64"), np.dtype("bool")}: + return dtype + if dtype.itemsize < 4: + return np.dtype("float32") + return np.dtype("float64") + + @dataclass -class CSRContainer: +class CSRContainer[DenseT: JaxArray | CupyArray | Tensor | np.ndarray]: """A low-cost container for moving around the buffers of a CSR object""" - elems: tuple[np.ndarray, np.ndarray, np.ndarray] + elems: tuple[DenseT, DenseT, DenseT] shape: tuple[int, int] dtype: np.dtype + def to( + self, target: Literal["jax", "torch", "cupy"] | None + ) -> Tensor | CupyCSRMatrix | JAXCsrMatrix | sp.sparse.csr_matrix: + """Put the CSR matrix into the correct output data structure. + + Parameters + ---------- + target + String target + + Returns + ------- + The sparse matrix. + """ + match target: + case "torch": + import torch + + return torch.sparse_csr_tensor( + self.elems[2], self.elems[1], self.elems[0], size=self.shape, dtype=self.dtype + ) + case "cupy" | None | "jax": + if target == "cupy": + csr_matrix = CupyCSRMatrix + elif target is None: + from scipy.sparse import csr_matrix + else: + csr_matrix = JAXCsrMatrix + return csr_matrix( + self.elems, + shape=self.shape, + **( + {"dtype": _cupy_dtype(self.dtype) if target == "cupy" else self.dtype} + if target != "jax" + else {} + ), + ) + # TODO: make this part of the public zarr or zarrs-python API. # We can do chunk coalescing in zarrs based on integer arrays, so I think diff --git a/tests/conftest.py b/tests/conftest.py index c1a76f8e..b005abad 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ from __future__ import annotations import random +import subprocess from pathlib import Path from typing import TYPE_CHECKING @@ -120,3 +121,17 @@ def simple_collection( shuffle_chunk_size=10, ) return ad.concat([ad.io.read_elem(ds) for ds in collection], join="outer"), collection + + +def pytest_itemcollected(item: pytest.Item) -> None: + """Define behavior of pytest.mark.{gpu,array_api}.""" + is_marked = len(list(item.iter_markers(name="gpu"))) > 0 + if is_marked: + try: + has_gpu = ( + subprocess.run(["nvidia-smi"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0 + ) + except FileNotFoundError: + has_gpu = False + if not has_gpu: + item.add_marker(pytest.mark.skip()) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 2871b1f4..8b6a70a9 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,9 +1,9 @@ from __future__ import annotations import math +from contextlib import nullcontext from importlib.util import find_spec -from types import NoneType -from typing import TYPE_CHECKING, TypedDict +from typing import TYPE_CHECKING, Literal, TypedDict import anndata as ad import h5py @@ -15,23 +15,18 @@ from annbatch import Loader, write_sharded from annbatch.abc import Sampler +from annbatch.compat import CupyArray, CupyCSRMatrix, JaxArray, JAXCsrMatrix, Tensor from annbatch.samplers import SequentialSampler -try: - from cupy import ndarray as CupyArray - from cupyx.scipy.sparse import csr_matrix as CupyCSRMatrix -except ImportError: - CupyCSRMatrix = NoneType - CupyArray = NoneType - if TYPE_CHECKING: from collections.abc import Callable from pathlib import Path from annbatch.io import DatasetCollection -skip_if_no_cupy = pytest.mark.skipif(find_spec("cupy") is None, reason="Can't test for preload_to_gpu without cupy") +skip_if_no_cupy = pytest.mark.skipif(find_spec("cupy") is None, reason="Need cupy installed") skip_if_no_torch = pytest.mark.skipif(find_spec("torch") is None, reason="Need torch installed.") +skip_if_no_jax = pytest.mark.skipif(find_spec("jax") is None, reason="Need jax installed.") class Data(TypedDict): @@ -103,16 +98,14 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: "gen_loader", [ pytest.param( - lambda collection, shuffle, use_zarrs, chunk_size=chunk_size, preload_nchunks=preload_nchunks, open_func=open_func, batch_size=batch_size, preload_to_gpu=preload_to_gpu, concat_strategy=concat_strategy: ( + lambda collection, shuffle, use_zarrs, chunk_size=chunk_size, preload_nchunks=preload_nchunks, open_func=open_func, batch_size=batch_size, to=to: ( Loader( shuffle=shuffle, chunk_size=chunk_size, preload_nchunks=preload_nchunks, return_index=True, batch_size=batch_size, - preload_to_gpu=preload_to_gpu, - to_torch=False, - concat_strategy=concat_strategy, + to=to, ).use_collection( collection, **( @@ -122,47 +115,20 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: ), ) ), - id=f"chunk_size={chunk_size}-preload_nchunks={preload_nchunks}-open_func={open_func.__name__[5:] if open_func is not None else 'None'}-batch_size={batch_size}{'-cupy' if preload_to_gpu else ''}-concat_strategy={concat_strategy}", # type: ignore[attr-defined] - marks=[skip_if_no_cupy, pytest.mark.gpu] if preload_to_gpu else [], + id=f"chunk_size={chunk_size}-preload_nchunks={preload_nchunks}-open_func={open_func.__name__[5:] if open_func is not None else 'None'}-batch_size={batch_size}-to={to if to is not None else 'cpu'}", # type: ignore[attr-defined] + marks=[skip_if_no_torch] + if to == "torch" + else ([skip_if_no_jax] if to == "jax" else ([skip_if_no_cupy] if to == "cupy" else [])), ) - for chunk_size, preload_nchunks, open_func, batch_size, preload_to_gpu, concat_strategy in [ + for chunk_size, preload_nchunks, open_func, batch_size, to in [ elem - for preload_to_gpu in [True, False] - for concat_strategy in ["concat-shuffle", "shuffle-concat"] + for to in [None, "torch", "jax", "cupy"] for open_func in [open_sparse, open_dense, None] for elem in [ - [ - 1, - 5, - open_func, - 1, - preload_to_gpu, - concat_strategy, - ], # singleton chunk size - [ - 5, - 1, - open_func, - 1, - preload_to_gpu, - concat_strategy, - ], # singleton preload - [ - 10, - 5, - open_func, - 5, - preload_to_gpu, - concat_strategy, - ], # batch size divides total in memory size evenly - [ - 10, - 5, - open_func, - 50, - preload_to_gpu, - concat_strategy, - ], # batch size equal to in-memory size loading + [1, 5, open_func, 1, to], # singleton chunk size + [5, 1, open_func, 1, to], # singleton preload + [10, 5, open_func, 5, to], # batch size divides total in memory size evenly + [10, 5, open_func, 50, to], # batch size equal to in-memory size loading ] ] ], @@ -191,7 +157,26 @@ def test_store_load_dataset( n_elems += x.shape[0] # Check feature dimension assert x.shape[1] == 100 - batches += [x.get() if isinstance(x, CupyCSRMatrix | CupyArray) else x] + if isinstance(x, CupyCSRMatrix | CupyArray): + batches += [x.get()] + elif isinstance(x, JAXCsrMatrix): + data = np.array(x.data) + indices = np.array(x.indices) + indptr = np.array(x.indptr) + shape = x.shape + batches += [sp.csr_matrix((data, indices, indptr), shape=shape)] + elif isinstance(x, JaxArray): + batches += [np.array(x)] + elif isinstance(x, Tensor): + if is_dense: + batches += [x.numpy()] + else: + crow = x.crow_indices().numpy() + col = x.col_indices().numpy() + data = x.values().numpy() + batches += [sp.csr_matrix((data, col, crow), shape=x.shape)] + else: + batches += [x] if label is not None: obs += [label] if var is not None: @@ -245,7 +230,7 @@ def test_zarr_store_errors_lt_1(gen_loader, adata_with_zarr_path_same_var_space: def test_bad_adata_X_type(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path]): data = open_dense(next(adata_with_zarr_path_same_var_space[1].glob("*.zarr"))) data["dataset"] = data["dataset"][...] - ds = Loader(shuffle=True, chunk_size=10, preload_nchunks=10, preload_to_gpu=False, to_torch=False) + ds = Loader(shuffle=True, chunk_size=10, preload_nchunks=10, preload_to_gpu=False, to=None) with pytest.raises(TypeError, match="Cannot add"): ds.add_dataset(**data) @@ -257,27 +242,29 @@ def test_use_collection_twice(simple_collection: tuple[ad.AnnData, DatasetCollec ds.use_collection(simple_collection[1]) -@pytest.mark.gpu -@skip_if_no_torch @pytest.mark.parametrize( "preload_to_gpu", [ - pytest.param( - True, - marks=skip_if_no_cupy, - ), + pytest.param(True, marks=pytest.mark.gpu), False, ], ids=["preload_to_gpu", "dont_preload_to_gpu"], ) +@pytest.mark.parametrize( + "to", + [ + pytest.param("jax", marks=skip_if_no_jax), + pytest.param("cupy", marks=skip_if_no_cupy), + pytest.param("torch", marks=skip_if_no_torch), + ], +) @pytest.mark.parametrize("open_func", [open_sparse, open_dense]) -def test_to_torch( +def test_to_gpu( adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], open_func: Callable[[Path], Data], preload_to_gpu: bool, + to: Literal["jax", "cupy", "torch"], ): - import torch - # batch_size guaranteed to have leftovers to drop ds = Loader( shuffle=False, @@ -286,10 +273,31 @@ def test_to_torch( batch_size=25, preload_to_gpu=preload_to_gpu, return_index=True, - to_torch=True, + to=to, ) ds.add_dataset(**open_func(next(adata_with_zarr_path_same_var_space[1].glob("*.zarr")))) - assert isinstance(next(iter(ds))["X"], torch.Tensor) + x = next(iter(ds))["X"] + match to: + case "torch": + import torch + + assert isinstance(x, torch.Tensor) + if preload_to_gpu: + assert x.is_cuda + case "jax": + import jax + from jax import Array + from jax.experimental.sparse import CSR + + assert isinstance(x, Array if open_func is open_dense else CSR) + if preload_to_gpu: + dev = jax.devices("gpu")[0] + assert x.device() == dev + case "cupy": + from cupy import ndarray + from cupyx.sparse import csr_matrix + + assert isinstance(x, ndarray if open_func is open_dense else csr_matrix) @pytest.mark.parametrize("drop_last", [True, False], ids=["drop", "kept"]) @@ -308,7 +316,7 @@ def test_drop_last(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], preload_to_gpu=False, return_index=True, drop_last=drop_last, - to_torch=False, + to=None, ) ds.add_dataset(**open_sparse(zarr_path)) batches = [] @@ -342,7 +350,7 @@ def test_len( shuffle=False, batch_size=batch_size, preload_to_gpu=False, - to_torch=False, + to=None, drop_last=drop_last, ) loader.add_dataset(**data) @@ -357,7 +365,7 @@ def test_len( def test_bad_adata_X_hdf5(adata_with_h5_path_different_var_space: tuple[ad.AnnData, Path]): with h5py.File(next(adata_with_h5_path_different_var_space[1].glob("*.h5ad"))) as f: data = ad.io.sparse_dataset(f["X"]) - ds = Loader(shuffle=True, chunk_size=10, preload_nchunks=10, preload_to_gpu=False, to_torch=False) + ds = Loader(shuffle=True, chunk_size=10, preload_nchunks=10, preload_to_gpu=False, to=None) with pytest.raises(TypeError, match="Cannot add"): ds.add_dataset(data) @@ -414,19 +422,25 @@ def test_torch_multiprocess_dataloading_zarr( @pytest.mark.parametrize( - "preload_to_gpu", [False, pytest.param(True, marks=[pytest.mark.gpu, skip_if_no_cupy])], ids=["no_cupy", "cupy"] + "to", + [ + pytest.param(None, id="cpu"), + pytest.param("torch", marks=[skip_if_no_torch]), + pytest.param("jax", marks=[skip_if_no_jax]), + pytest.param("cupy", marks=[skip_if_no_cupy, pytest.mark.gpu]), + ], ) -@pytest.mark.parametrize("to_torch", [False, pytest.param(True, marks=[skip_if_no_torch])], ids=["no_torch", "torch"]) def test_3d( - adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], use_zarrs: bool, preload_to_gpu: bool, to_torch: bool + adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], + use_zarrs: bool, + to: Literal["torch", "cupy", "jax"] | None, ): ds = Loader( chunk_size=10, preload_nchunks=4, shuffle=True, return_index=True, - preload_to_gpu=preload_to_gpu, - to_torch=to_torch, + to=to, ) ds.add_datasets( **concat([open_3d(p, use_zarrs=use_zarrs) for p in adata_with_zarr_path_same_var_space[1].glob("*.zarr")]) @@ -436,43 +450,55 @@ def test_3d( x_list, idx_list = [], [] for batch in ds: x, idxs = batch["X"], batch["index"] - if preload_to_gpu and not to_torch: + if to == "cupy": import cupy as cp assert isinstance(x, cp.ndarray) x = x.get() - if to_torch: + if to == "torch": import torch assert isinstance(x, torch.Tensor) x = x.cpu().numpy() + elif to == "jax": + import jax + + assert isinstance(x, jax.Array) + x = np.array(x) x_list.append(x) idx_list.append(idxs.ravel()) x = np.vstack(x_list) idxs = np.concatenate(idx_list) - assert np.array_equal(x[np.argsort(idxs)], x_ref) + np.testing.assert_almost_equal(x[np.argsort(idxs)], x_ref) @pytest.mark.skipif( find_spec("cupy") is not None, reason="Can't test for preload_to_gpu True ImportError with cupy installed" ) def test_no_cupy(): - with pytest.raises( - ImportError, match=r"Follow the directions at https://docs.cupy.dev/en/stable/install.html to install cupy." - ): - Loader(chunk_size=10, preload_nchunks=4, preload_to_gpu=True, to_torch=False) + with pytest.raises(ImportError, match=r"Install cupy using our."): + Loader(chunk_size=10, preload_nchunks=4, preload_to_gpu=True, to=None) -@pytest.mark.skipif( - find_spec("torch") is not None, reason="Can't test for to_torch True ImportError with torch installed" +@pytest.mark.parametrize( + "to", + [ + pytest.param( + to, + marks=pytest.mark.skipif( + find_spec(to) is not None, reason=f'Can\'t test for no to="{to}" with {to} installed' + ), + ) + for to in ["torch", "cupy", "jax"] + ], ) -def test_no_torch(): - with pytest.raises(ImportError, match=r"Try `pip install torch`."): - Loader(chunk_size=10, preload_nchunks=4, to_torch=True, preload_to_gpu=False) +def test_no_to_installed(to): + with pytest.raises(ImportError, match=r"Try `pip install annbatch\["): + Loader(chunk_size=10, preload_nchunks=4, to=to, preload_to_gpu=False) -def get_default_dense() -> type: +def legacy_get_default_dense() -> type: if find_spec("torch"): from torch import Tensor as expected_dense else: @@ -480,7 +506,7 @@ def get_default_dense() -> type: return expected_dense -def get_default_sparse() -> type: +def legacy_get_default_sparse() -> type: if find_spec("cupy"): from cupyx.scipy.sparse import csr_matrix as expected_sparse else: @@ -493,18 +519,18 @@ def get_default_sparse() -> type: @pytest.mark.parametrize( ("expected_cls", "kwargs"), ( - pytest.param(get_default_dense(), {"preload_to_gpu": False}, id="torch"), - pytest.param(get_default_sparse(), {"to_torch": False}, id="cupy"), + pytest.param(legacy_get_default_dense(), {"preload_to_gpu": False}, id="torch"), + pytest.param(legacy_get_default_sparse(), {"to_torch": False}, id="cupy"), ), ) -def test_default_data_structures( +def test_legeacy_default_data_structures( adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], expected_cls: type, kwargs: dict ): # format is a smoke test for sparse - ds = Loader( - chunk_size=10, preload_nchunks=4, batch_size=20, shuffle=True, return_index=False, **kwargs - ).add_dataset( - **(open_sparse if issubclass(expected_cls, get_default_sparse()) else open_dense)( + with pytest.warns(DeprecationWarning) if "to_torch" in kwargs else nullcontext(): + ds = Loader(chunk_size=10, preload_nchunks=4, batch_size=20, shuffle=True, return_index=False, **kwargs) + ds.add_dataset( + **(open_sparse if issubclass(expected_cls, legacy_get_default_sparse()) else open_dense)( list(adata_with_zarr_path_same_var_space[1].iterdir())[0] ) ) @@ -581,13 +607,13 @@ def test_mismatched_var_raises_error(tmp_path: Path, subtests): ("dtype_in", "expected"), [(np.int16, np.float32), (np.int32, np.float64), (np.float32, np.float32), (np.float64, np.float64)], ) -def test_preload_dtype(tmp_path: Path, dtype_in: np.dtype, expected: np.dtype): +def test_legacy_preload_dtype(tmp_path: Path, dtype_in: np.dtype, expected: np.dtype): z = zarr.open(tmp_path / "foo.zarr") write_sharded(z, ad.AnnData(X=sp.random(100, 10, dtype=dtype_in, format="csr", rng=np.random.default_rng()))) adata = ad.AnnData(X=ad.io.sparse_dataset(z["X"])) - loader = Loader(preload_to_gpu=True, batch_size=10, chunk_size=10, preload_nchunks=2, to_torch=False).add_adata( - adata - ) + with pytest.warns(DeprecationWarning): + loader = Loader(preload_to_gpu=True, batch_size=10, chunk_size=10, preload_nchunks=2, to_torch=False) + loader.add_adata(adata) assert next(iter(loader))["X"].dtype == expected @@ -628,7 +654,7 @@ def _sample(self, n_obs: int, worker_handle=None): data2 = open_dense(paths[1]) sampler = FailOnSecondValidateSampler() - loader = Loader(batch_sampler=sampler, preload_to_gpu=False, to_torch=False) + loader = Loader(batch_sampler=sampler, preload_to_gpu=False, to=None) # First add succeeds loader.add_dataset(**data1) @@ -667,7 +693,7 @@ def test_given_batch_sampler_samples_subset_of_combined_datasets( preload_nchunks=2, ) - loader = Loader(batch_sampler=sampler, preload_to_gpu=False, to_torch=False, return_index=True) + loader = Loader(batch_sampler=sampler, preload_to_gpu=False, to=None, return_index=True) loader.add_datasets(**concat(datas)) # Collect all yielded indices @@ -687,16 +713,12 @@ def test_cannot_provide_batch_sampler_with_sampler_args(kwarg): """Test that providing batch_sampler with sampler args raises in constructor.""" chunk_sampler = SequentialSampler(mask=slice(0, 50), batch_size=5, chunk_size=10, preload_nchunks=2) with pytest.raises(ValueError, match="Cannot specify.*when providing a custom sampler"): - Loader(batch_sampler=chunk_sampler, preload_to_gpu=False, to_torch=False, **kwarg) + Loader(batch_sampler=chunk_sampler, preload_to_gpu=False, to=None, **kwarg) def test_rng(simple_collection: tuple[ad.AnnData, DatasetCollection]): - ds1 = Loader( - chunk_size=10, preload_nchunks=4, batch_size=20, shuffle=True, rng=np.random.default_rng(0), to_torch=False - ) - ds2 = Loader( - chunk_size=10, preload_nchunks=4, batch_size=20, shuffle=True, rng=np.random.default_rng(0), to_torch=False - ) + ds1 = Loader(chunk_size=10, preload_nchunks=4, batch_size=20, shuffle=True, rng=np.random.default_rng(0), to=None) + ds2 = Loader(chunk_size=10, preload_nchunks=4, batch_size=20, shuffle=True, rng=np.random.default_rng(0), to=None) ds1.use_collection( simple_collection[1], )