diff --git a/src/annbatch/io.py b/src/annbatch/io.py index b76be75..3ec4d3e 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -505,6 +505,23 @@ def _persist_adata_in_memory(adata: ad.AnnData) -> ad.AnnData: return adata.to_memory() +def _prepare_adata_for_write( + adata: ad.AnnData, + *, + shuffle: bool, + rng: np.random.Generator, + groupby: str | Iterable[str] | None, +) -> ad.AnnData: + adata = _persist_adata_in_memory(adata) + if shuffle: + # Shuffle after persisting so indexing happens fully in memory. + idxs = rng.permutation(np.arange(adata.shape[0])) + adata = adata[idxs] + if groupby is not None: + adata = _groupby_adata(adata, groupby=groupby) + return adata + + DATASET_PREFIX = "dataset" @@ -801,13 +818,7 @@ def _create_collection( # np.sort: It's more efficient to access elements sequentially from dask arrays # The data will be shuffled later on, we just want the elements at this point adata_chunk = adata_concat[np.sort(chunk), :][:, var_mask].copy() - adata_chunk = _persist_adata_in_memory(adata_chunk) - if shuffle: - # shuffle adata in memory to break up individual chunks - idxs = rng.permutation(np.arange(len(adata_chunk))) - adata_chunk = adata_chunk[idxs] - if groupby is not None: - adata_chunk = _groupby_adata(adata_chunk, groupby=groupby) + adata_chunk = _prepare_adata_for_write(adata_chunk, shuffle=shuffle, rng=rng, groupby=groupby) if isinstance(self._group, zarr.Group): write_sharded( self._group, @@ -918,13 +929,7 @@ def validate_load_adata(path_or_group): adata_concat[chunk, :][:, adata_concat.var.index.isin(adata_dataset.var.index)] ) adata = ad.concat([adata_dataset, subset_adata], join="outer") - if shuffle: - idxs = rng.permutation(adata.shape[0]) - else: - idxs = np.arange(adata.shape[0]) - adata = _persist_adata_in_memory(adata[idxs, :].copy()) - if groupby is not None: - adata = _groupby_adata(adata, groupby=groupby) + adata = _prepare_adata_for_write(adata, shuffle=shuffle, rng=rng, groupby=groupby) if isinstance(self._group, zarr.Group): write_sharded( self._group,