Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 48 additions & 35 deletions src/annbatch/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@
V1_ENCODING = {"encoding-type": "annbatch-preshuffled", "encoding-version": "0.1.0"}


def _ds_to_memory(ds: Dataset2D) -> pd.DataFrame:
ds.index = ds.true_index
if "_index" in ds.columns:
del ds["_index"]
return ds.to_memory()


def _default_load_adata[T: zarr.Group | h5py.Group | PathLike[str] | str](x: T) -> ad.AnnData:
adata = ad.experimental.read_lazy(x, load_annotation_index=False)
if not isinstance(x, zarr.Group | h5py.Group):
Expand All @@ -46,14 +53,13 @@ def _default_load_adata[T: zarr.Group | h5py.Group | PathLike[str] | str](x: T)
group = x
# -1 indicates that all of each `obs` column should just be loaded, but this is probably fine since it goes column by column and discards.
# TODO: Bug with empty columns: https://github.com/scverse/anndata/pull/2307
for attr in ["obs", "var"]:
# Only one column at a time will be loaded so we will hopefully pick up the benefit of loading into memory by the cache without having memory pressure.
if len(getattr(adata, attr).columns) > 0:
setattr(adata, attr, ad.experimental.read_elem_lazy(group[attr], chunks=(-1,), use_range_index=True))
for col in getattr(adata, attr).columns:
# Nullables / categoricals have bad perforamnce characteristics when concatenating using dask
if pd.api.types.is_extension_array_dtype(getattr(adata, attr)[col].dtype):
getattr(adata, attr)[col] = getattr(adata, attr)[col].data
# Only one column at a time will be loaded so we will hopefully pick up the benefit of loading into memory by the cache without having memory pressure.
if len(adata.obs.columns) > 0:
adata.obs = ad.experimental.read_elem_lazy(group["obs"], chunks=(-1,), use_range_index=True)
for col in adata.obs.columns:
# Nullables / categoricals have bad perforamnce characteristics when concatenating using dask
if pd.api.types.is_extension_array_dtype(adata.obs[col].dtype):
adata.obs[col] = adata.obs[col].data
return adata


Expand Down Expand Up @@ -347,11 +353,24 @@ def _validate_groupby_columns[T: zarr.Group | h5py.Group | PathLike[str] | str](
def _lazy_load_adata[T: zarr.Group | h5py.Group | PathLike[str] | str](
paths: Iterable[T],
load_adata: Callable[[T], ad.AnnData] = _default_load_adata,
var_subset: Iterable[str] | None = None,
merge: Literal["same", "unique", "first", "only"] | None = None,
):
adatas = []
categoricals_in_all_adatas: dict[str, pd.Index] = {}
for i, path in tqdm(enumerate(paths), total=len(paths), desc="Lazy loading anndatas"):
adata = load_adata(path)
# TODO: File bug/issue in anndata about merging var xarray objects
# Otherwise there is no respect for the merge argument
if isinstance(adata.var, Dataset2D):
adata.var = _ds_to_memory(adata.var)
if adata.raw is not None and isinstance(adata.raw.var, Dataset2D):
adata_raw = adata.raw.to_adata()
adata_raw.var = _ds_to_memory(adata_raw.var)
del adata.raw
adata.raw = adata_raw
if var_subset is not None:
adata = adata[:, adata.var.index.isin(var_subset)]
Comment on lines +375 to +376
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also slipped this in - I don't think there's much point in outer joining genes if we're just gonna subset them anyway. It seems like we should just subset first and then outer join. I think this is just a distributive law of sets.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #198 (comment) for a bit more info on a small implication of this, which I think is actually more "correct" anyway

# Track the source file for this given anndata object
adata.obs["src_path"] = pd.Categorical.from_codes(
np.ones((adata.shape[0],), dtype="int") * i, categories=pd.Index([str(p) for p in paths])
Expand All @@ -371,16 +390,12 @@ def _lazy_load_adata[T: zarr.Group | h5py.Group | PathLike[str] | str](
categoricals_in_all_adatas[k] = categoricals_in_all_adatas[k].union(
categorical_cols_in_this_adata[k]
)
# TODO: Probably bug in anndata, need the true index for proper outer joins (can't skirt this with fake indexes, at least not in the mixed-type regime).
# See: https://github.com/scverse/anndata/pull/2299
if isinstance(adata.var, Dataset2D):
adata.var.index = adata.var.true_index
if adata.raw is not None and isinstance(adata.raw.var, Dataset2D):
adata.raw.var.index = adata.raw.var.true_index
raise RuntimeError("No Dataset 2D raw allowed")
adatas.append(adata)
if len(adatas) == 1:
return adatas[0]
adata = ad.concat(adatas, join="outer")
adata = ad.concat(adatas, join="outer", merge=merge)
if len(categoricals_in_all_adatas) > 0:
adata.uns["dataset2d_categoricals_to_convert"] = categoricals_in_all_adatas
return adata
Expand Down Expand Up @@ -464,25 +479,17 @@ def _persist_adata_in_memory(adata: ad.AnnData) -> ad.AnnData:
if isinstance(adata.X, DaskArray):
adata.X = _compute_blockwise(adata.X)
if isinstance(adata.obs, Dataset2D):
adata.obs = adata.obs.to_memory()
# TODO: This is a bug in anndata?
if "_index" in adata.obs.columns:
adata.obs.index = adata.obs["_index"]
del adata.obs["_index"]
adata.obs = _ds_to_memory(adata.obs)
adata = _to_categorical_obs(adata)
if isinstance(adata.var, Dataset2D):
adata.var = adata.var.to_memory()
if "_index" in adata.var.columns:
del adata.var["_index"]
raise RuntimeError("No Dataset2D var should be found")

if adata.raw is not None:
adata_raw = adata.raw.to_adata()
if isinstance(adata_raw.X, DaskArray):
adata_raw.X = _compute_blockwise(adata_raw.X)
if isinstance(adata_raw.var, Dataset2D):
adata_raw.var = adata_raw.var.to_memory()
if "_index" in adata_raw.var.columns:
del adata_raw.var["_index"]
raise RuntimeError("No Dataset2D var should be found")
if isinstance(adata_raw.obs, Dataset2D):
adata_raw.obs = adata_raw.obs.to_memory()
del adata.raw
Expand All @@ -498,8 +505,8 @@ def _persist_adata_in_memory(adata: ad.AnnData) -> ad.AnnData:
if "_index" in elem.columns:
del elem["_index"]
# TODO: Bug in anndata
if "obs" in axis_name:
elem.index = adata.obs_names
if "obs" in axis_name or "var" in axis_name:
elem.index = getattr(adata, f"{axis_name[:-1]}_names")
getattr(adata, axis_name)[k] = elem

return adata.to_memory()
Expand Down Expand Up @@ -604,6 +611,7 @@ def add_adatas(
shuffle_chunk_size: int = 1000,
shuffle: bool = True,
rng: np.random.Generator | None = None,
merge: Literal["same", "unique", "first", "only"] | None = None,
) -> Self:
"""Take AnnData paths and create or add to an on-disk set of AnnData datasets with uniform var spaces at the desired path (with `dataset_size` rows per dataset if running for the first time).

Expand Down Expand Up @@ -657,6 +665,8 @@ def add_adatas(
`(shuffle_chunk_size // dataset_size)` slices will be loaded of size `shuffle_chunk_size`.
rng
Random number generator for shuffling.
merge
var column merge strategy - see :func:`anndata.concat` for more information.

Examples
--------
Expand Down Expand Up @@ -695,6 +705,7 @@ def add_adatas(
"shuffle_chunk_size": shuffle_chunk_size,
"shuffle": shuffle,
"rng": rng,
"merge": merge,
}
if self.is_empty:
self._create_collection(**shared_kwargs, dataset_size=dataset_size, var_subset=var_subset)
Expand All @@ -716,6 +727,7 @@ def _create_collection(
dataset_size: int | str = "20GB",
shuffle_chunk_size: int = 1000,
shuffle: bool = True,
merge: Literal["same", "unique", "first", "only"] | None = None,
rng: np.random.Generator,
) -> None:
"""Take AnnData paths, create an on-disk set of AnnData datasets with uniform var spaces at the desired path with `dataset_size` rows per dataset.
Expand Down Expand Up @@ -762,6 +774,8 @@ def _create_collection(
shuffle_chunk_size
How many contiguous rows to load into memory before shuffling at once.
`(shuffle_chunk_size // dataset_size)` slices will be loaded of size `shuffle_chunk_size`.
merge
var column merge strategy - see :func:`anndata.concat` for more information. This setting is applied when concatenating on-disk datasets together (with input datasets if adding as well).
rng
Random number generator for shuffling.
"""
Expand All @@ -783,7 +797,7 @@ def _create_collection(
"Cannot have a larger slice size than observations per dataset. Reduce `shuffle_chunk_size` or increase `dataset_size`."
)

adata_concat = _lazy_load_adata(adata_paths, load_adata=load_adata)
adata_concat = _lazy_load_adata(adata_paths, load_adata=load_adata, var_subset=var_subset, merge=merge)
adata_concat.obs_names_make_unique()
dataset_size = min(adata_concat.shape[0], dataset_size)
chunks = _create_chunks_for_shuffling(
Expand All @@ -793,14 +807,10 @@ def _create_collection(
shuffle=shuffle,
shuffle_n_obs_per_dataset=dataset_size,
)

if var_subset is None:
var_subset = adata_concat.var_names
for i, chunk in enumerate(tqdm(chunks, desc="Creating dataset collection")):
var_mask = adata_concat.var_names.isin(var_subset)
# np.sort: It's more efficient to access elements sequentially from dask arrays
# The data will be shuffled later on, we just want the elements at this point
adata_chunk = adata_concat[np.sort(chunk), :][:, var_mask].copy()
adata_chunk = adata_concat[np.sort(chunk), :].copy()
adata_chunk = _persist_adata_in_memory(adata_chunk)
if shuffle:
# shuffle adata in memory to break up individual chunks
Expand Down Expand Up @@ -838,6 +848,7 @@ def _add_to_collection(
h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip",
shuffle_chunk_size: int = 1000,
shuffle: bool = True,
merge: Literal["same", "unique", "first", "only"] | None = None,
rng: np.random.Generator,
) -> None:
"""Add anndata files to an existing collection of sharded anndata zarr datasets.
Expand Down Expand Up @@ -869,6 +880,8 @@ def _add_to_collection(
To save memory, the blocks of a dense on-disk store can be sparsified for in-memory processing.
shuffle_chunk_size
How many contiguous rows to load into memory of the input data for pseudo-blockshuffling into the existing datasets.
merge
var column merge strategy - see :func:`anndata.concat` for more information.
shuffle
Whether or not to shuffle when adding. Otherwise, the incoming data will just be split up and appended.
"""
Expand All @@ -879,7 +892,7 @@ def _add_to_collection(
_validate_groupby_columns(adata_paths, load_adata=load_adata, groupby=groupby)
_validate_anndatas_and_maybe_get_bytes_per_row(adata_paths, load_adata=load_adata)
# Check for mismatched keys among the inputs.
adata_concat = _lazy_load_adata(adata_paths, load_adata=load_adata)
adata_concat = _lazy_load_adata(adata_paths, load_adata=load_adata, merge=merge)
if math.ceil(adata_concat.shape[0] / shuffle_chunk_size) < len(self._dataset_keys):
raise ValueError(
f"Use a shuffle size small enough to distribute the input data with {adata_concat.shape[0]} obs across {len(self._dataset_keys)} anndata stores."
Expand Down Expand Up @@ -917,7 +930,7 @@ def validate_load_adata(path_or_group):
subset_adata = _to_categorical_obs(
adata_concat[chunk, :][:, adata_concat.var.index.isin(adata_dataset.var.index)]
)
adata = ad.concat([adata_dataset, subset_adata], join="outer")
adata = ad.concat([adata_dataset, subset_adata], join="outer", merge=merge)
if shuffle:
idxs = rng.permutation(adata.shape[0])
else:
Expand Down
11 changes: 10 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def adata_with_h5_path_different_var_space(
params = getattr(request, "param", {})
n_adatas = params.get("n_adatas", 6)
all_adatas_have_raw = params.get("all_adatas_have_raw", True)
merge = params.get("merge", None)

tmp_path = Path(tmpdir_factory.mktemp("raw_adatas"))
tmp_path = tmp_path / "h5_files"
Expand All @@ -90,7 +91,14 @@ def adata_with_h5_path_different_var_space(
},
index=obs_idx,
),
var=pd.DataFrame(index=var_idx),
var=pd.DataFrame(
index=var_idx,
data={
f"only_{i}": pd.array(range(n), dtype="int64"),
f"partial_share_{i % 3}": pd.array(range(n), dtype="int64"),
"same": pd.array(range(n), dtype="int64"),
},
),
obsm={"arr": np.random.randn(m, 10), "df": pd.DataFrame({"numeric": np.arange(m)}, index=obs_idx)},
varm={"arr": np.random.randn(n, 10), "df": pd.DataFrame({"numeric": np.arange(n)}, index=var_idx)},
)
Expand All @@ -103,6 +111,7 @@ def adata_with_h5_path_different_var_space(
return ad.concat(
[ad.read_h5ad(tmp_path / shard) for shard in sorted(tmp_path.iterdir()) if str(shard).endswith(".h5ad")],
join="outer",
merge=merge,
), tmp_path


Expand Down
27 changes: 18 additions & 9 deletions tests/test_preshuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,42 +175,49 @@ def test_store_creation_default(
"load_adata", [pytest.param(None, id="default_read"), pytest.param(ad.experimental.read_lazy, id="fully_lazy")]
)
@pytest.mark.parametrize("var_subset", [[f"gene_{i}" for i in range(100)], None], ids=["var_subset", "no_subset"])
@pytest.mark.parametrize("merge", ["same", "unique", "first", "only", None])
def test_store_creation(
adata_with_h5_path_different_var_space: tuple[ad.AnnData, Path],
shuffle: bool,
load_adata: Callable[[str], ad.AnnData],
merge: Literal["same", "unique", "first", "only"] | None,
var_subset: list[str] | None,
):
h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir())
h5_dir_sorted = sorted(adata_with_h5_path_different_var_space[1].iterdir())
h5_files = [adata_with_h5_path_different_var_space[1] / f for f in h5_dir_sorted if str(f).endswith(".h5ad")]
# apply merge
orig_adatas = [ad.read_h5ad(shard) for shard in h5_files]
adata_orig = ad.concat(
orig_adatas,
join="outer",
merge=merge,
)
output_path = (
adata_with_h5_path_different_var_space[1].parent
/ f"zarr_store_creation_test_{shuffle}_{'default_read' if load_adata is None else 'custom_read'}{'_with_var_subset' if var_subset is not None else ''}.zarr"
/ f"zarr_store_creation_test_{shuffle}_{'default_read' if load_adata is None else 'custom_read'}{'_with_var_subset' if var_subset is not None else ''}_{merge}.zarr"
)
collection = DatasetCollection(output_path).add_adatas(
[adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")],
h5_files,
n_obs_per_chunk=5,
shard_size=10,
dataset_size=50,
shuffle_chunk_size=10,
shuffle=shuffle,
**({"load_adata": load_adata} if load_adata is not None else {}),
**({"var_subset": var_subset} if var_subset is not None else {}),
merge=merge,
)
assert not DatasetCollection(output_path).is_empty
assert V1_ENCODING.items() <= zarr.open(output_path).attrs.items()

adata_orig = adata_with_h5_path_different_var_space[0]
# make sure all category dtypes match
adatas_shuffled = [ad.io.read_elem(g) for g in collection]
for adata in adatas_shuffled:
assert adata.obs["label"].dtype == adata_orig.obs["label"].dtype
# subset to var_subset
adata_orig = adata_orig[:, adata_orig.var.index.isin(var_subset) if var_subset is not None else slice(None)]
adata_orig.obs_names_make_unique()
adata = ad.concat(
adatas_shuffled,
join="outer",
)
adata = ad.concat(adatas_shuffled, join="outer", merge="same")
del adata.obs["src_path"]
assert adata.X.shape[0] == adata_orig.X.shape[0]
assert adata.X.shape[1] == adata_orig.X.shape[1]
Expand Down Expand Up @@ -240,7 +247,9 @@ def test_store_creation(
adata.obs["label"] = adata.obs["label"].cat.reorder_categories(adata_orig.obs["label"].dtype.categories)

pd.testing.assert_frame_equal(adata.obs, adata_orig.obs)
pd.testing.assert_frame_equal(adata.var, adata_orig.var)
# TODO: Why is the orig dtype floats instead of ints for certain columns?
# Since it is the wrong one, we can leave this. The on-disk data is correct (int).
pd.testing.assert_frame_equal(adata.var, adata_orig.var, check_dtype=False)
z = zarr.open(output_path / "dataset_0")
# assert chunk behavior (unified n_obs_per_chunk=5 for both sparse and dense)
assert z["obsm"]["arr"].chunks[0] == 5, z["obsm"]["arr"]
Expand Down
Loading