diff --git a/pyproject.toml b/pyproject.toml index a63cac86..79dafe73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -150,7 +150,7 @@ lint.per-file-ignores."docs/*" = [ "I" ] lint.per-file-ignores."tests/*" = [ "D" ] lint.pydocstyle.convention = "numpy" -[tool.pytest.ini_option] +[tool.pytest.ini_options] markers = [ "gpu: mark test to run on GPU" ] testpaths = [ "tests" ] xfail_strict = true diff --git a/src/annbatch/__init__.py b/src/annbatch/__init__.py index 39180c0b..7fa9ed2e 100644 --- a/src/annbatch/__init__.py +++ b/src/annbatch/__init__.py @@ -5,6 +5,7 @@ from . import abc, types from .io import DatasetCollection, write_sharded from .loader import Loader +from .samplers._categorical_sampler import CategoricalSampler from .samplers._chunk_sampler import ChunkSampler __version__ = version("annbatch") @@ -14,6 +15,7 @@ "DatasetCollection", "types", "write_sharded", + "CategoricalSampler", "ChunkSampler", "abc", ] diff --git a/src/annbatch/samplers/__init__.py b/src/annbatch/samplers/__init__.py index 9f92bbf0..e963e650 100644 --- a/src/annbatch/samplers/__init__.py +++ b/src/annbatch/samplers/__init__.py @@ -1,5 +1,7 @@ +from ._categorical_sampler import CategoricalSampler from ._chunk_sampler import ChunkSampler __all__ = [ + "CategoricalSampler", "ChunkSampler", ] diff --git a/src/annbatch/samplers/_categorical_sampler.py b/src/annbatch/samplers/_categorical_sampler.py new file mode 100644 index 00000000..f7b224c2 --- /dev/null +++ b/src/annbatch/samplers/_categorical_sampler.py @@ -0,0 +1,323 @@ +"""Categorical sampler for group-stratified data access.""" + +from __future__ import annotations + +from itertools import batched +from typing import TYPE_CHECKING + +import numpy as np +import pandas as pd + +from annbatch.abc import Sampler +from annbatch.samplers._chunk_sampler import ChunkSampler +from annbatch.samplers._utils import is_in_worker +from annbatch.utils import check_lt_1 + +if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + + from annbatch.types import LoadRequest + + +class CategoricalSampler(Sampler): + """Categorical sampler for group-stratified batched data access. + + This sampler ensures each batch contains observations from a single category/group. + It iterates through all categories, yielding all batches exactly once per epoch. + The batch order is shuffled across categories, but each individual batch contains + observations from only one category. + + The sampler assumes data is sorted by category, with boundaries provided as slices. + For convenience, use :meth:`from_pandas` to construct from a pandas Categorical. + + Parameters + ---------- + category_boundaries + A sequence of slices defining the boundaries for each category. + Each slice represents a contiguous range of observations belonging to one category. + Data must be sorted by category before using this sampler. + Number of categories must be greater than 1 and all boundaries must be in increasing order. + batch_size + Number of observations per batch. + chunk_size + Size of each chunk i.e. the range of each chunk yielded. + shuffle + Whether to shuffle chunk and index order within each category. + preload_nchunks + Number of chunks to load per iteration. + rng + Random number generator for shuffling. + + Notes + ----- + This sampler does not support multiple workers. Using it with a DataLoader + that has `num_workers > 0` will raise an error. + + Examples + -------- + Using boundaries directly: + + >>> boundaries = [slice(0, 100), slice(100, 250), slice(250, 400)] + >>> sampler = CategoricalSampler( + ... category_boundaries=boundaries, + ... batch_size=32, + ... chunk_size=64, + ... preload_nchunks=4, + ... ) + + Using from_pandas for convenience: + + >>> import pandas as pd + >>> categories = pd.Categorical(["A", "A", "B", "B", "B", "C"]) + >>> sampler = CategoricalSampler.from_pandas( + ... categories, + ... batch_size=32, + ... chunk_size=64, + ... preload_nchunks=4, + ... ) + """ + + _category_samplers: list[ChunkSampler] + _rng: np.random.Generator + + def __init__( + self, + category_boundaries: Sequence[slice], + chunk_size: int, + preload_nchunks: int, + batch_size: int, + *, + shuffle: bool = False, + rng: np.random.Generator | None = None, + ): + check_lt_1([len(category_boundaries)], ["Number of categories"]) + if batch_size < chunk_size: + raise ValueError( + f"batch_size ({batch_size}) cannot be less than chunk_size ({chunk_size}) because each batch must be from one category." + ) + self._validate_boundaries(category_boundaries) + self._rng = rng or np.random.default_rng() + + child_rngs = self._rng.spawn(len(category_boundaries)) + + # Create a ChunkSampler for each category, using its boundary as the mask + # Always use drop_last=True internally + # also compute the number of batches for each category + self._n_batches_per_category = [ + int((boundary.stop - boundary.start) // batch_size) for boundary in category_boundaries + ] + self._category_samplers = [ + ChunkSampler( + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + mask=boundary, + shuffle=shuffle, + drop_last=True, + rng=child_rng, + ) + for boundary, child_rng in zip(category_boundaries, child_rngs, strict=True) + ] + + self._batch_size = batch_size + self._chunk_size = chunk_size + self._preload_nchunks = preload_nchunks + self._shuffle = shuffle + + def _validate_boundaries(self, category_boundaries: Sequence[slice]) -> None: + for i, boundary in enumerate(category_boundaries): + if not isinstance(boundary, slice): + raise TypeError(f"Expected slice for boundary {i}, got {type(boundary)}") + if boundary.step is not None and boundary.step != 1: + raise ValueError(f"Boundary {i} must have step=1 or None, got {boundary.step}") + if boundary.start is None or boundary.stop is None: + raise ValueError(f"Boundary {i} must have explicit start and stop") + if boundary.start >= boundary.stop: + raise ValueError(f"Boundary {i} must have start < stop, got {boundary}") + if i == 0 and boundary.start != 0: + raise ValueError(f"First boundary must start at 0, got {boundary.start}") + if i > 0 and boundary.start != category_boundaries[i - 1].stop: + raise ValueError( + f"Boundaries must be contiguous: boundary {i} starts at {boundary.start} " + f"but boundary {i - 1} ends at {category_boundaries[i - 1].stop}" + ) + + @staticmethod + def _boundaries_from_pandas(categorical: pd.Categorical | pd.Series) -> list[slice]: + """Compute category boundaries from a pandas Categorical or Series. + + Parameters + ---------- + categorical + A pandas Categorical or Series with categorical dtype. + Data must be sorted by category. + + Returns + ------- + list[slice] + Boundaries for each category as slices. + + Raises + ------ + ValueError + If the data is not sorted by category or is empty. + TypeError + If the input is not a Categorical or categorical Series. + """ + if isinstance(categorical, pd.Series): + if not isinstance(categorical.dtype, pd.CategoricalDtype): + raise TypeError(f"Expected categorical Series, got {categorical.dtype}") + categorical = categorical.cat + elif not isinstance(categorical, pd.Categorical): + raise TypeError(f"Expected pandas.Categorical or categorical Series, got {type(categorical)}") + + codes = categorical.codes + n_obs = len(codes) + + if n_obs == 0: + raise ValueError("Cannot create sampler from empty categorical") + + # Check if sorted by finding where codes decrease + if np.any(np.diff(codes) < 0): + raise ValueError( + "Data must be sorted by category. Use df.sort_values('category_column') before creating the sampler." + ) + + # Compute boundaries by finding where codes change + # We need to handle the case where some categories might be empty + change_points = np.where(np.diff(codes) != 0)[0] + 1 + starts = np.concatenate([[0], change_points]) + stops = np.concatenate([change_points, [n_obs]]) + + return [slice(int(start), int(stop)) for start, stop in zip(starts, stops, strict=True)] + + @classmethod + def from_pandas( + cls, + categorical: pd.Categorical | pd.Series, + chunk_size: int, + preload_nchunks: int, + batch_size: int, + *, + shuffle: bool = False, + rng: np.random.Generator | None = None, + ) -> CategoricalSampler: + """Create a CategoricalSampler from a pandas Categorical or Series. + + The data is assumed to be sorted by category. This method computes the + boundaries for each category based on where values change. + + Parameters + ---------- + categorical + A pandas Categorical or Series with categorical dtype. + Data must be sorted by category. + chunk_size + Size of each chunk. + preload_nchunks + Number of chunks to load per iteration. + batch_size + Number of observations per batch. + shuffle + Whether to shuffle chunk and index order within each category. + rng + Random number generator for shuffling. + + Returns + ------- + CategoricalSampler + A sampler configured with boundaries derived from the categorical. + + Raises + ------ + ValueError + If the data is not sorted by category. + TypeError + If the input is not a Categorical or categorical Series. + + Examples + -------- + >>> import pandas as pd + >>> # Data must be sorted by category + >>> obs_cat = pd.Categorical(["A", "A", "A", "B", "B", "C", "C", "C", "C"]) + >>> sampler = CategoricalSampler.from_pandas( + ... obs_cat, + ... batch_size=2, + ... chunk_size=4, + ... preload_nchunks=2, + ... ) + """ + boundaries = cls._boundaries_from_pandas(categorical) + + return cls( + category_boundaries=boundaries, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + shuffle=shuffle, + rng=rng, + ) + + @property + def batch_size(self) -> int: + return self._category_samplers[0].batch_size + + @property + def shuffle(self) -> bool: + return self._category_samplers[0].shuffle + + @property + def n_categories(self) -> int: + """The number of categories in this sampler.""" + return len(self._category_samplers) + + def validate(self, n_obs: int) -> None: + """Validate the sampler configuration against the loader's n_obs. + + Parameters + ---------- + n_obs + The total number of observations in the loader. + + Raises + ------ + ValueError + If the sampler configuration is invalid for the given n_obs. + """ + # Validate each category sampler + for sampler in self._category_samplers: + sampler.validate(n_obs) + + # Check for worker usage - CategoricalSampler doesn't support workers + if is_in_worker(): + raise ValueError("CategoricalSampler does not support multiple workers.") + + @staticmethod + def _iter_batches( + sampler: ChunkSampler, n_obs: int, chunks_per_batch: int + ) -> Iterator[tuple[list[slice], np.ndarray]]: + """Yield per batch given a sampler.""" + for load_request in sampler._sample(n_obs): + chunks = load_request["chunks"] + yield from batched(chunks, chunks_per_batch) + + def _sample(self, n_obs: int) -> Iterator[LoadRequest]: + batches_per_load = int((self._preload_nchunks * self._chunk_size) // self._batch_size) + chunks_per_batch = int(self._batch_size / self._chunk_size) + batch_generators = [self._iter_batches(sampler, n_obs, chunks_per_batch) for sampler in self._category_samplers] + # simulate the category order + category_order = np.concatenate([np.full(n, i) for i, n in enumerate(self._n_batches_per_category)]) + if self._shuffle: + self._rng.shuffle(category_order) + + # pre-allocate and reshape to batches_per_load x batch_size + # so that we can shuffle with numpy all at once + batch_indices = np.arange(batches_per_load * self._batch_size).reshape(batches_per_load, self._batch_size) + + for cat_idxs in batched(category_order, batches_per_load): + if self._shuffle: + batch_indices = self._rng.permuted(batch_indices, axis=1) + yield { + "chunks": [chunk for cat_idx in cat_idxs for chunk in next(batch_generators[cat_idx])], + "splits": list(batch_indices[: len(cat_idxs)]), + } diff --git a/src/annbatch/samplers/_chunk_sampler.py b/src/annbatch/samplers/_chunk_sampler.py index a785e0bb..86a68f20 100644 --- a/src/annbatch/samplers/_chunk_sampler.py +++ b/src/annbatch/samplers/_chunk_sampler.py @@ -9,6 +9,7 @@ import numpy as np from annbatch.abc import Sampler +from annbatch.samplers._utils import validate_batch_size from annbatch.utils import check_lt_1, split_given_size if TYPE_CHECKING: @@ -69,18 +70,8 @@ def __init__( raise ValueError("mask.start must be < mask.stop when mask.stop is specified") check_lt_1([chunk_size, preload_nchunks], ["Chunk size", "Preloaded chunks"]) - preload_size = chunk_size * preload_nchunks + validate_batch_size(batch_size, chunk_size, preload_nchunks) - if batch_size > preload_size: - raise ValueError( - "batch_size cannot exceed chunk_size * preload_nchunks. " - f"Got batch_size={batch_size}, but max is {preload_size}." - ) - if preload_size % batch_size != 0: - raise ValueError( - "chunk_size * preload_nchunks must be divisible by batch_size. " - f"Got {preload_size} % {batch_size} = {preload_size % batch_size}." - ) self._rng = rng or np.random.default_rng() self._batch_size, self._chunk_size, self._shuffle = batch_size, chunk_size, shuffle self._preload_nchunks, self._mask, self._drop_last = ( @@ -127,7 +118,7 @@ def _get_worker_handle(self) -> WorkerHandle | None: from annbatch.utils import WorkerHandle if get_worker_info() is not None: - worker_handle = WorkerHandle() + worker_handle = WorkerHandle(self._rng) # Worker mode validation - only check when there are multiple workers # With batch_size=1, every batch is exactly 1 item, so no partial batches exist if ( @@ -146,23 +137,24 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: # Create chunk indices for possible shuffling and worker sharding chunk_indices = np.arange(math.ceil((stop - start) / self._chunk_size)) if self._shuffle: - if worker_handle is None: - self._rng.shuffle(chunk_indices) - else: - worker_handle.shuffle(chunk_indices) + # Use sampler's RNG for chunk ordering - same across all workers + self._rng.shuffle(chunk_indices) chunks = self._compute_chunks(chunk_indices, start, stop) # Worker sharding: each worker gets a disjoint subset of chunks + if self._shuffle: + self._rng.shuffle(chunks) if worker_handle is not None: - chunks = worker_handle.get_part_for_worker(chunks) + chunks = np.array_split(chunks, worker_handle.num_workers)[worker_handle.worker_id] # Set up the iterator for chunks and the batch indices for splits in_memory_size = self._chunk_size * self._preload_nchunks chunks_per_request = split_given_size(chunks, self._preload_nchunks) batch_indices = np.arange(in_memory_size) split_batch_indices = split_given_size(batch_indices, self._batch_size) + batch_rng = worker_handle.rng if worker_handle is not None else self._rng for request_chunks in chunks_per_request[:-1]: if self._shuffle: # Avoid copies using in-place shuffling since `self._shuffle` should not change mid-training - self._rng.shuffle(batch_indices) + batch_rng.shuffle(batch_indices) split_batch_indices = split_given_size(batch_indices, self._batch_size) yield {"chunks": request_chunks, "splits": split_batch_indices} # On the last yield, drop the last uneven batch and create new batch_indices since the in-memory size of this last yield could be divisible by batch_size but smaller than preload_nslices * slice_size @@ -174,10 +166,10 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: if total_obs_in_last_batch < self._batch_size: return total_obs_in_last_batch -= total_obs_in_last_batch % self._batch_size - batch_indices = split_given_size( - (self._rng.permutation if self._shuffle else np.arange)(total_obs_in_last_batch), - self._batch_size, + indices = ( + batch_rng.permutation(total_obs_in_last_batch) if self._shuffle else np.arange(total_obs_in_last_batch) ) + batch_indices = split_given_size(indices, self._batch_size) yield {"chunks": final_chunks, "splits": batch_indices} def _compute_chunks(self, chunk_indices: np.ndarray, start: int, stop: int) -> list[slice]: diff --git a/src/annbatch/samplers/_utils.py b/src/annbatch/samplers/_utils.py new file mode 100644 index 00000000..6ce6ae1a --- /dev/null +++ b/src/annbatch/samplers/_utils.py @@ -0,0 +1,78 @@ +"""Utility functions for samplers.""" + +from __future__ import annotations + +from importlib.util import find_spec +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + from annbatch.utils import WorkerHandle + + +def is_in_worker() -> bool: + """Check if currently running inside a torch DataLoader worker. + + Returns + ------- + bool + True if inside a DataLoader worker, False otherwise. + """ + if find_spec("torch"): + from torch.utils.data import get_worker_info + + return get_worker_info() is not None + return False + + +def get_worker_handle(rng: np.random.Generator) -> WorkerHandle | None: + """Get a WorkerHandle if running inside a torch DataLoader worker. + + Parameters + ---------- + rng + The RNG to spawn worker-specific RNGs from. + + Returns + ------- + WorkerHandle | None + A WorkerHandle if inside a DataLoader worker, None otherwise. + """ + if is_in_worker(): + from annbatch.utils import WorkerHandle + + return WorkerHandle(rng) + return None + + +def validate_batch_size(batch_size: int, chunk_size: int, preload_nchunks: int) -> None: + """Validate batch_size against chunk_size and preload_nchunks constraints. + + Parameters + ---------- + batch_size + Number of observations per batch. + chunk_size + Size of each chunk. + preload_nchunks + Number of chunks to load per iteration. + + Raises + ------ + ValueError + If batch_size exceeds the total loaded size (chunk_size * preload_nchunks) + or if the total loaded size is not divisible by batch_size. + """ + preload_size = chunk_size * preload_nchunks + + if batch_size > preload_size: + raise ValueError( + "batch_size cannot exceed chunk_size * preload_nchunks. " + f"Got batch_size={batch_size}, but max is {preload_size}." + ) + if preload_size % batch_size != 0: + raise ValueError( + "chunk_size * preload_nchunks must be divisible by batch_size. " + f"Got {preload_size} % {batch_size} = {preload_size % batch_size}." + ) diff --git a/src/annbatch/utils.py b/src/annbatch/utils.py index 45c7ea2e..eaadb844 100644 --- a/src/annbatch/utils.py +++ b/src/annbatch/utils.py @@ -4,8 +4,7 @@ import itertools import warnings from dataclasses import dataclass -from functools import cached_property, wraps -from importlib.util import find_spec +from functools import wraps from typing import TYPE_CHECKING, Concatenate, Protocol import anndata as ad @@ -102,59 +101,56 @@ def __iter__(self): total += gap -class WorkerHandle: # noqa: D101 - @cached_property - def _worker_info(self): - if find_spec("torch"): - from torch.utils.data import get_worker_info +def _spawn_worker_rng( + rng: np.random.Generator | None, + worker_id: int, +) -> np.random.Generator: + """Create a worker-specific RNG using the sequence-of-integers seeding pattern. - return get_worker_info() - return None + Uses NumPy's recommended approach for multi-process RNG. See: + https://numpy.org/doc/stable/reference/random/parallel.html#sequence-of-integer-seeds + """ + if rng is not None: + root_seed = rng.integers(np.iinfo(np.int64).max) + return np.random.default_rng([worker_id, root_seed]) + else: + return np.random.default_rng() + + +class WorkerHandle: + """Handle for torch DataLoader worker context. + + This class should only be instantiated inside a torch DataLoader worker process + (i.e., when `torch.utils.data.get_worker_info()` returns a non-None value). + It provides worker-specific RNGs. + + Parameters + ---------- + rng + The RNG to derive worker-specific RNGs from. If None, uses a fresh unseeded RNG. + + """ + + def __init__(self, rng: np.random.Generator | None = None): + from torch.utils.data import get_worker_info + + self._worker_info = get_worker_info() + self._rng = _spawn_worker_rng(rng, self._worker_info.id) + + @property + def rng(self) -> np.random.Generator: + """Return the RNG for the current worker.""" + return self._rng @property def num_workers(self) -> int: - """Return the number of workers, or 1 if not in a worker context.""" - if self._worker_info is None: - return 1 + """Return the number of workers.""" return self._worker_info.num_workers - @cached_property - def _rng(self): - if self._worker_info is None: - return np.random.default_rng() - else: - # This is used for the _get_chunks function - # Use the same seed for all workers that the resulting splits are the same across workers - # torch default seed is `base_seed + worker_id`. Hence, subtract worker_id to get the base seed - return np.random.default_rng(self._worker_info.seed - self._worker_info.id) - - def shuffle(self, obj: np.typing.ArrayLike) -> None: - """Perform in-place shuffle. - - Parameters - ---------- - obj - The object to be shuffled - """ - self._rng.shuffle(obj) - - def get_part_for_worker(self, obj: np.ndarray) -> np.ndarray: - """Get a chunk of an incoming array accordnig to the current worker id. - - Parameters - ---------- - obj - Incoming array - - Returns - ------- - A evenly split part of the ray corresponding to how many workers there are. - """ - if self._worker_info is None: - return obj - num_workers, worker_id = self._worker_info.num_workers, self._worker_info.id - chunks_split = np.array_split(obj, num_workers) - return chunks_split[worker_id] + @property + def worker_id(self) -> int: + """Worker ID.""" + return self._worker_info.id def check_lt_1(vals: list[int], obs: list[str]) -> None: diff --git a/tests/test_categorical_sampler.py b/tests/test_categorical_sampler.py new file mode 100644 index 00000000..6746e924 --- /dev/null +++ b/tests/test_categorical_sampler.py @@ -0,0 +1,440 @@ +"""Tests for CategoricalSampler.""" + +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest + +from annbatch import CategoricalSampler + + +def collect_all_indices(sampler, n_obs): + """Helper to collect all indices from sampler, organized by load request.""" + all_indices = [] + for load_request in sampler.sample(n_obs): + indices_in_request = [] + for chunk in load_request["chunks"]: + indices_in_request.extend(range(chunk.start, chunk.stop)) + all_indices.append(indices_in_request) + return all_indices + + +def collect_flat_indices(sampler, n_obs): + """Helper to collect all indices flattened from splits (the actual batch indices).""" + indices = [] + for load_request in sampler.sample(n_obs): + # Build chunk indices mapping (indices into concatenated chunk data -> original indices) + chunk_indices = [] + for chunk in load_request["chunks"]: + chunk_indices.extend(range(chunk.start, chunk.stop)) + # Collect actual batch indices from splits + for split in load_request["splits"]: + for idx in split: + indices.append(chunk_indices[idx]) + return indices + + +# ============================================================================= +# Basic construction tests +# ============================================================================= + + +def test_basic_construction(): + """Test basic CategoricalSampler construction.""" + boundaries = [slice(0, 100), slice(100, 200), slice(200, 300)] + sampler = CategoricalSampler( + category_boundaries=boundaries, + batch_size=20, + chunk_size=20, + preload_nchunks=2, + ) + assert sampler.batch_size == 20 + assert sampler.n_categories == 3 + assert sampler.shuffle is False + + +def test_from_pandas_categorical(): + """Test construction from pandas Categorical.""" + # Create sorted categorical data + categories = pd.Categorical(["A"] * 50 + ["B"] * 30 + ["C"] * 20) + sampler = CategoricalSampler.from_pandas( + categories, + batch_size=10, + chunk_size=10, + preload_nchunks=2, + ) + assert sampler.n_categories == 3 + + +def test_from_pandas_series(): + """Test construction from pandas Series with categorical dtype.""" + series = pd.Series(pd.Categorical(["X"] * 40 + ["Y"] * 60)) + sampler = CategoricalSampler.from_pandas( + series, + batch_size=10, + chunk_size=10, + preload_nchunks=2, + ) + assert sampler.n_categories == 2 + + +def test_from_pandas_unsorted_raises(): + """Test that unsorted data raises ValueError.""" + categories = pd.Categorical(["A", "B", "A", "B"]) # Not sorted + with pytest.raises(ValueError, match="Data must be sorted"): + CategoricalSampler.from_pandas( + categories, + batch_size=2, + chunk_size=2, + preload_nchunks=1, + ) + + +def test_from_pandas_non_categorical_raises(): + """Test that non-categorical Series raises TypeError.""" + series = pd.Series(["A", "B", "C"]) # Not categorical + with pytest.raises(TypeError, match="Expected categorical"): + CategoricalSampler.from_pandas( + series, + batch_size=2, + chunk_size=2, + preload_nchunks=1, + ) + + +def test_from_pandas_empty_raises(): + """Test that empty categorical raises ValueError.""" + categories = pd.Categorical([]) + with pytest.raises(ValueError, match="empty"): + CategoricalSampler.from_pandas( + categories, + batch_size=2, + chunk_size=2, + preload_nchunks=1, + ) + + +# ============================================================================= +# Boundary validation tests +# ============================================================================= + + +@pytest.mark.parametrize( + "boundaries,error_match", + [ + pytest.param([slice(0, 10), slice(10, 5)], "start < stop", id="start_gte_stop"), + pytest.param([slice(0, 10, 2)], "step=1", id="step_not_one"), + pytest.param([slice(None, 10)], "explicit start and stop", id="none_start"), + pytest.param([slice(0, None)], "explicit start and stop", id="none_stop"), + pytest.param(["not a slice"], "Expected slice", id="not_a_slice"), + pytest.param([slice(5, 15)], "must start at 0", id="not_starting_at_zero"), + pytest.param([slice(0, 10), slice(15, 25)], "contiguous", id="gap_between_boundaries"), + ], +) +def test_invalid_boundary_raises(boundaries, error_match): + """Test that invalid boundaries raise appropriate errors.""" + with pytest.raises((ValueError, TypeError), match=error_match): + CategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=10, + preload_nchunks=1, + ) + + +def test_empty_boundaries_raises(): + """Test that empty boundaries list raises ValueError.""" + with pytest.raises(ValueError): + CategoricalSampler( + category_boundaries=[], + batch_size=10, + chunk_size=10, + preload_nchunks=1, + ) + + +# ============================================================================= +# Coverage tests +# ============================================================================= + + +@pytest.mark.parametrize( + "category_sizes,chunk_size,preload_nchunks,batch_size", + [ + pytest.param([100, 100, 100], 10, 2, 10, id="equal_categories"), + pytest.param([50, 150, 100], 10, 2, 10, id="unequal_categories"), + pytest.param([30, 30, 30], 5, 3, 5, id="small_categories"), + pytest.param([200], 25, 2, 25, id="single_category"), + pytest.param([10, 20, 30, 40], 5, 1, 5, id="many_categories"), + ], +) +def test_coverage_all_indices(category_sizes, chunk_size, preload_nchunks, batch_size): + """Test that sampler covers all indices exactly once.""" + # Build boundaries from category sizes + boundaries = [] + start = 0 + for size in category_sizes: + boundaries.append(slice(start, start + size)) + start += size + n_obs = sum(category_sizes) + + sampler = CategoricalSampler( + category_boundaries=boundaries, + batch_size=batch_size, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + shuffle=False, + ) + + all_indices = collect_flat_indices(sampler, n_obs) + assert set(all_indices) == set(range(n_obs)), "Should cover all indices" + assert len(all_indices) == n_obs, "Should cover each index exactly once" + + +def _get_category_for_index(index: int, boundaries: list[slice]) -> int: + """Helper to find which category an index belongs to.""" + for i, boundary in enumerate(boundaries): + if boundary.start <= index < boundary.stop: + return i + raise ValueError(f"Index {index} not in any category boundary") + + +@pytest.mark.parametrize("shuffle", [False, True]) +@pytest.mark.parametrize( + "boundaries", + [ + pytest.param([slice(0, 100), slice(100, 200), slice(200, 300)], id="equal_categories"), + pytest.param([slice(0, 50), slice(50, 150), slice(150, 300)], id="unequal_categories"), + pytest.param([slice(0, 30), slice(30, 60), slice(60, 90), slice(90, 120)], id="many_categories"), + ], +) +def test_each_split_from_single_category(boundaries, shuffle): + """Test that each split (batch) within a load request is from a single category. + + Note: The CategoricalSampler combines batches from multiple categories into + a single load request for efficiency, but each split within that request + should only contain indices from a single category. + """ + n_obs = boundaries[-1].stop + + sampler = CategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=10, + preload_nchunks=2, + shuffle=shuffle, + rng=np.random.default_rng(42), + ) + + for load_request in sampler.sample(n_obs): + chunks = load_request["chunks"] + if len(chunks) == 0: + continue + + # Build mapping from concatenated chunk index to original index + chunk_indices = [] + for chunk in chunks: + chunk_indices.extend(range(chunk.start, chunk.stop)) + + # Verify each split contains indices from only one category + for split in load_request["splits"]: + if len(split) == 0: + continue + + # Get the category of the first index in this split + first_original_idx = chunk_indices[split[0]] + expected_category = _get_category_for_index(first_original_idx, boundaries) + + # Verify all indices in this split belong to the same category + for idx in split: + original_idx = chunk_indices[idx] + split_category = _get_category_for_index(original_idx, boundaries) + assert split_category == expected_category, ( + f"Split index {idx} (original {original_idx}) belongs to category {split_category}, " + f"but expected category {expected_category}" + ) + + +# ============================================================================= +# Shuffle tests +# ============================================================================= + + +def test_shuffle_changes_order(): + """Test that shuffling changes the order of indices within categories.""" + boundaries = [slice(0, 100), slice(100, 200)] + n_obs = 200 + + sampler_no_shuffle = CategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=10, + preload_nchunks=2, + shuffle=False, + ) + + sampler_shuffle = CategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=10, + preload_nchunks=2, + shuffle=True, + rng=np.random.default_rng(42), + ) + + indices_no_shuffle = collect_flat_indices(sampler_no_shuffle, n_obs) + indices_shuffle = collect_flat_indices(sampler_shuffle, n_obs) + + # Both should cover same indices + assert set(indices_no_shuffle) == set(indices_shuffle) + + # But order should differ + assert indices_no_shuffle != indices_shuffle + + +# ============================================================================= +# Validation tests +# ============================================================================= + + +def test_validate_boundary_exceeds_n_obs(): + """Test validation fails when boundary exceeds n_obs.""" + boundaries = [slice(0, 100), slice(100, 300)] # Second boundary goes to 300 + sampler = CategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=10, + preload_nchunks=2, + ) + + with pytest.raises(ValueError, match="exceeds loader n_obs"): + sampler.validate(n_obs=200) # n_obs is only 200 + + +def test_validate_passes_for_valid_config(): + """Test validation passes for valid configuration.""" + boundaries = [slice(0, 100), slice(100, 200)] + sampler = CategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=10, + preload_nchunks=2, + ) + # Should not raise + sampler.validate(n_obs=200) + + +# ============================================================================= +# Batch size validation tests +# ============================================================================= + + +def test_invalid_batch_size_raises(): + """Test that batch_size < chunk_size raises ValueError.""" + with pytest.raises(ValueError, match="cannot be less than chunk_size"): + CategoricalSampler( + category_boundaries=[slice(0, 100)], + batch_size=5, + chunk_size=10, + preload_nchunks=2, + ) + + +# ============================================================================= +# Drop last tests +# ============================================================================= + + +def test_drop_last_enforced(): + """Test that incomplete batches are always dropped (drop_last is enforced).""" + # 45 obs, batch_size 10 -> should get 4 complete batches (40 obs) + boundaries = [slice(0, 45)] + sampler = CategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=10, + preload_nchunks=2, + ) + + total_obs = 0 + for load_request in sampler.sample(45): + for split in load_request["splits"]: + total_obs += len(split) + + assert total_obs == 40, "should drop incomplete batch" + + +# ============================================================================= +# Splits structure tests +# ============================================================================= + + +def test_splits_have_correct_batch_size(): + """Test that splits have correct batch sizes (all complete batches).""" + boundaries = [slice(0, 100)] + sampler = CategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=10, + preload_nchunks=2, + ) + + for load_request in sampler.sample(100): + splits = load_request["splits"] + # All splits should have exactly batch_size elements (drop_last is enforced) + for split in splits: + assert len(split) == 10 + + +# ============================================================================= +# Integration with from_pandas +# ============================================================================= + + +def test_from_pandas_integration(): + """Test full integration with from_pandas and sampling.""" + # Simulate sorted obs column + n_obs = 150 + categories = pd.Categorical(["celltype_A"] * 50 + ["celltype_B"] * 70 + ["celltype_C"] * 30) + + sampler = CategoricalSampler.from_pandas( + categories, + batch_size=10, + chunk_size=10, + preload_nchunks=2, + shuffle=True, + rng=np.random.default_rng(123), + ) + + all_indices = collect_flat_indices(sampler, n_obs) + assert set(all_indices) == set(range(n_obs)) + assert len(all_indices) == n_obs + + +# ============================================================================= +# Reproducibility tests +# ============================================================================= + + +def test_rng_reproducibility(): + """Test that same RNG seed gives same results.""" + boundaries = [slice(0, 100), slice(100, 200)] + + def get_indices(seed): + sampler = CategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=10, + preload_nchunks=2, + shuffle=True, + rng=np.random.default_rng(seed), + ) + return collect_flat_indices(sampler, 200) + + indices1 = get_indices(42) + indices2 = get_indices(42) + indices3 = get_indices(99) + + assert indices1 == indices2, "Same seed should give same results" + assert indices1 != indices3, "Different seeds should give different results" diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 05299cc0..e60193f3 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -7,6 +7,7 @@ from annbatch import ChunkSampler from annbatch.abc import Sampler +from annbatch.utils import _spawn_worker_rng # TODO(selmanozleyen): Check for the validation within the _get_worker_handle method. Mock worker handle wouldn't make sense # but overall one must also think about how validation can't be independent of the worker handle. @@ -26,28 +27,30 @@ def collect_indices(sampler, n_obs): class MockWorkerHandle: """Simulates torch worker context for testing without actual DataLoader.""" - def __init__(self, worker_id: int, num_workers: int, seed: int = 42): + def __init__(self, worker_id: int, num_workers: int, rng: np.random.Generator | None = None): self.worker_id = worker_id self._num_workers = num_workers - self._rng = np.random.default_rng(seed) + self._rng = _spawn_worker_rng(rng, worker_id) + + @property + def rng(self) -> np.random.Generator: + return self._rng @property def num_workers(self) -> int: return self._num_workers - def shuffle(self, obj): - self._rng.shuffle(obj) - - def get_part_for_worker(self, obj: np.ndarray) -> np.ndarray: - return np.array_split(obj, self._num_workers)[self.worker_id] - class ChunkSamplerWithMockWorkerHandle(ChunkSampler): - def set_worker_handle(self, worker_handle: MockWorkerHandle): - self.worker_handle = worker_handle + def set_mock_worker_info(self, worker_id: int, num_workers: int): + """Set mock worker info. The RNG will be derived from sampler's _rng.""" + self._mock_worker_id = worker_id + self._mock_num_workers = num_workers def _get_worker_handle(self) -> MockWorkerHandle | None: - return self.worker_handle + if hasattr(self, "_mock_worker_id"): + return MockWorkerHandle(self._mock_worker_id, self._mock_num_workers, self._rng) + return None # ============================================================================= @@ -175,7 +178,6 @@ def test_workers_cover_full_dataset_without_overlap( """Test workers cover full dataset without overlap. Also checks if there are empty splits in any of the load requests.""" all_worker_indices = [] for worker_id in range(num_workers): - worker_handle = MockWorkerHandle(worker_id, num_workers) sampler = ChunkSamplerWithMockWorkerHandle( mask=slice(0, None), batch_size=batch_size, @@ -183,7 +185,7 @@ def test_workers_cover_full_dataset_without_overlap( preload_nchunks=preload_nchunks, drop_last=drop_last, ) - sampler.set_worker_handle(worker_handle) + sampler.set_mock_worker_info(worker_id, num_workers) all_worker_indices.append(collect_indices(sampler, n_obs)) # All workers should have disjoint chunks @@ -195,6 +197,40 @@ def test_workers_cover_full_dataset_without_overlap( assert set().union(*all_worker_indices) == set(range(n_obs)) +def test_batch_shuffle_is_reproducible(): + """Test that batch shuffling is reproducible when using ChunkSampler directly.""" + n_obs, chunk_size, preload_nchunks, batch_size = 100, 10, 2, 5 + seed = 42 + + def collect_splits(sampler): + all_splits = [] + for load_request in sampler.sample(n_obs): + for split in load_request["splits"]: + all_splits.append(split.tolist()) + return all_splits + + # Run twice with same seed - should get identical batch ordering + sampler1 = ChunkSampler( + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + shuffle=True, + rng=np.random.default_rng(seed), + ) + splits1 = collect_splits(sampler1) + + sampler2 = ChunkSampler( + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + shuffle=True, + rng=np.random.default_rng(seed), + ) + splits2 = collect_splits(sampler2) + + assert splits1 == splits2, "Batch shuffling should be reproducible with same seed" + + # ============================================================================= # Validation tests # =============================================================================