From fcbeb21882ee09498047466fc015656fea9b405a Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 4 Feb 2026 15:13:23 +0100 Subject: [PATCH 01/16] wip --- src/annbatch/__init__.py | 2 + src/annbatch/samplers/__init__.py | 2 + src/annbatch/samplers/_categorical_sampler.py | 305 ++++++++++++ src/annbatch/samplers/_chunk_sampler.py | 14 +- src/annbatch/samplers/_utils.py | 34 ++ tests/test_categorical_sampler.py | 470 ++++++++++++++++++ 6 files changed, 815 insertions(+), 12 deletions(-) create mode 100644 src/annbatch/samplers/_categorical_sampler.py create mode 100644 src/annbatch/samplers/_utils.py create mode 100644 tests/test_categorical_sampler.py diff --git a/src/annbatch/__init__.py b/src/annbatch/__init__.py index 39180c0b..7fa9ed2e 100644 --- a/src/annbatch/__init__.py +++ b/src/annbatch/__init__.py @@ -5,6 +5,7 @@ from . import abc, types from .io import DatasetCollection, write_sharded from .loader import Loader +from .samplers._categorical_sampler import CategoricalSampler from .samplers._chunk_sampler import ChunkSampler __version__ = version("annbatch") @@ -14,6 +15,7 @@ "DatasetCollection", "types", "write_sharded", + "CategoricalSampler", "ChunkSampler", "abc", ] diff --git a/src/annbatch/samplers/__init__.py b/src/annbatch/samplers/__init__.py index 9f92bbf0..e963e650 100644 --- a/src/annbatch/samplers/__init__.py +++ b/src/annbatch/samplers/__init__.py @@ -1,5 +1,7 @@ +from ._categorical_sampler import CategoricalSampler from ._chunk_sampler import ChunkSampler __all__ = [ + "CategoricalSampler", "ChunkSampler", ] diff --git a/src/annbatch/samplers/_categorical_sampler.py b/src/annbatch/samplers/_categorical_sampler.py new file mode 100644 index 00000000..3d49333c --- /dev/null +++ b/src/annbatch/samplers/_categorical_sampler.py @@ -0,0 +1,305 @@ +"""Categorical sampler for group-stratified data access.""" + +from __future__ import annotations + +from importlib.util import find_spec +from typing import TYPE_CHECKING + +import numpy as np +import pandas as pd + +from annbatch.abc import Sampler +from annbatch.samplers._chunk_sampler import ChunkSampler +from annbatch.utils import check_lt_1 + +if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + + from annbatch.types import LoadRequest + + +class CategoricalSampler(Sampler): + """Categorical sampler for group-stratified batched data access. + + This sampler ensures each batch contains observations from a single category/group. + It samples from categories proportionally to their size, yielding batches where + all observations belong to the same category. + + The sampler assumes data is sorted by category, with boundaries provided as slices. + For convenience, use :meth:`from_pandas` to construct from a pandas Categorical. + + Parameters + ---------- + category_boundaries + A sequence of slices defining the boundaries for each category. + Each slice represents a contiguous range of observations belonging to one category. + Data must be sorted by category before using this sampler. + Number of categories must be greater than 1 and all boundaries must be in increasing order. + batch_size + Number of observations per batch. + chunk_size + Size of each chunk i.e. the range of each chunk yielded. + shuffle + Whether to shuffle chunk and index order within each category. + preload_nchunks + Number of chunks to load per iteration. + drop_last + Whether to drop the last incomplete batch per category. + rng + Random number generator for shuffling. + + Notes + ----- + This sampler does not support multiple workers. Using it with a DataLoader + that has `num_workers > 0` will raise an error. + + Examples + -------- + Using boundaries directly: + + >>> boundaries = [slice(0, 100), slice(100, 250), slice(250, 400)] + >>> sampler = CategoricalSampler( + ... category_boundaries=boundaries, + ... batch_size=32, + ... chunk_size=64, + ... preload_nchunks=4, + ... ) + + Using from_pandas for convenience: + + >>> import pandas as pd + >>> categories = pd.Categorical(["A", "A", "B", "B", "B", "C"]) + >>> sampler = CategoricalSampler.from_pandas( + ... categories, + ... batch_size=32, + ... chunk_size=64, + ... preload_nchunks=4, + ... ) + """ + + _category_samplers: list[ChunkSampler] + _rng: np.random.Generator + + def __init__( + self, + category_boundaries: Sequence[slice], + chunk_size: int, + preload_nchunks: int, + batch_size: int, + *, + shuffle: bool = False, + drop_last: bool = False, + rng: np.random.Generator | None = None, + ): + check_lt_1([len(category_boundaries)], ["Number of categories"]) + + for i, boundary in enumerate(category_boundaries): + if not isinstance(boundary, slice): + raise TypeError(f"Expected slice for boundary {i}, got {type(boundary)}") + if boundary.step is not None and boundary.step != 1: + raise ValueError(f"Boundary {i} must have step=1 or None, got {boundary.step}") + if boundary.start is None or boundary.stop is None: + raise ValueError(f"Boundary {i} must have explicit start and stop") + if boundary.start >= boundary.stop: + raise ValueError(f"Boundary {i} must have start < stop, got {boundary}") + if i == 0 and boundary.start != 0: + raise ValueError(f"First boundary must start at 0, got {boundary.start}") + if i > 0 and boundary.start != category_boundaries[i - 1].stop: + raise ValueError( + f"Boundaries must be contiguous: boundary {i} starts at {boundary.start} " + f"but boundary {i - 1} ends at {category_boundaries[i - 1].stop}" + ) + + self._rng = rng or np.random.default_rng() + + child_rngs = self._rng.spawn(len(category_boundaries)) + + # Create a ChunkSampler for each category, using its boundary as the mask + self._category_samplers = [ + ChunkSampler( + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + mask=boundary, + shuffle=shuffle, + drop_last=drop_last, + rng=child_rng, + ) + for boundary, child_rng in zip(category_boundaries, child_rngs, strict=True) + ] + + @classmethod + def from_pandas( + cls, + categorical: pd.Categorical | pd.Series, + chunk_size: int, + preload_nchunks: int, + batch_size: int, + *, + shuffle: bool = False, + drop_last: bool = False, + rng: np.random.Generator | None = None, + ) -> CategoricalSampler: + """Create a CategoricalSampler from a pandas Categorical or Series. + + The data is assumed to be sorted by category. This method computes the + boundaries for each category based on where values change. + + Parameters + ---------- + categorical + A pandas Categorical or Series with categorical dtype. + Data must be sorted by category. + chunk_size + Size of each chunk. + preload_nchunks + Number of chunks to load per iteration. + batch_size + Number of observations per batch. + shuffle + Whether to shuffle chunk and index order within each category. + drop_last + Whether to drop the last incomplete batch per category. + rng + Random number generator for shuffling. + + Returns + ------- + CategoricalSampler + A sampler configured with boundaries derived from the categorical. + + Raises + ------ + ValueError + If the data is not sorted by category. + TypeError + If the input is not a Categorical or categorical Series. + + Examples + -------- + >>> import pandas as pd + >>> # Data must be sorted by category + >>> obs_cat = pd.Categorical(["A", "A", "A", "B", "B", "C", "C", "C", "C"]) + >>> sampler = CategoricalSampler.from_pandas( + ... obs_cat, + ... batch_size=2, + ... chunk_size=4, + ... preload_nchunks=2, + ... ) + """ + if isinstance(categorical, pd.Series): + if not isinstance(categorical.dtype, pd.CategoricalDtype): + raise TypeError(f"Expected categorical Series, got {categorical.dtype}") + categorical = categorical.cat + elif not isinstance(categorical, pd.Categorical): + raise TypeError(f"Expected pandas.Categorical or categorical Series, got {type(categorical)}") + + codes = categorical.codes + n_obs = len(codes) + + if n_obs == 0: + raise ValueError("Cannot create sampler from empty categorical") + + # Check if sorted by finding where codes decrease + if np.any(np.diff(codes) < 0): + raise ValueError( + "Data must be sorted by category. Use df.sort_values('category_column') before creating the sampler." + ) + + # Compute boundaries by finding where codes change + # We need to handle the case where some categories might be empty + change_points = np.where(np.diff(codes) != 0)[0] + 1 + starts = np.concatenate([[0], change_points]) + stops = np.concatenate([change_points, [n_obs]]) + + boundaries = [slice(int(start), int(stop)) for start, stop in zip(starts, stops, strict=True)] + + return cls( + category_boundaries=boundaries, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + rng=rng, + ) + + @property + def batch_size(self) -> int: + return self._category_samplers[0].batch_size + + @property + def shuffle(self) -> bool: + return self._category_samplers[0].shuffle + + @property + def n_categories(self) -> int: + """The number of categories in this sampler.""" + return len(self._category_samplers) + + @property + def category_sizes(self) -> list[int]: + """The size (number of observations) for each category.""" + return [s._mask.stop - s._mask.start for s in self._category_samplers] + + def validate(self, n_obs: int) -> None: + """Validate the sampler configuration against the loader's n_obs. + + Parameters + ---------- + n_obs + The total number of observations in the loader. + + Raises + ------ + ValueError + If the sampler configuration is invalid for the given n_obs. + """ + # Validate each category sampler + for sampler in self._category_samplers: + sampler.validate(n_obs) + + # Check for worker usage - CategoricalSampler doesn't support workers + if find_spec("torch"): + from torch.utils.data import get_worker_info + + if get_worker_info() is not None: + raise ValueError( + "CategoricalSampler does not support multiple workers. Use num_workers=0 in your DataLoader." + ) + + def _sample(self, n_obs: int) -> Iterator[LoadRequest]: + """Sample load requests, ensuring each batch is from a single category. + + The sampling strategy: + 1. Create iterators for each category's ChunkSampler + 2. Sample categories proportionally to their remaining size + 3. Yield load requests from sampled category until exhausted + """ + # Create iterators for each category sampler + category_iterators: list[Iterator[LoadRequest] | None] = [ + sampler._sample(n_obs) for sampler in self._category_samplers + ] + # Track remaining sizes (approximate by category size, decremented as we yield) + remaining_sizes = np.array(self.category_sizes, dtype=float) + + # Keep sampling until all categories are exhausted + while np.any(remaining_sizes > 0): + # Sample category with probability proportional to remaining size + probs = remaining_sizes / remaining_sizes.sum() + cat_idx = self._rng.choice(len(category_iterators), p=probs) + + iterator = category_iterators[cat_idx] + if iterator is None: + remaining_sizes[cat_idx] = 0 + continue + + try: + load_request = next(iterator) + # Update remaining size based on what we just yielded + yielded_obs = sum(chunk.stop - chunk.start for chunk in load_request["chunks"]) + remaining_sizes[cat_idx] = max(0, remaining_sizes[cat_idx] - yielded_obs) + yield load_request + except StopIteration: + category_iterators[cat_idx] = None + remaining_sizes[cat_idx] = 0 diff --git a/src/annbatch/samplers/_chunk_sampler.py b/src/annbatch/samplers/_chunk_sampler.py index 17caf410..79b88239 100644 --- a/src/annbatch/samplers/_chunk_sampler.py +++ b/src/annbatch/samplers/_chunk_sampler.py @@ -9,6 +9,7 @@ import numpy as np from annbatch.abc import Sampler +from annbatch.samplers._utils import validate_batch_size from annbatch.utils import check_lt_1, split_given_size if TYPE_CHECKING: @@ -69,18 +70,7 @@ def __init__( raise ValueError("mask.start must be < mask.stop when mask.stop is specified") check_lt_1([chunk_size, preload_nchunks], ["Chunk size", "Preloaded chunks"]) - preload_size = chunk_size * preload_nchunks - - if batch_size > preload_size: - raise ValueError( - "batch_size cannot exceed chunk_size * preload_nchunks. " - f"Got batch_size={batch_size}, but max is {preload_size}." - ) - if preload_size % batch_size != 0: - raise ValueError( - "chunk_size * preload_nchunks must be divisible by batch_size. " - f"Got {preload_size} % {batch_size} = {preload_size % batch_size}." - ) + validate_batch_size(batch_size, chunk_size, preload_nchunks) self._rng = rng or np.random.default_rng() self._batch_size, self._chunk_size, self._shuffle = batch_size, chunk_size, shuffle self._preload_nchunks, self._mask, self._drop_last = ( diff --git a/src/annbatch/samplers/_utils.py b/src/annbatch/samplers/_utils.py new file mode 100644 index 00000000..7d9f1acd --- /dev/null +++ b/src/annbatch/samplers/_utils.py @@ -0,0 +1,34 @@ +"""Utility functions for samplers.""" + +from __future__ import annotations + + +def validate_batch_size(batch_size: int, chunk_size: int, preload_nchunks: int) -> None: + """Validate batch_size against chunk_size and preload_nchunks constraints. + + Parameters + ---------- + batch_size + Number of observations per batch. + chunk_size + Size of each chunk. + preload_nchunks + Number of chunks to preload. + + Raises + ------ + ValueError + If batch_size exceeds preload_size or preload_size is not divisible by batch_size. + """ + preload_size = chunk_size * preload_nchunks + + if batch_size > preload_size: + raise ValueError( + "batch_size cannot exceed chunk_size * preload_nchunks. " + f"Got batch_size={batch_size}, but max is {preload_size}." + ) + if preload_size % batch_size != 0: + raise ValueError( + "chunk_size * preload_nchunks must be divisible by batch_size. " + f"Got {preload_size} % {batch_size} = {preload_size % batch_size}." + ) diff --git a/tests/test_categorical_sampler.py b/tests/test_categorical_sampler.py new file mode 100644 index 00000000..45d22ba2 --- /dev/null +++ b/tests/test_categorical_sampler.py @@ -0,0 +1,470 @@ +"""Tests for CategoricalSampler.""" + +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest + +from annbatch import CategoricalSampler + + +def collect_all_indices(sampler, n_obs): + """Helper to collect all indices from sampler, organized by load request.""" + all_indices = [] + for load_request in sampler.sample(n_obs): + indices_in_request = [] + for chunk in load_request["chunks"]: + indices_in_request.extend(range(chunk.start, chunk.stop)) + all_indices.append(indices_in_request) + return all_indices + + +def collect_flat_indices(sampler, n_obs): + """Helper to collect all indices flattened.""" + indices = [] + for load_request in sampler.sample(n_obs): + for chunk in load_request["chunks"]: + indices.extend(range(chunk.start, chunk.stop)) + return indices + + +# ============================================================================= +# Basic construction tests +# ============================================================================= + + +def test_basic_construction(): + """Test basic CategoricalSampler construction.""" + boundaries = [slice(0, 100), slice(100, 200), slice(200, 300)] + sampler = CategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + ) + assert sampler.batch_size == 10 + assert sampler.n_categories == 3 + assert sampler.category_sizes == [100, 100, 100] + assert sampler.shuffle is False + + +def test_from_pandas_categorical(): + """Test construction from pandas Categorical.""" + # Create sorted categorical data + categories = pd.Categorical(["A"] * 50 + ["B"] * 30 + ["C"] * 20) + sampler = CategoricalSampler.from_pandas( + categories, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + ) + assert sampler.n_categories == 3 + assert sampler.category_sizes == [50, 30, 20] + + +def test_from_pandas_series(): + """Test construction from pandas Series with categorical dtype.""" + series = pd.Series(pd.Categorical(["X"] * 40 + ["Y"] * 60)) + sampler = CategoricalSampler.from_pandas( + series, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + ) + assert sampler.n_categories == 2 + assert sampler.category_sizes == [40, 60] + + +def test_from_pandas_unsorted_raises(): + """Test that unsorted data raises ValueError.""" + categories = pd.Categorical(["A", "B", "A", "B"]) # Not sorted + with pytest.raises(ValueError, match="Data must be sorted"): + CategoricalSampler.from_pandas( + categories, + batch_size=2, + chunk_size=4, + preload_nchunks=1, + ) + + +def test_from_pandas_non_categorical_raises(): + """Test that non-categorical Series raises TypeError.""" + series = pd.Series(["A", "B", "C"]) # Not categorical + with pytest.raises(TypeError, match="Expected categorical"): + CategoricalSampler.from_pandas( + series, + batch_size=2, + chunk_size=4, + preload_nchunks=1, + ) + + +def test_from_pandas_empty_raises(): + """Test that empty categorical raises ValueError.""" + categories = pd.Categorical([]) + with pytest.raises(ValueError, match="empty"): + CategoricalSampler.from_pandas( + categories, + batch_size=2, + chunk_size=4, + preload_nchunks=1, + ) + + +# ============================================================================= +# Boundary validation tests +# ============================================================================= + + +@pytest.mark.parametrize( + "boundaries,error_match", + [ + pytest.param([slice(0, 10), slice(10, 5)], "start < stop", id="start_gte_stop"), + pytest.param([slice(0, 10, 2)], "step=1", id="step_not_one"), + pytest.param([slice(None, 10)], "explicit start and stop", id="none_start"), + pytest.param([slice(0, None)], "explicit start and stop", id="none_stop"), + pytest.param(["not a slice"], "Expected slice", id="not_a_slice"), + pytest.param([slice(5, 15)], "must start at 0", id="not_starting_at_zero"), + pytest.param([slice(0, 10), slice(15, 25)], "contiguous", id="gap_between_boundaries"), + ], +) +def test_invalid_boundary_raises(boundaries, error_match): + """Test that invalid boundaries raise appropriate errors.""" + with pytest.raises((ValueError, TypeError), match=error_match): + CategoricalSampler( + category_boundaries=boundaries, + batch_size=5, + chunk_size=10, + preload_nchunks=1, + ) + + +def test_empty_boundaries_raises(): + """Test that empty boundaries list raises ValueError.""" + with pytest.raises(ValueError): + CategoricalSampler( + category_boundaries=[], + batch_size=5, + chunk_size=10, + preload_nchunks=1, + ) + + +# ============================================================================= +# Coverage tests +# ============================================================================= + + +@pytest.mark.parametrize( + "category_sizes,chunk_size,preload_nchunks,batch_size", + [ + pytest.param([100, 100, 100], 20, 2, 10, id="equal_categories"), + pytest.param([50, 150, 100], 25, 2, 10, id="unequal_categories"), + pytest.param([30, 30, 30], 10, 3, 5, id="small_categories"), + pytest.param([200], 50, 2, 25, id="single_category"), + pytest.param([10, 20, 30, 40], 10, 1, 5, id="many_categories"), + ], +) +def test_coverage_all_indices(category_sizes, chunk_size, preload_nchunks, batch_size): + """Test that sampler covers all indices exactly once.""" + # Build boundaries from category sizes + boundaries = [] + start = 0 + for size in category_sizes: + boundaries.append(slice(start, start + size)) + start += size + n_obs = sum(category_sizes) + + sampler = CategoricalSampler( + category_boundaries=boundaries, + batch_size=batch_size, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + shuffle=False, + ) + + all_indices = collect_flat_indices(sampler, n_obs) + assert set(all_indices) == set(range(n_obs)), "Should cover all indices" + assert len(all_indices) == n_obs, "Should cover each index exactly once" + + +def _get_category_for_index(index: int, boundaries: list[slice]) -> int: + """Helper to find which category an index belongs to.""" + for i, boundary in enumerate(boundaries): + if boundary.start <= index < boundary.stop: + return i + raise ValueError(f"Index {index} not in any category boundary") + + +@pytest.mark.parametrize("shuffle", [False, True]) +@pytest.mark.parametrize( + "boundaries", + [ + pytest.param([slice(0, 100), slice(100, 200), slice(200, 300)], id="equal_categories"), + pytest.param([slice(0, 50), slice(50, 150), slice(150, 300)], id="unequal_categories"), + pytest.param([slice(0, 30), slice(30, 60), slice(60, 90), slice(90, 120)], id="many_categories"), + ], +) +def test_each_load_request_from_single_category(boundaries, shuffle): + """Test that each load request contains chunks from a single category.""" + n_obs = boundaries[-1].stop + + sampler = CategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + shuffle=shuffle, + rng=np.random.default_rng(42), + ) + + for load_request in sampler.sample(n_obs): + # Check all chunks are from the same category + chunks = load_request["chunks"] + if len(chunks) == 0: + continue + + # Find which category this load request belongs to based on first chunk + first_chunk_start = chunks[0].start + category_idx = _get_category_for_index(first_chunk_start, boundaries) + category_boundary = boundaries[category_idx] + + # Verify ALL chunks are within the same category boundary + for chunk in chunks: + assert chunk.start >= category_boundary.start, ( + f"Chunk start {chunk.start} outside category {category_idx} boundary {category_boundary}" + ) + assert chunk.stop <= category_boundary.stop, ( + f"Chunk stop {chunk.stop} outside category {category_idx} boundary {category_boundary}" + ) + + # Also verify all indices in splits map back to the same category + # The splits are indices into the concatenated chunks, so we need to map them back + chunk_indices = [] + for chunk in chunks: + chunk_indices.extend(range(chunk.start, chunk.stop)) + + for split in load_request["splits"]: + for idx in split: + original_idx = chunk_indices[idx] + split_category = _get_category_for_index(original_idx, boundaries) + assert split_category == category_idx, ( + f"Split index {idx} (original {original_idx}) belongs to category {split_category}, " + f"but load request is for category {category_idx}" + ) + + +# ============================================================================= +# Shuffle tests +# ============================================================================= + + +def test_shuffle_changes_order(): + """Test that shuffling changes the order of indices within categories.""" + boundaries = [slice(0, 100), slice(100, 200)] + n_obs = 200 + + sampler_no_shuffle = CategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + shuffle=False, + ) + + sampler_shuffle = CategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + shuffle=True, + rng=np.random.default_rng(42), + ) + + indices_no_shuffle = collect_flat_indices(sampler_no_shuffle, n_obs) + indices_shuffle = collect_flat_indices(sampler_shuffle, n_obs) + + # Both should cover same indices + assert set(indices_no_shuffle) == set(indices_shuffle) + + # But order should differ + assert indices_no_shuffle != indices_shuffle + + +# ============================================================================= +# Validation tests +# ============================================================================= + + +def test_validate_boundary_exceeds_n_obs(): + """Test validation fails when boundary exceeds n_obs.""" + boundaries = [slice(0, 100), slice(100, 300)] # Second boundary goes to 300 + sampler = CategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + ) + + with pytest.raises(ValueError, match="exceeds loader n_obs"): + sampler.validate(n_obs=200) # n_obs is only 200 + + +def test_validate_passes_for_valid_config(): + """Test validation passes for valid configuration.""" + boundaries = [slice(0, 100), slice(100, 200)] + sampler = CategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + ) + # Should not raise + sampler.validate(n_obs=200) + + +# ============================================================================= +# Batch size validation tests +# ============================================================================= + + +@pytest.mark.parametrize( + "batch_size,chunk_size,preload_nchunks,error_match", + [ + pytest.param(100, 10, 2, "batch_size cannot exceed", id="batch_exceeds_preload"), + pytest.param(7, 10, 2, "must be divisible by batch_size", id="not_divisible"), + ], +) +def test_invalid_batch_size_raises(batch_size, chunk_size, preload_nchunks, error_match): + """Test that invalid batch_size configurations raise ValueError.""" + with pytest.raises(ValueError, match=error_match): + CategoricalSampler( + category_boundaries=[slice(0, 100)], + batch_size=batch_size, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + ) + + +# ============================================================================= +# Drop last tests +# ============================================================================= + + +def test_drop_last_removes_incomplete_batches(): + """Test that drop_last removes incomplete batches.""" + # 45 obs, batch_size 10 -> should get 4 complete batches (40 obs) with drop_last=True + boundaries = [slice(0, 45)] + sampler = CategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + drop_last=True, + ) + + total_obs = 0 + for load_request in sampler.sample(45): + for split in load_request["splits"]: + total_obs += len(split) + + assert total_obs == 40, "drop_last should remove incomplete batch" + + +def test_no_drop_last_keeps_incomplete_batches(): + """Test that drop_last=False keeps incomplete batches.""" + boundaries = [slice(0, 45)] + sampler = CategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + drop_last=False, + ) + + total_obs = 0 + for load_request in sampler.sample(45): + for split in load_request["splits"]: + total_obs += len(split) + + assert total_obs == 45, "should keep all observations" + + +# ============================================================================= +# Splits structure tests +# ============================================================================= + + +def test_splits_have_correct_batch_size(): + """Test that splits have correct batch sizes.""" + boundaries = [slice(0, 100)] + sampler = CategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + drop_last=False, + ) + + for load_request in sampler.sample(100): + splits = load_request["splits"] + # All splits except possibly the last should have batch_size elements + for split in splits[:-1]: + assert len(split) == 10 + # Last split should have 1 to batch_size elements + if splits: + assert 1 <= len(splits[-1]) <= 10 + + +# ============================================================================= +# Integration with from_pandas +# ============================================================================= + + +def test_from_pandas_integration(): + """Test full integration with from_pandas and sampling.""" + # Simulate sorted obs column + n_obs = 150 + categories = pd.Categorical(["celltype_A"] * 50 + ["celltype_B"] * 70 + ["celltype_C"] * 30) + + sampler = CategoricalSampler.from_pandas( + categories, + batch_size=10, + chunk_size=25, + preload_nchunks=2, + shuffle=True, + rng=np.random.default_rng(123), + ) + + all_indices = collect_flat_indices(sampler, n_obs) + assert set(all_indices) == set(range(n_obs)) + assert len(all_indices) == n_obs + + +# ============================================================================= +# Reproducibility tests +# ============================================================================= + + +def test_rng_reproducibility(): + """Test that same RNG seed gives same results.""" + boundaries = [slice(0, 100), slice(100, 200)] + + def get_indices(seed): + sampler = CategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + shuffle=True, + rng=np.random.default_rng(seed), + ) + return collect_flat_indices(sampler, 200) + + indices1 = get_indices(42) + indices2 = get_indices(42) + indices3 = get_indices(99) + + assert indices1 == indices2, "Same seed should give same results" + assert indices1 != indices3, "Different seeds should give different results" From 15fa94d8f1eb9265ff854abbe75f17a9f7b2a4ab Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 4 Feb 2026 16:39:12 +0100 Subject: [PATCH 02/16] wip --- src/annbatch/samplers/_categorical_sampler.py | 120 ++++++++++++------ tests/test_categorical_sampler.py | 88 +++++-------- tests/test_sampler.py | 38 +++--- 3 files changed, 137 insertions(+), 109 deletions(-) diff --git a/src/annbatch/samplers/_categorical_sampler.py b/src/annbatch/samplers/_categorical_sampler.py index 3d49333c..a075fb20 100644 --- a/src/annbatch/samplers/_categorical_sampler.py +++ b/src/annbatch/samplers/_categorical_sampler.py @@ -43,8 +43,6 @@ class CategoricalSampler(Sampler): Whether to shuffle chunk and index order within each category. preload_nchunks Number of chunks to load per iteration. - drop_last - Whether to drop the last incomplete batch per category. rng Random number generator for shuffling. @@ -88,7 +86,6 @@ def __init__( batch_size: int, *, shuffle: bool = False, - drop_last: bool = False, rng: np.random.Generator | None = None, ): check_lt_1([len(category_boundaries)], ["Number of categories"]) @@ -115,6 +112,7 @@ def __init__( child_rngs = self._rng.spawn(len(category_boundaries)) # Create a ChunkSampler for each category, using its boundary as the mask + # Always use drop_last=True internally self._category_samplers = [ ChunkSampler( chunk_size=chunk_size, @@ -122,7 +120,7 @@ def __init__( batch_size=batch_size, mask=boundary, shuffle=shuffle, - drop_last=drop_last, + drop_last=True, rng=child_rng, ) for boundary, child_rng in zip(category_boundaries, child_rngs, strict=True) @@ -137,7 +135,6 @@ def from_pandas( batch_size: int, *, shuffle: bool = False, - drop_last: bool = False, rng: np.random.Generator | None = None, ) -> CategoricalSampler: """Create a CategoricalSampler from a pandas Categorical or Series. @@ -158,8 +155,6 @@ def from_pandas( Number of observations per batch. shuffle Whether to shuffle chunk and index order within each category. - drop_last - Whether to drop the last incomplete batch per category. rng Random number generator for shuffling. @@ -220,7 +215,6 @@ def from_pandas( preload_nchunks=preload_nchunks, batch_size=batch_size, shuffle=shuffle, - drop_last=drop_last, rng=rng, ) @@ -272,34 +266,84 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: """Sample load requests, ensuring each batch is from a single category. The sampling strategy: - 1. Create iterators for each category's ChunkSampler - 2. Sample categories proportionally to their remaining size - 3. Yield load requests from sampled category until exhausted + 1. Collect all load requests from each category sampler + 2. Flatten into individual batches: (chunk_group_id, batch_index_within_group) + 3. Shuffle the batch order across categories + 4. Group batches by their chunk_group_id and yield combined load requests """ - # Create iterators for each category sampler - category_iterators: list[Iterator[LoadRequest] | None] = [ - sampler._sample(n_obs) for sampler in self._category_samplers - ] - # Track remaining sizes (approximate by category size, decremented as we yield) - remaining_sizes = np.array(self.category_sizes, dtype=float) - - # Keep sampling until all categories are exhausted - while np.any(remaining_sizes > 0): - # Sample category with probability proportional to remaining size - probs = remaining_sizes / remaining_sizes.sum() - cat_idx = self._rng.choice(len(category_iterators), p=probs) - - iterator = category_iterators[cat_idx] - if iterator is None: - remaining_sizes[cat_idx] = 0 - continue - - try: - load_request = next(iterator) - # Update remaining size based on what we just yielded - yielded_obs = sum(chunk.stop - chunk.start for chunk in load_request["chunks"]) - remaining_sizes[cat_idx] = max(0, remaining_sizes[cat_idx] - yielded_obs) - yield load_request - except StopIteration: - category_iterators[cat_idx] = None - remaining_sizes[cat_idx] = 0 + batch_size = self._category_samplers[0]._batch_size + + # Collect all chunk groups: list of (chunks, n_batches) + # chunk_group_id is the index into this list + all_chunk_groups: list[tuple[list[slice], int]] = [] + + for sampler in self._category_samplers: + for load_request in sampler._sample(n_obs): + chunks = list(load_request["chunks"]) + # Count only non-empty splits (drop_last may produce empty final split) + n_batches = sum(1 for s in load_request["splits"] if len(s) > 0) + if n_batches > 0: + all_chunk_groups.append((chunks, n_batches)) + + if not all_chunk_groups: + return + + # Flatten into individual batches: (chunk_group_id, batch_index_within_group) + all_batches: list[tuple[int, int]] = [] + + for group_id, (chunks, n_batches) in enumerate(all_chunk_groups): + for batch_idx in range(n_batches): + all_batches.append((group_id, batch_idx)) + + if not all_batches: + return + + # Shuffle the batch order + batch_order = np.arange(len(all_batches)) + self._rng.shuffle(batch_order) + + # Group batches that share the same chunk_group_id together + # Yield one load request per unique set of chunk groups + batches_per_load = len(self._category_samplers) + + for i in range(0, len(batch_order), batches_per_load): + selected_batch_indices = batch_order[i : i + batches_per_load] + + # Collect unique chunk groups needed for this load request + # Map: chunk_group_id -> (chunks, list of batch indices within that group) + groups_in_load: dict[int, list[int]] = {} + for batch_idx in selected_batch_indices: + group_id, batch_num = all_batches[batch_idx] + if group_id not in groups_in_load: + groups_in_load[group_id] = [] + groups_in_load[group_id].append(batch_num) + + # Build combined load request + combined_chunks: list[slice] = [] + combined_splits: list[np.ndarray] = [] + + # Track offset for each chunk group in the combined data + group_offsets: dict[int, int] = {} + current_offset = 0 + + # First pass: add chunks and compute offsets + for group_id in groups_in_load: + chunks, _ = all_chunk_groups[group_id] + group_offsets[group_id] = current_offset + combined_chunks.extend(chunks) + current_offset += sum(c.stop - c.start for c in chunks) + + # Second pass: create splits + for batch_idx in selected_batch_indices: + group_id, batch_num = all_batches[batch_idx] + offset = group_offsets[group_id] + + # Create split indices + start_idx = batch_num * batch_size + end_idx = start_idx + batch_size + split_indices = np.arange(start_idx, end_idx) + offset + if self._category_samplers[0]._shuffle: + self._rng.shuffle(split_indices) + combined_splits.append(split_indices) + + yield {"chunks": combined_chunks, "splits": combined_splits} diff --git a/tests/test_categorical_sampler.py b/tests/test_categorical_sampler.py index 45d22ba2..6e690cee 100644 --- a/tests/test_categorical_sampler.py +++ b/tests/test_categorical_sampler.py @@ -21,11 +21,17 @@ def collect_all_indices(sampler, n_obs): def collect_flat_indices(sampler, n_obs): - """Helper to collect all indices flattened.""" + """Helper to collect all indices flattened from splits (the actual batch indices).""" indices = [] for load_request in sampler.sample(n_obs): + # Build chunk indices mapping (indices into concatenated chunk data -> original indices) + chunk_indices = [] for chunk in load_request["chunks"]: - indices.extend(range(chunk.start, chunk.stop)) + chunk_indices.extend(range(chunk.start, chunk.stop)) + # Collect actual batch indices from splits + for split in load_request["splits"]: + for idx in split: + indices.append(chunk_indices[idx]) return indices @@ -206,8 +212,13 @@ def _get_category_for_index(index: int, boundaries: list[slice]) -> int: pytest.param([slice(0, 30), slice(30, 60), slice(60, 90), slice(90, 120)], id="many_categories"), ], ) -def test_each_load_request_from_single_category(boundaries, shuffle): - """Test that each load request contains chunks from a single category.""" +def test_each_split_from_single_category(boundaries, shuffle): + """Test that each split (batch) within a load request is from a single category. + + Note: The CategoricalSampler combines batches from multiple categories into + a single load request for efficiency, but each split within that request + should only contain indices from a single category. + """ n_obs = boundaries[-1].stop sampler = CategoricalSampler( @@ -220,38 +231,31 @@ def test_each_load_request_from_single_category(boundaries, shuffle): ) for load_request in sampler.sample(n_obs): - # Check all chunks are from the same category chunks = load_request["chunks"] if len(chunks) == 0: continue - # Find which category this load request belongs to based on first chunk - first_chunk_start = chunks[0].start - category_idx = _get_category_for_index(first_chunk_start, boundaries) - category_boundary = boundaries[category_idx] - - # Verify ALL chunks are within the same category boundary - for chunk in chunks: - assert chunk.start >= category_boundary.start, ( - f"Chunk start {chunk.start} outside category {category_idx} boundary {category_boundary}" - ) - assert chunk.stop <= category_boundary.stop, ( - f"Chunk stop {chunk.stop} outside category {category_idx} boundary {category_boundary}" - ) - - # Also verify all indices in splits map back to the same category - # The splits are indices into the concatenated chunks, so we need to map them back + # Build mapping from concatenated chunk index to original index chunk_indices = [] for chunk in chunks: chunk_indices.extend(range(chunk.start, chunk.stop)) + # Verify each split contains indices from only one category for split in load_request["splits"]: + if len(split) == 0: + continue + + # Get the category of the first index in this split + first_original_idx = chunk_indices[split[0]] + expected_category = _get_category_for_index(first_original_idx, boundaries) + + # Verify all indices in this split belong to the same category for idx in split: original_idx = chunk_indices[idx] split_category = _get_category_for_index(original_idx, boundaries) - assert split_category == category_idx, ( + assert split_category == expected_category, ( f"Split index {idx} (original {original_idx}) belongs to category {split_category}, " - f"but load request is for category {category_idx}" + f"but expected category {expected_category}" ) @@ -352,35 +356,15 @@ def test_invalid_batch_size_raises(batch_size, chunk_size, preload_nchunks, erro # ============================================================================= -def test_drop_last_removes_incomplete_batches(): - """Test that drop_last removes incomplete batches.""" - # 45 obs, batch_size 10 -> should get 4 complete batches (40 obs) with drop_last=True - boundaries = [slice(0, 45)] - sampler = CategoricalSampler( - category_boundaries=boundaries, - batch_size=10, - chunk_size=20, - preload_nchunks=2, - drop_last=True, - ) - - total_obs = 0 - for load_request in sampler.sample(45): - for split in load_request["splits"]: - total_obs += len(split) - - assert total_obs == 40, "drop_last should remove incomplete batch" - - -def test_no_drop_last_keeps_incomplete_batches(): - """Test that drop_last=False keeps incomplete batches.""" +def test_drop_last_enforced(): + """Test that incomplete batches are always dropped (drop_last is enforced).""" + # 45 obs, batch_size 10 -> should get 4 complete batches (40 obs) boundaries = [slice(0, 45)] sampler = CategoricalSampler( category_boundaries=boundaries, batch_size=10, chunk_size=20, preload_nchunks=2, - drop_last=False, ) total_obs = 0 @@ -388,7 +372,7 @@ def test_no_drop_last_keeps_incomplete_batches(): for split in load_request["splits"]: total_obs += len(split) - assert total_obs == 45, "should keep all observations" + assert total_obs == 40, "should drop incomplete batch" # ============================================================================= @@ -397,24 +381,20 @@ def test_no_drop_last_keeps_incomplete_batches(): def test_splits_have_correct_batch_size(): - """Test that splits have correct batch sizes.""" + """Test that splits have correct batch sizes (all complete batches).""" boundaries = [slice(0, 100)] sampler = CategoricalSampler( category_boundaries=boundaries, batch_size=10, chunk_size=20, preload_nchunks=2, - drop_last=False, ) for load_request in sampler.sample(100): splits = load_request["splits"] - # All splits except possibly the last should have batch_size elements - for split in splits[:-1]: + # All splits should have exactly batch_size elements (drop_last is enforced) + for split in splits: assert len(split) == 10 - # Last split should have 1 to batch_size elements - if splits: - assert 1 <= len(splits[-1]) <= 10 # ============================================================================= diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 2e43cfb9..d08ca20b 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -56,32 +56,35 @@ def _get_worker_handle(self) -> MockWorkerHandle | None: @pytest.mark.parametrize( - "n_obs,chunk_size,start,stop,batch_size,preload_nchunks,shuffle", + "n_obs,chunk_size,start,stop,batch_size,preload_nchunks,shuffle,drop_last", [ # Basic full dataset - pytest.param(100, 10, None, None, 5, 2, False, id="full_dataset"), + pytest.param(100, 10, None, None, 5, 2, False, False, id="full_dataset"), # mask.start only - pytest.param(100, 10, 30, None, 5, 2, False, id="start_at_chunk_boundary"), - pytest.param(100, 10, 35, None, 5, 2, False, id="start_not_at_chunk_boundary"), - pytest.param(120, 12, 90, None, 3, 1, False, id="start_near_end"), - pytest.param(100, 10, 20, None, 5, 2, False, id="start_mask_stop_none"), + pytest.param(100, 10, 30, None, 5, 2, False, False, id="start_at_chunk_boundary"), + pytest.param(100, 10, 35, None, 5, 2, False, False, id="start_not_at_chunk_boundary"), + pytest.param(120, 12, 90, None, 3, 1, False, False, id="start_near_end"), + pytest.param(100, 10, 20, None, 5, 2, False, False, id="start_mask_stop_none"), # mask.stop only - pytest.param(50, 10, None, 50, 5, 2, False, id="stop_at_chunk_boundary"), - pytest.param(47, 10, None, 47, 5, 2, False, id="stop_not_at_chunk_boundary"), + pytest.param(50, 10, None, 50, 5, 2, False, False, id="stop_at_chunk_boundary"), + pytest.param(47, 10, None, 47, 5, 2, False, False, id="stop_not_at_chunk_boundary"), # Both bounds - pytest.param(60, 10, 20, 60, 5, 2, False, id="both_at_chunk_boundaries"), - pytest.param(67, 10, 23, 67, 5, 2, False, id="both_not_at_chunk_boundaries"), - pytest.param(28, 10, 22, 28, 2, 1, False, id="single_chunk_span"), - pytest.param(100, 10, 15, 85, 5, 2, False, id="both_non_aligned"), - pytest.param(100, 10, 20, 80, 5, 2, False, id="both_aligned"), + pytest.param(60, 10, 20, 60, 5, 2, False, False, id="both_at_chunk_boundaries"), + pytest.param(67, 10, 23, 67, 5, 2, False, False, id="both_not_at_chunk_boundaries"), + pytest.param(28, 10, 22, 28, 2, 1, False, False, id="single_chunk_span"), + pytest.param(100, 10, 15, 85, 5, 2, False, False, id="both_non_aligned"), + pytest.param(100, 10, 20, 80, 5, 2, False, False, id="both_aligned"), # Edge cases - pytest.param(100, 10, 95, 100, 10, 1, False, id="very_small_mask"), + pytest.param(100, 10, 95, 100, 10, 1, False, False, id="very_small_mask"), # With shuffle - pytest.param(100, 10, 30, None, 5, 2, True, id="shuffle_with_start"), - pytest.param(75, 10, 25, 75, 5, 2, True, id="shuffle_with_both_bounds"), + pytest.param(100, 10, 30, None, 5, 2, True, False, id="shuffle_with_start"), + pytest.param(75, 10, 25, 75, 5, 2, True, False, id="shuffle_with_both_bounds"), + # drop_last edge cases: remainder less than batch_size + pytest.param(45, 20, None, None, 10, 2, False, True, id="drop_last_remainder_less_than_batch"), + pytest.param(5, 20, None, None, 10, 2, False, True, id="drop_last_total_less_than_batch"), ], ) -def test_mask_coverage(n_obs, chunk_size, start, stop, batch_size, preload_nchunks, shuffle): +def test_mask_coverage(n_obs, chunk_size, start, stop, batch_size, preload_nchunks, shuffle, drop_last): """Test sampler covers exactly the expected range, and ordering is correct when not shuffled.""" sampler = ChunkSampler( mask=slice(start, stop), @@ -89,6 +92,7 @@ def test_mask_coverage(n_obs, chunk_size, start, stop, batch_size, preload_nchun chunk_size=chunk_size, preload_nchunks=preload_nchunks, shuffle=shuffle, + drop_last=drop_last, rng=np.random.default_rng(42) if shuffle else None, ) From 9694e50d16b56813395dcb6eaa28a06066f3cd5b Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 4 Feb 2026 17:06:48 +0100 Subject: [PATCH 03/16] update the tests --- tests/test_sampler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index d08ca20b..05299cc0 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -98,6 +98,10 @@ def test_mask_coverage(n_obs, chunk_size, start, stop, batch_size, preload_nchun expected_start = start if start is not None else 0 expected_stop = stop if stop is not None else n_obs + if drop_last: + # With drop_last, only complete batches are yielded + total_obs = expected_stop - expected_start + expected_stop = expected_start + (total_obs // batch_size) * batch_size expected_indices = list(range(expected_start, expected_stop)) all_indices = collect_indices(sampler, n_obs) From d02c51a2cf840b58729d7eb0e076fa52dfdd0beb Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 4 Feb 2026 17:08:14 +0100 Subject: [PATCH 04/16] fix --- src/annbatch/samplers/_chunk_sampler.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/annbatch/samplers/_chunk_sampler.py b/src/annbatch/samplers/_chunk_sampler.py index 79b88239..61221bf8 100644 --- a/src/annbatch/samplers/_chunk_sampler.py +++ b/src/annbatch/samplers/_chunk_sampler.py @@ -146,20 +146,23 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: chunks = worker_handle.get_part_for_worker(chunks) # Set up the iterator for chunks and the batch indices for splits in_memory_size = self._chunk_size * self._preload_nchunks - chunks_per_batch = split_given_size(chunks, self._preload_nchunks) + chunks_per_request = split_given_size(chunks, self._preload_nchunks) batch_indices = np.arange(in_memory_size) split_batch_indices = split_given_size(batch_indices, self._batch_size) - for batch_chunks in chunks_per_batch[:-1]: + for request_chunks in chunks_per_request[:-1]: if self._shuffle: # Avoid copies using in-place shuffling since `self._shuffle` should not change mid-training np.random.default_rng().shuffle(batch_indices) split_batch_indices = split_given_size(batch_indices, self._batch_size) - yield {"chunks": batch_chunks, "splits": split_batch_indices} + yield {"chunks": request_chunks, "splits": split_batch_indices} # On the last yield, drop the last uneven batch and create new batch_indices since the in-memory size of this last yield could be divisible by batch_size but smaller than preload_nslices * slice_size - final_chunks = chunks_per_batch[-1] + final_chunks = chunks_per_request[-1] total_obs_in_last_batch = int(sum(s.stop - s.start for s in final_chunks)) if self._drop_last: total_obs_in_last_batch -= total_obs_in_last_batch % self._batch_size + # Skip yielding if there are no observations (can happen with drop_last=True and last request is empty) + if total_obs_in_last_batch == 0: + return batch_indices = split_given_size( (np.random.default_rng().permutation if self._shuffle else np.arange)(total_obs_in_last_batch), self._batch_size, From cbeccd9a9a1ac93449484f4cd4dcbeecf58fd857 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Selman=20=C3=96zleyen?= <32667648+selmanozleyen@users.noreply.github.com> Date: Wed, 4 Feb 2026 19:22:19 +0100 Subject: [PATCH 05/16] Update src/annbatch/samplers/_chunk_sampler.py Co-authored-by: Ilan Gold --- src/annbatch/samplers/_chunk_sampler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/annbatch/samplers/_chunk_sampler.py b/src/annbatch/samplers/_chunk_sampler.py index 61221bf8..18e55502 100644 --- a/src/annbatch/samplers/_chunk_sampler.py +++ b/src/annbatch/samplers/_chunk_sampler.py @@ -158,11 +158,13 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: # On the last yield, drop the last uneven batch and create new batch_indices since the in-memory size of this last yield could be divisible by batch_size but smaller than preload_nslices * slice_size final_chunks = chunks_per_request[-1] total_obs_in_last_batch = int(sum(s.stop - s.start for s in final_chunks)) + if total_obs_in_last_batch == 0: # pragma: no cover + raise RuntimeError("Last batch was found to have no observations. Please open an issue.") if self._drop_last: total_obs_in_last_batch -= total_obs_in_last_batch % self._batch_size # Skip yielding if there are no observations (can happen with drop_last=True and last request is empty) - if total_obs_in_last_batch == 0: - return + if total_obs_in_last_batch == 0: + return batch_indices = split_given_size( (np.random.default_rng().permutation if self._shuffle else np.arange)(total_obs_in_last_batch), self._batch_size, From ad72866bb140eb9f855bbb614fbf23cb9dba34fd Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 4 Feb 2026 20:09:41 +0100 Subject: [PATCH 06/16] rename to [tool.pytest.ini_options] --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a63cac86..79dafe73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -150,7 +150,7 @@ lint.per-file-ignores."docs/*" = [ "I" ] lint.per-file-ignores."tests/*" = [ "D" ] lint.pydocstyle.convention = "numpy" -[tool.pytest.ini_option] +[tool.pytest.ini_options] markers = [ "gpu: mark test to run on GPU" ] testpaths = [ "tests" ] xfail_strict = true From 0548b7c0fd4dde133dbd643682cdb022da378d65 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 4 Feb 2026 21:17:19 +0100 Subject: [PATCH 07/16] wip poc --- src/annbatch/__init__.py | 3 +- src/annbatch/samplers/__init__.py | 3 +- src/annbatch/samplers/_categorical_sampler.py | 337 ++++++++++++++++-- src/annbatch/samplers/_chunk_sampler.py | 33 +- src/annbatch/utils.py | 61 +++- tests/test_categorical_sampler.py | 267 +++++++++++++- tests/test_dataset.py | 70 ++++ 7 files changed, 728 insertions(+), 46 deletions(-) diff --git a/src/annbatch/__init__.py b/src/annbatch/__init__.py index 7fa9ed2e..26075709 100644 --- a/src/annbatch/__init__.py +++ b/src/annbatch/__init__.py @@ -5,7 +5,7 @@ from . import abc, types from .io import DatasetCollection, write_sharded from .loader import Loader -from .samplers._categorical_sampler import CategoricalSampler +from .samplers._categorical_sampler import CategoricalSampler, StratifiedCategoricalSampler from .samplers._chunk_sampler import ChunkSampler __version__ = version("annbatch") @@ -17,5 +17,6 @@ "write_sharded", "CategoricalSampler", "ChunkSampler", + "StratifiedCategoricalSampler", "abc", ] diff --git a/src/annbatch/samplers/__init__.py b/src/annbatch/samplers/__init__.py index e963e650..d73829bc 100644 --- a/src/annbatch/samplers/__init__.py +++ b/src/annbatch/samplers/__init__.py @@ -1,7 +1,8 @@ -from ._categorical_sampler import CategoricalSampler +from ._categorical_sampler import CategoricalSampler, StratifiedCategoricalSampler from ._chunk_sampler import ChunkSampler __all__ = [ "CategoricalSampler", "ChunkSampler", + "StratifiedCategoricalSampler", ] diff --git a/src/annbatch/samplers/_categorical_sampler.py b/src/annbatch/samplers/_categorical_sampler.py index a075fb20..212a0ef9 100644 --- a/src/annbatch/samplers/_categorical_sampler.py +++ b/src/annbatch/samplers/_categorical_sampler.py @@ -10,7 +10,7 @@ from annbatch.abc import Sampler from annbatch.samplers._chunk_sampler import ChunkSampler -from annbatch.utils import check_lt_1 +from annbatch.utils import WorkerHandle, check_lt_1 if TYPE_CHECKING: from collections.abc import Iterator, Sequence @@ -126,6 +126,55 @@ def __init__( for boundary, child_rng in zip(category_boundaries, child_rngs, strict=True) ] + @staticmethod + def _boundaries_from_pandas(categorical: pd.Categorical | pd.Series) -> list[slice]: + """Compute category boundaries from a pandas Categorical or Series. + + Parameters + ---------- + categorical + A pandas Categorical or Series with categorical dtype. + Data must be sorted by category. + + Returns + ------- + list[slice] + Boundaries for each category as slices. + + Raises + ------ + ValueError + If the data is not sorted by category or is empty. + TypeError + If the input is not a Categorical or categorical Series. + """ + if isinstance(categorical, pd.Series): + if not isinstance(categorical.dtype, pd.CategoricalDtype): + raise TypeError(f"Expected categorical Series, got {categorical.dtype}") + categorical = categorical.cat + elif not isinstance(categorical, pd.Categorical): + raise TypeError(f"Expected pandas.Categorical or categorical Series, got {type(categorical)}") + + codes = categorical.codes + n_obs = len(codes) + + if n_obs == 0: + raise ValueError("Cannot create sampler from empty categorical") + + # Check if sorted by finding where codes decrease + if np.any(np.diff(codes) < 0): + raise ValueError( + "Data must be sorted by category. Use df.sort_values('category_column') before creating the sampler." + ) + + # Compute boundaries by finding where codes change + # We need to handle the case where some categories might be empty + change_points = np.where(np.diff(codes) != 0)[0] + 1 + starts = np.concatenate([[0], change_points]) + stops = np.concatenate([change_points, [n_obs]]) + + return [slice(int(start), int(stop)) for start, stop in zip(starts, stops, strict=True)] + @classmethod def from_pandas( cls, @@ -182,32 +231,7 @@ def from_pandas( ... preload_nchunks=2, ... ) """ - if isinstance(categorical, pd.Series): - if not isinstance(categorical.dtype, pd.CategoricalDtype): - raise TypeError(f"Expected categorical Series, got {categorical.dtype}") - categorical = categorical.cat - elif not isinstance(categorical, pd.Categorical): - raise TypeError(f"Expected pandas.Categorical or categorical Series, got {type(categorical)}") - - codes = categorical.codes - n_obs = len(codes) - - if n_obs == 0: - raise ValueError("Cannot create sampler from empty categorical") - - # Check if sorted by finding where codes decrease - if np.any(np.diff(codes) < 0): - raise ValueError( - "Data must be sorted by category. Use df.sort_values('category_column') before creating the sampler." - ) - - # Compute boundaries by finding where codes change - # We need to handle the case where some categories might be empty - change_points = np.where(np.diff(codes) != 0)[0] + 1 - starts = np.concatenate([[0], change_points]) - stops = np.concatenate([change_points, [n_obs]]) - - boundaries = [slice(int(start), int(stop)) for start, stop in zip(starts, stops, strict=True)] + boundaries = cls._boundaries_from_pandas(categorical) return cls( category_boundaries=boundaries, @@ -347,3 +371,262 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: combined_splits.append(split_indices) yield {"chunks": combined_chunks, "splits": combined_splits} + + +class StratifiedCategoricalSampler(CategoricalSampler): + """Stratified sampler with uniform default weights and multi-worker support. + + Samples categories according to weights (uniform by default), yielding + n_yields batches total. Supports multi-worker DataLoaders by splitting + n_yields across workers. + + Unlike :class:`CategoricalSampler`, this sampler: + - Samples with replacement (categories reset when exhausted) + - Supports multi-worker DataLoaders + - Uses uniform weights by default (not proportional to category size) + + Parameters + ---------- + category_boundaries + A sequence of slices defining the boundaries for each category. + Each slice represents a contiguous range of observations belonging to one category. + Data must be sorted by category before using this sampler. + chunk_size + Size of each chunk i.e. the range of each chunk yielded. + preload_nchunks + Number of chunks to load per iteration. + batch_size + Number of observations per batch. + n_yields + Total number of batches to yield (split across workers if num_workers > 1). + weights + Sampling weights per category. Default is uniform (equal probability). + Use ``weights=sampler.category_sizes`` for size-proportional sampling. + shuffle + Whether to shuffle chunk and index order within each category. + rng + Random number generator for shuffling. + + Examples + -------- + >>> boundaries = [slice(0, 100), slice(100, 250), slice(250, 400)] + >>> sampler = StratifiedCategoricalSampler( + ... category_boundaries=boundaries, + ... batch_size=32, + ... chunk_size=64, + ... preload_nchunks=4, + ... n_yields=1000, + ... ) + + Using custom weights (e.g., upsample rare categories): + + >>> sampler = StratifiedCategoricalSampler( + ... category_boundaries=boundaries, + ... batch_size=32, + ... chunk_size=64, + ... preload_nchunks=4, + ... n_yields=1000, + ... weights=[1.0, 2.0, 3.0], # Category 2 sampled 3x as often as category 0 + ... ) + """ + + _n_yields: int + _weights: np.ndarray + _seed_seq: np.random.SeedSequence + + def __init__( + self, + category_boundaries: Sequence[slice], + chunk_size: int, + preload_nchunks: int, + batch_size: int, + n_yields: int, + weights: Sequence[float] | None = None, + *, + shuffle: bool = False, + rng: np.random.Generator | None = None, + ): + # Create seed sequence for spawning worker-specific RNGs + if rng is None: + self._seed_seq = np.random.SeedSequence() + rng = np.random.default_rng(self._seed_seq.spawn(1)[0]) + else: + self._seed_seq = np.random.SeedSequence(rng.integers(2**63)) + + super().__init__( + category_boundaries, + chunk_size, + preload_nchunks, + batch_size, + shuffle=shuffle, + rng=rng, + ) + + # Validate n_yields + if n_yields < 1: + raise ValueError("n_yields must be >= 1") + self._n_yields = n_yields + + # Handle weights (uniform by default) + if weights is None: + self._weights = np.ones(self.n_categories, dtype=float) + else: + if len(weights) != self.n_categories: + raise ValueError(f"weights length ({len(weights)}) must match n_categories ({self.n_categories})") + weights = np.asarray(weights, dtype=float) + if np.any(weights < 0): + raise ValueError("weights must be non-negative") + if weights.sum() == 0: + raise ValueError("weights must not sum to zero") + self._weights = weights + + @property + def n_yields(self) -> int: + """Total number of batches to yield.""" + return self._n_yields + + @property + def weights(self) -> np.ndarray: + """Sampling weights for each category (not normalized).""" + return self._weights.copy() + + @property + def probabilities(self) -> np.ndarray: + """Normalized sampling probabilities for each category.""" + return self._weights / self._weights.sum() + + @classmethod + def from_pandas( + cls, + categorical: pd.Categorical | pd.Series, + chunk_size: int, + preload_nchunks: int, + batch_size: int, + n_yields: int, + weights: Sequence[float] | None = None, + *, + shuffle: bool = False, + rng: np.random.Generator | None = None, + ) -> StratifiedCategoricalSampler: + """Create a StratifiedCategoricalSampler from a pandas Categorical or Series. + + The data is assumed to be sorted by category. This method computes the + boundaries for each category based on where values change. + + Parameters + ---------- + categorical + A pandas Categorical or Series with categorical dtype. + Data must be sorted by category. + chunk_size + Size of each chunk. + preload_nchunks + Number of chunks to load per iteration. + batch_size + Number of observations per batch. + n_yields + Total number of batches to yield. + weights + Sampling weights per category. Default is uniform (equal probability). + shuffle + Whether to shuffle chunk and index order within each category. + rng + Random number generator for shuffling. + + Returns + ------- + StratifiedCategoricalSampler + A sampler configured with boundaries derived from the categorical. + """ + boundaries = cls._boundaries_from_pandas(categorical) + + return cls( + category_boundaries=boundaries, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + n_yields=n_yields, + weights=weights, + shuffle=shuffle, + rng=rng, + ) + + def validate(self, n_obs: int) -> None: + """Validate the sampler configuration against the loader's n_obs. + + Unlike CategoricalSampler, this sampler supports multi-worker DataLoaders. + + Parameters + ---------- + n_obs + The total number of observations in the loader. + + Raises + ------ + ValueError + If the sampler configuration is invalid for the given n_obs. + """ + # Validate category samplers (skip parent's worker check) + for sampler in self._category_samplers: + sampler.validate(n_obs) + # NOTE: Multi-worker IS supported for stratified (unlike parent CategoricalSampler) + + def _get_worker_handle(self) -> WorkerHandle | None: + """Get WorkerHandle with seed_seq for worker-specific RNGs.""" + if find_spec("torch"): + from torch.utils.data import get_worker_info + + if get_worker_info() is not None: + return WorkerHandle(seed_seq=self._seed_seq) + return None + + def _sample(self, n_obs: int) -> Iterator[LoadRequest]: + """Sample load requests using stratified sampling with replacement. + + Categories are sampled according to weights (uniform by default). + When a category is exhausted, its iterator is reset (sampling with replacement). + """ + # Get worker handle (contains worker_rng) + worker_handle = self._get_worker_handle() + + if worker_handle is not None: + worker_id = worker_handle.worker_id + num_workers = worker_handle.num_workers + # Split n_yields across workers + worker_n_yields = self._n_yields // num_workers + if worker_id < (self._n_yields % num_workers): + worker_n_yields += 1 + # Use worker-specific RNG from handle + worker_rng = worker_handle.worker_rng + else: + worker_n_yields = self._n_yields + worker_rng = self._rng + + if worker_n_yields == 0: + return + + probs = self.probabilities + category_iters: list[Iterator[LoadRequest] | None] = [None] * self.n_categories + yields_so_far = 0 + + while yields_so_far < worker_n_yields: + # Sample category using worker RNG + cat_idx = int(worker_rng.choice(self.n_categories, p=probs)) + + # Get/reset iterator for this category + if category_iters[cat_idx] is None: + category_iters[cat_idx] = iter(self._category_samplers[cat_idx]._sample(n_obs)) + + try: + load_request = next(category_iters[cat_idx]) + except StopIteration: + # Reset iterator (sample with replacement) + category_iters[cat_idx] = iter(self._category_samplers[cat_idx]._sample(n_obs)) + load_request = next(category_iters[cat_idx]) + + # Yield individual batches from this load request + for split in load_request["splits"]: + if yields_so_far >= worker_n_yields: + return + yield {"chunks": load_request["chunks"], "splits": [split]} + yields_so_far += 1 diff --git a/src/annbatch/samplers/_chunk_sampler.py b/src/annbatch/samplers/_chunk_sampler.py index 18e55502..30454888 100644 --- a/src/annbatch/samplers/_chunk_sampler.py +++ b/src/annbatch/samplers/_chunk_sampler.py @@ -47,6 +47,7 @@ class ChunkSampler(Sampler): _mask: slice _drop_last: bool _rng: np.random.Generator + _seed_seq: np.random.SeedSequence def __init__( self, @@ -71,7 +72,16 @@ def __init__( check_lt_1([chunk_size, preload_nchunks], ["Chunk size", "Preloaded chunks"]) validate_batch_size(batch_size, chunk_size, preload_nchunks) - self._rng = rng or np.random.default_rng() + + # Store seed sequence for spawning worker-specific RNGs + if rng is None: + self._seed_seq = np.random.SeedSequence() + self._rng = np.random.default_rng(self._seed_seq) + else: + # Create seed sequence from provided RNG for worker spawning + self._seed_seq = np.random.SeedSequence(rng.integers(2**63)) + self._rng = rng + self._batch_size, self._chunk_size, self._shuffle = batch_size, chunk_size, shuffle self._preload_nchunks, self._mask, self._drop_last = ( preload_nchunks, @@ -117,7 +127,8 @@ def _get_worker_handle(self) -> WorkerHandle | None: from annbatch.utils import WorkerHandle if get_worker_info() is not None: - worker_handle = WorkerHandle() + # Pass seed_seq so WorkerHandle can spawn worker-specific RNGs + worker_handle = WorkerHandle(seed_seq=self._seed_seq) # Worker mode validation - only check when there are multiple workers # With batch_size=1, every batch is exactly 1 item, so no partial batches exist if ( @@ -136,14 +147,15 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: # Create chunk indices for possible shuffling and worker sharding chunk_indices = np.arange(math.ceil((stop - start) / self._chunk_size)) if self._shuffle: - if worker_handle is None: - self._rng.shuffle(chunk_indices) - else: - worker_handle.shuffle(chunk_indices) + self._rng.shuffle(chunk_indices) chunks = self._compute_chunks(chunk_indices, start, stop) # Worker sharding: each worker gets a disjoint subset of chunks if worker_handle is not None: chunks = worker_handle.get_part_for_worker(chunks) + + # Batch shuffling: use worker-specific RNG (different per worker) + batch_rng = worker_handle.worker_rng if worker_handle is not None else self._rng + # Set up the iterator for chunks and the batch indices for splits in_memory_size = self._chunk_size * self._preload_nchunks chunks_per_request = split_given_size(chunks, self._preload_nchunks) @@ -151,22 +163,21 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: split_batch_indices = split_given_size(batch_indices, self._batch_size) for request_chunks in chunks_per_request[:-1]: if self._shuffle: - # Avoid copies using in-place shuffling since `self._shuffle` should not change mid-training - np.random.default_rng().shuffle(batch_indices) + batch_rng.shuffle(batch_indices) split_batch_indices = split_given_size(batch_indices, self._batch_size) yield {"chunks": request_chunks, "splits": split_batch_indices} # On the last yield, drop the last uneven batch and create new batch_indices since the in-memory size of this last yield could be divisible by batch_size but smaller than preload_nslices * slice_size final_chunks = chunks_per_request[-1] total_obs_in_last_batch = int(sum(s.stop - s.start for s in final_chunks)) - if total_obs_in_last_batch == 0: # pragma: no cover + if total_obs_in_last_batch == 0: # pragma: no cover raise RuntimeError("Last batch was found to have no observations. Please open an issue.") if self._drop_last: total_obs_in_last_batch -= total_obs_in_last_batch % self._batch_size - # Skip yielding if there are no observations (can happen with drop_last=True and last request is empty) + # Skip yielding if there are no observations (can happen with drop_last=True and last request is empty) if total_obs_in_last_batch == 0: return batch_indices = split_given_size( - (np.random.default_rng().permutation if self._shuffle else np.arange)(total_obs_in_last_batch), + (batch_rng.permutation if self._shuffle else np.arange)(total_obs_in_last_batch), self._batch_size, ) yield {"chunks": final_chunks, "splits": batch_indices} diff --git a/src/annbatch/utils.py b/src/annbatch/utils.py index 45c7ea2e..279a8b85 100644 --- a/src/annbatch/utils.py +++ b/src/annbatch/utils.py @@ -103,6 +103,20 @@ def __iter__(self): class WorkerHandle: # noqa: D101 + _seed_seq: np.random.SeedSequence | None + + def __init__(self, seed_seq: np.random.SeedSequence | None = None): + """Initialize WorkerHandle with optional SeedSequence for worker-specific RNGs. + + Parameters + ---------- + seed_seq + Optional SeedSequence for spawning worker-specific RNGs. + If provided, worker_rng will spawn from this sequence. + If None, worker_rng uses torch's worker seed (unique per worker). + """ + self._seed_seq = seed_seq + @cached_property def _worker_info(self): if find_spec("torch"): @@ -118,25 +132,62 @@ def num_workers(self) -> int: return 1 return self._worker_info.num_workers + @property + def worker_id(self) -> int: + """Return the current worker ID, or 0 if not in a worker context.""" + if self._worker_info is None: + return 0 + return self._worker_info.id + @cached_property - def _rng(self): + def _shared_rng(self) -> np.random.Generator: + """RNG with same seed across workers - for chunk ordering.""" if self._worker_info is None: return np.random.default_rng() else: - # This is used for the _get_chunks function - # Use the same seed for all workers that the resulting splits are the same across workers + # Use the same seed for all workers so chunk ordering is deterministic # torch default seed is `base_seed + worker_id`. Hence, subtract worker_id to get the base seed return np.random.default_rng(self._worker_info.seed - self._worker_info.id) + @cached_property + def worker_rng(self) -> np.random.Generator: + """RNG unique to this worker - for batch shuffling and stratified sampling.""" + if self._worker_info is None: + if self._seed_seq is None: + return np.random.default_rng() + return np.random.default_rng(self._seed_seq) + else: + if self._seed_seq is None: + # Use torch's worker seed (already unique per worker) + return np.random.default_rng(self._worker_info.seed) + else: + # Spawn from provided seed sequence + spawned = self._seed_seq.spawn(self._worker_info.num_workers) + return np.random.default_rng(spawned[self._worker_info.id]) + def shuffle(self, obj: np.typing.ArrayLike) -> None: - """Perform in-place shuffle. + """Perform in-place shuffle using shared RNG (same across workers). + + Use this for chunk ordering where all workers need the same order. + + Parameters + ---------- + obj + The object to be shuffled + """ + self._shared_rng.shuffle(obj) + + def shuffle_worker(self, obj: np.typing.ArrayLike) -> None: + """Perform in-place shuffle using worker-specific RNG. + + Use this for batch shuffling where each worker should shuffle differently. Parameters ---------- obj The object to be shuffled """ - self._rng.shuffle(obj) + self.worker_rng.shuffle(obj) def get_part_for_worker(self, obj: np.ndarray) -> np.ndarray: """Get a chunk of an incoming array accordnig to the current worker id. diff --git a/tests/test_categorical_sampler.py b/tests/test_categorical_sampler.py index 6e690cee..0e62fae2 100644 --- a/tests/test_categorical_sampler.py +++ b/tests/test_categorical_sampler.py @@ -6,7 +6,7 @@ import pandas as pd import pytest -from annbatch import CategoricalSampler +from annbatch import CategoricalSampler, StratifiedCategoricalSampler def collect_all_indices(sampler, n_obs): @@ -448,3 +448,268 @@ def get_indices(seed): assert indices1 == indices2, "Same seed should give same results" assert indices1 != indices3, "Different seeds should give different results" + + +# ============================================================================= +# StratifiedCategoricalSampler tests +# ============================================================================= + + +def test_stratified_basic_construction(): + """Test basic StratifiedCategoricalSampler construction.""" + boundaries = [slice(0, 100), slice(100, 200), slice(200, 300)] + sampler = StratifiedCategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + n_yields=50, + ) + assert sampler.batch_size == 10 + assert sampler.n_categories == 3 + assert sampler.n_yields == 50 + assert sampler.shuffle is False + # Default weights are uniform + np.testing.assert_array_equal(sampler.weights, [1.0, 1.0, 1.0]) + + +def test_stratified_custom_weights(): + """Test StratifiedCategoricalSampler with custom weights.""" + boundaries = [slice(0, 100), slice(100, 200), slice(200, 300)] + sampler = StratifiedCategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + n_yields=50, + weights=[1.0, 2.0, 3.0], + ) + np.testing.assert_array_equal(sampler.weights, [1.0, 2.0, 3.0]) + np.testing.assert_array_almost_equal(sampler.probabilities, [1 / 6, 2 / 6, 3 / 6]) + + +def test_stratified_n_yields_count(): + """Test that exactly n_yields batches are yielded.""" + boundaries = [slice(0, 100), slice(100, 200)] + n_yields = 25 + + sampler = StratifiedCategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + n_yields=n_yields, + rng=np.random.default_rng(42), + ) + + batch_count = 0 + for load_request in sampler.sample(200): + batch_count += len(load_request["splits"]) + + assert batch_count == n_yields + + +def test_stratified_n_yields_invalid(): + """Test that n_yields < 1 raises ValueError.""" + boundaries = [slice(0, 100)] + with pytest.raises(ValueError, match="n_yields must be >= 1"): + StratifiedCategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + n_yields=0, + ) + + +def test_stratified_weights_validation(): + """Test weight validation errors.""" + boundaries = [slice(0, 100), slice(100, 200)] + + # Wrong length + with pytest.raises(ValueError, match="weights length"): + StratifiedCategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + n_yields=10, + weights=[1.0], # Only 1 weight for 2 categories + ) + + # Negative weights + with pytest.raises(ValueError, match="non-negative"): + StratifiedCategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + n_yields=10, + weights=[1.0, -1.0], + ) + + # Zero sum + with pytest.raises(ValueError, match="not sum to zero"): + StratifiedCategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + n_yields=10, + weights=[0.0, 0.0], + ) + + +def test_stratified_replacement(): + """Test that categories are reset when exhausted (sampling with replacement).""" + # Small category with only 2 complete batches possible + boundaries = [slice(0, 20)] # 20 obs, batch_size=10, drop_last=True -> 2 batches + n_yields = 10 # Request more than available + + sampler = StratifiedCategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=20, + preload_nchunks=1, + n_yields=n_yields, + rng=np.random.default_rng(42), + ) + + batch_count = 0 + for load_request in sampler.sample(20): + batch_count += len(load_request["splits"]) + + # Should still yield n_yields batches due to replacement + assert batch_count == n_yields + + +def test_stratified_each_batch_single_category(): + """Test that each batch in stratified sampling is from a single category.""" + boundaries = [slice(0, 100), slice(100, 200), slice(200, 300)] + n_obs = 300 + + sampler = StratifiedCategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + n_yields=50, + shuffle=True, + rng=np.random.default_rng(42), + ) + + for load_request in sampler.sample(n_obs): + chunks = load_request["chunks"] + if len(chunks) == 0: + continue + + # Build mapping from concatenated chunk index to original index + chunk_indices = [] + for chunk in chunks: + chunk_indices.extend(range(chunk.start, chunk.stop)) + + # Verify each split contains indices from only one category + for split in load_request["splits"]: + if len(split) == 0: + continue + + first_original_idx = chunk_indices[split[0]] + expected_category = _get_category_for_index(first_original_idx, boundaries) + + for idx in split: + original_idx = chunk_indices[idx] + split_category = _get_category_for_index(original_idx, boundaries) + assert split_category == expected_category + + +def test_stratified_rng_reproducibility(): + """Test that same RNG seed gives same results for stratified sampler.""" + boundaries = [slice(0, 100), slice(100, 200)] + + def get_batches(seed): + sampler = StratifiedCategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + n_yields=20, + shuffle=True, + rng=np.random.default_rng(seed), + ) + return collect_flat_indices(sampler, 200) + + indices1 = get_batches(42) + indices2 = get_batches(42) + indices3 = get_batches(99) + + assert indices1 == indices2, "Same seed should give same results" + assert indices1 != indices3, "Different seeds should give different results" + + +def test_stratified_from_pandas(): + """Test StratifiedCategoricalSampler.from_pandas construction.""" + categories = pd.Categorical(["A"] * 50 + ["B"] * 30 + ["C"] * 20) + + sampler = StratifiedCategoricalSampler.from_pandas( + categories, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + n_yields=30, + ) + + assert sampler.n_categories == 3 + assert sampler.category_sizes == [50, 30, 20] + assert sampler.n_yields == 30 + + +def test_stratified_from_pandas_with_weights(): + """Test StratifiedCategoricalSampler.from_pandas with custom weights.""" + categories = pd.Categorical(["A"] * 50 + ["B"] * 30 + ["C"] * 20) + + sampler = StratifiedCategoricalSampler.from_pandas( + categories, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + n_yields=30, + weights=[3.0, 2.0, 1.0], + ) + + np.testing.assert_array_equal(sampler.weights, [3.0, 2.0, 1.0]) + + +def test_stratified_uniform_weights_distribution(): + """Test that uniform weights sample categories roughly equally.""" + boundaries = [slice(0, 100), slice(100, 200), slice(200, 300)] + n_obs = 300 + n_yields = 300 # Large number for statistical significance + + sampler = StratifiedCategoricalSampler( + category_boundaries=boundaries, + batch_size=10, + chunk_size=20, + preload_nchunks=2, + n_yields=n_yields, + rng=np.random.default_rng(42), + ) + + # Count batches per category + category_counts = [0, 0, 0] + for load_request in sampler.sample(n_obs): + chunks = load_request["chunks"] + chunk_indices = [] + for chunk in chunks: + chunk_indices.extend(range(chunk.start, chunk.stop)) + + for split in load_request["splits"]: + if len(split) > 0: + first_idx = chunk_indices[split[0]] + cat = _get_category_for_index(first_idx, boundaries) + category_counts[cat] += 1 + + # With uniform weights, each category should get roughly 1/3 of batches + # Allow 20% tolerance for randomness + expected = n_yields / 3 + for count in category_counts: + assert abs(count - expected) < expected * 0.3, f"Category count {count} too far from expected {expected}" diff --git a/tests/test_dataset.py b/tests/test_dataset.py index d812c4d2..dcd1482e 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -602,3 +602,73 @@ def test_cannot_provide_batch_sampler_with_sampler_args(kwarg): chunk_sampler = ChunkSampler(mask=slice(0, 50), batch_size=5, chunk_size=10, preload_nchunks=2) with pytest.raises(ValueError, match="Cannot specify.*when providing a custom sampler"): Loader(batch_sampler=chunk_sampler, preload_to_gpu=False, to_torch=False, **kwarg) + + +def test_stratified_categorical_sampler_with_loader( + adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], +): + """Test StratifiedCategoricalSampler integration with Loader. + + Verifies: + - Correct number of batches yielded (n_yields) + - Each batch contains indices from a single category + - Data integrity (indices map to correct data) + """ + from annbatch import StratifiedCategoricalSampler + + paths = list(adata_with_zarr_path_same_var_space[1].glob("*.zarr")) + datas = [open_sparse(p) for p in paths] + + # Calculate total n_obs + total_n_obs = sum(d["dataset"].shape[0] for d in datas) + + # Create category boundaries (split into 3 roughly equal categories) + cat_size = total_n_obs // 3 + boundaries = [ + slice(0, cat_size), + slice(cat_size, 2 * cat_size), + slice(2 * cat_size, total_n_obs), + ] + + n_yields = 15 + batch_size = 10 + + sampler = StratifiedCategoricalSampler( + category_boundaries=boundaries, + batch_size=batch_size, + chunk_size=20, + preload_nchunks=2, + n_yields=n_yields, + rng=np.random.default_rng(42), + ) + + loader = Loader(batch_sampler=sampler, preload_to_gpu=False, to_torch=False, return_index=True) + loader.add_datasets(**concat(datas)) + + # Collect all batches and verify + batch_count = 0 + all_indices = [] + + for batch in loader: + batch_count += 1 + indices = batch["index"] + all_indices.append(indices) + + # Verify batch size + assert len(indices) == batch_size + + # Verify all indices in batch are from the same category + categories_in_batch = set() + for idx in indices: + for i, boundary in enumerate(boundaries): + if boundary.start <= idx < boundary.stop: + categories_in_batch.add(i) + break + assert len(categories_in_batch) == 1, f"Batch contains indices from multiple categories: {categories_in_batch}" + + # Verify correct number of batches + assert batch_count == n_yields, f"Expected {n_yields} batches, got {batch_count}" + + # Verify all indices are valid + all_indices_flat = np.concatenate(all_indices) + assert all(0 <= idx < total_n_obs for idx in all_indices_flat), "Invalid indices in output" From 23dde3b00cd17b3e366f1da3571b1bd03ecb2669 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Selman=20=C3=96zleyen?= <32667648+selmanozleyen@users.noreply.github.com> Date: Thu, 5 Feb 2026 13:06:15 +0100 Subject: [PATCH 08/16] Fix: ChunkSampler yields empty splits when remainder is less than batch_size and drop_last (#132) * add failing tests * update the tests * fix * Update src/annbatch/samplers/_chunk_sampler.py Co-authored-by: Ilan Gold * Update src/annbatch/samplers/_chunk_sampler.py --------- Co-authored-by: Ilan Gold --- src/annbatch/samplers/_chunk_sampler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/annbatch/samplers/_chunk_sampler.py b/src/annbatch/samplers/_chunk_sampler.py index 30454888..6543617d 100644 --- a/src/annbatch/samplers/_chunk_sampler.py +++ b/src/annbatch/samplers/_chunk_sampler.py @@ -172,6 +172,8 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: if total_obs_in_last_batch == 0: # pragma: no cover raise RuntimeError("Last batch was found to have no observations. Please open an issue.") if self._drop_last: + if total_obs_in_last_batch < self._batch_size: + return total_obs_in_last_batch -= total_obs_in_last_batch % self._batch_size # Skip yielding if there are no observations (can happen with drop_last=True and last request is empty) if total_obs_in_last_batch == 0: From a22e9166b15bed2d6cf8c443790a3bb7558e82d9 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Thu, 5 Feb 2026 15:20:24 +0100 Subject: [PATCH 09/16] merge the branch w rng handling --- src/annbatch/samplers/_categorical_sampler.py | 20 ++- src/annbatch/samplers/_chunk_sampler.py | 16 +-- src/annbatch/utils.py | 116 ++++++------------ tests/test_sampler.py | 67 ++++++++-- 4 files changed, 104 insertions(+), 115 deletions(-) diff --git a/src/annbatch/samplers/_categorical_sampler.py b/src/annbatch/samplers/_categorical_sampler.py index 212a0ef9..1a4a3f5a 100644 --- a/src/annbatch/samplers/_categorical_sampler.py +++ b/src/annbatch/samplers/_categorical_sampler.py @@ -10,12 +10,13 @@ from annbatch.abc import Sampler from annbatch.samplers._chunk_sampler import ChunkSampler -from annbatch.utils import WorkerHandle, check_lt_1 +from annbatch.utils import check_lt_1 if TYPE_CHECKING: from collections.abc import Iterator, Sequence from annbatch.types import LoadRequest + from annbatch.utils import WorkerHandle class CategoricalSampler(Sampler): @@ -432,7 +433,6 @@ class StratifiedCategoricalSampler(CategoricalSampler): _n_yields: int _weights: np.ndarray - _seed_seq: np.random.SeedSequence def __init__( self, @@ -446,13 +446,6 @@ def __init__( shuffle: bool = False, rng: np.random.Generator | None = None, ): - # Create seed sequence for spawning worker-specific RNGs - if rng is None: - self._seed_seq = np.random.SeedSequence() - rng = np.random.default_rng(self._seed_seq.spawn(1)[0]) - else: - self._seed_seq = np.random.SeedSequence(rng.integers(2**63)) - super().__init__( category_boundaries, chunk_size, @@ -572,12 +565,14 @@ def validate(self, n_obs: int) -> None: # NOTE: Multi-worker IS supported for stratified (unlike parent CategoricalSampler) def _get_worker_handle(self) -> WorkerHandle | None: - """Get WorkerHandle with seed_seq for worker-specific RNGs.""" + """Get WorkerHandle for worker-specific RNGs.""" if find_spec("torch"): from torch.utils.data import get_worker_info + from annbatch.utils import WorkerHandle + if get_worker_info() is not None: - return WorkerHandle(seed_seq=self._seed_seq) + return WorkerHandle(self._rng) return None def _sample(self, n_obs: int) -> Iterator[LoadRequest]: @@ -586,7 +581,6 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: Categories are sampled according to weights (uniform by default). When a category is exhausted, its iterator is reset (sampling with replacement). """ - # Get worker handle (contains worker_rng) worker_handle = self._get_worker_handle() if worker_handle is not None: @@ -597,7 +591,7 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: if worker_id < (self._n_yields % num_workers): worker_n_yields += 1 # Use worker-specific RNG from handle - worker_rng = worker_handle.worker_rng + worker_rng = worker_handle.rng else: worker_n_yields = self._n_yields worker_rng = self._rng diff --git a/src/annbatch/samplers/_chunk_sampler.py b/src/annbatch/samplers/_chunk_sampler.py index 6543617d..c706431e 100644 --- a/src/annbatch/samplers/_chunk_sampler.py +++ b/src/annbatch/samplers/_chunk_sampler.py @@ -47,7 +47,6 @@ class ChunkSampler(Sampler): _mask: slice _drop_last: bool _rng: np.random.Generator - _seed_seq: np.random.SeedSequence def __init__( self, @@ -73,15 +72,7 @@ def __init__( check_lt_1([chunk_size, preload_nchunks], ["Chunk size", "Preloaded chunks"]) validate_batch_size(batch_size, chunk_size, preload_nchunks) - # Store seed sequence for spawning worker-specific RNGs - if rng is None: - self._seed_seq = np.random.SeedSequence() - self._rng = np.random.default_rng(self._seed_seq) - else: - # Create seed sequence from provided RNG for worker spawning - self._seed_seq = np.random.SeedSequence(rng.integers(2**63)) - self._rng = rng - + self._rng = rng or np.random.default_rng() self._batch_size, self._chunk_size, self._shuffle = batch_size, chunk_size, shuffle self._preload_nchunks, self._mask, self._drop_last = ( preload_nchunks, @@ -127,8 +118,7 @@ def _get_worker_handle(self) -> WorkerHandle | None: from annbatch.utils import WorkerHandle if get_worker_info() is not None: - # Pass seed_seq so WorkerHandle can spawn worker-specific RNGs - worker_handle = WorkerHandle(seed_seq=self._seed_seq) + worker_handle = WorkerHandle(self._rng) # Worker mode validation - only check when there are multiple workers # With batch_size=1, every batch is exactly 1 item, so no partial batches exist if ( @@ -154,7 +144,7 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: chunks = worker_handle.get_part_for_worker(chunks) # Batch shuffling: use worker-specific RNG (different per worker) - batch_rng = worker_handle.worker_rng if worker_handle is not None else self._rng + batch_rng = worker_handle.rng if worker_handle is not None else self._rng # Set up the iterator for chunks and the batch indices for splits in_memory_size = self._chunk_size * self._preload_nchunks diff --git a/src/annbatch/utils.py b/src/annbatch/utils.py index 279a8b85..fb1076e6 100644 --- a/src/annbatch/utils.py +++ b/src/annbatch/utils.py @@ -4,7 +4,7 @@ import itertools import warnings from dataclasses import dataclass -from functools import cached_property, wraps +from functools import wraps from importlib.util import find_spec from typing import TYPE_CHECKING, Concatenate, Protocol @@ -102,95 +102,55 @@ def __iter__(self): total += gap -class WorkerHandle: # noqa: D101 - _seed_seq: np.random.SeedSequence | None +class WorkerHandle: + """Handle for torch DataLoader worker context. - def __init__(self, seed_seq: np.random.SeedSequence | None = None): - """Initialize WorkerHandle with optional SeedSequence for worker-specific RNGs. + This class should only be instantiated inside a torch DataLoader worker process + (i.e., when `torch.utils.data.get_worker_info()` returns a non-None value). + It provides worker-specific RNG and partitioning utilities. - Parameters - ---------- - seed_seq - Optional SeedSequence for spawning worker-specific RNGs. - If provided, worker_rng will spawn from this sequence. - If None, worker_rng uses torch's worker seed (unique per worker). - """ - self._seed_seq = seed_seq + Parameters + ---------- + rng + The RNG to spawn worker-specific RNGs from. If provided, uses its bit_generator's + seed sequence to spawn independent streams for each worker. If None, falls back + to using torch's worker seed. + + The RNG is created using `SeedSequence.spawn()` to ensure each worker has an + independent but reproducible random stream, following numpy's recommended + pattern for parallel random number generation. + """ - @cached_property - def _worker_info(self): - if find_spec("torch"): - from torch.utils.data import get_worker_info + def __init__(self, rng: np.random.Generator | None = None): + """Initialize WorkerHandle. Must be called from within a torch DataLoader worker.""" + from torch.utils.data import get_worker_info - return get_worker_info() - return None + self._worker_info = get_worker_info() + # Each worker gets its own RNG spawned from the sampler's RNG for reproducible batch shuffling + if rng is not None: + bit_generators = rng.bit_generator.spawn(self._worker_info.num_workers) + else: + seq = np.random.SeedSequence(self._worker_info.seed).spawn(self._worker_info.num_workers) + bit_generators = seq + self._rng = np.random.default_rng(bit_generators[self._worker_info.id]) + + @property + def rng(self) -> np.random.Generator: + """Return the RNG for the current worker.""" + return self._rng @property def num_workers(self) -> int: - """Return the number of workers, or 1 if not in a worker context.""" - if self._worker_info is None: - return 1 + """Return the number of workers.""" return self._worker_info.num_workers @property def worker_id(self) -> int: - """Return the current worker ID, or 0 if not in a worker context.""" - if self._worker_info is None: - return 0 + """Return the current worker ID.""" return self._worker_info.id - @cached_property - def _shared_rng(self) -> np.random.Generator: - """RNG with same seed across workers - for chunk ordering.""" - if self._worker_info is None: - return np.random.default_rng() - else: - # Use the same seed for all workers so chunk ordering is deterministic - # torch default seed is `base_seed + worker_id`. Hence, subtract worker_id to get the base seed - return np.random.default_rng(self._worker_info.seed - self._worker_info.id) - - @cached_property - def worker_rng(self) -> np.random.Generator: - """RNG unique to this worker - for batch shuffling and stratified sampling.""" - if self._worker_info is None: - if self._seed_seq is None: - return np.random.default_rng() - return np.random.default_rng(self._seed_seq) - else: - if self._seed_seq is None: - # Use torch's worker seed (already unique per worker) - return np.random.default_rng(self._worker_info.seed) - else: - # Spawn from provided seed sequence - spawned = self._seed_seq.spawn(self._worker_info.num_workers) - return np.random.default_rng(spawned[self._worker_info.id]) - - def shuffle(self, obj: np.typing.ArrayLike) -> None: - """Perform in-place shuffle using shared RNG (same across workers). - - Use this for chunk ordering where all workers need the same order. - - Parameters - ---------- - obj - The object to be shuffled - """ - self._shared_rng.shuffle(obj) - - def shuffle_worker(self, obj: np.typing.ArrayLike) -> None: - """Perform in-place shuffle using worker-specific RNG. - - Use this for batch shuffling where each worker should shuffle differently. - - Parameters - ---------- - obj - The object to be shuffled - """ - self.worker_rng.shuffle(obj) - def get_part_for_worker(self, obj: np.ndarray) -> np.ndarray: - """Get a chunk of an incoming array accordnig to the current worker id. + """Get a chunk of an incoming array according to the current worker id. Parameters ---------- @@ -199,10 +159,8 @@ def get_part_for_worker(self, obj: np.ndarray) -> np.ndarray: Returns ------- - A evenly split part of the ray corresponding to how many workers there are. + An evenly split part of the array corresponding to this worker. """ - if self._worker_info is None: - return obj num_workers, worker_id = self._worker_info.num_workers, self._worker_info.id chunks_split = np.array_split(obj, num_workers) return chunks_split[worker_id] diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 05299cc0..f6d7d7c5 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -26,28 +26,38 @@ def collect_indices(sampler, n_obs): class MockWorkerHandle: """Simulates torch worker context for testing without actual DataLoader.""" - def __init__(self, worker_id: int, num_workers: int, seed: int = 42): + def __init__(self, worker_id: int, num_workers: int, rng: np.random.Generator | None = None): self.worker_id = worker_id self._num_workers = num_workers - self._rng = np.random.default_rng(seed) + # Each worker gets its own RNG spawned from the sampler's RNG (mirrors real WorkerHandle) + if rng is not None: + bit_generators = rng.bit_generator.spawn(num_workers) + else: + bit_generators = np.random.SeedSequence(42).spawn(num_workers) + self._rng = np.random.default_rng(bit_generators[worker_id]) + + @property + def rng(self) -> np.random.Generator: + return self._rng @property def num_workers(self) -> int: return self._num_workers - def shuffle(self, obj): - self._rng.shuffle(obj) - def get_part_for_worker(self, obj: np.ndarray) -> np.ndarray: return np.array_split(obj, self._num_workers)[self.worker_id] class ChunkSamplerWithMockWorkerHandle(ChunkSampler): - def set_worker_handle(self, worker_handle: MockWorkerHandle): - self.worker_handle = worker_handle + def set_mock_worker_info(self, worker_id: int, num_workers: int): + """Set mock worker info. The RNG will be derived from sampler's _rng.""" + self._mock_worker_id = worker_id + self._mock_num_workers = num_workers def _get_worker_handle(self) -> MockWorkerHandle | None: - return self.worker_handle + if hasattr(self, "_mock_worker_id"): + return MockWorkerHandle(self._mock_worker_id, self._mock_num_workers, self._rng) + return None # ============================================================================= @@ -175,7 +185,6 @@ def test_workers_cover_full_dataset_without_overlap( """Test workers cover full dataset without overlap. Also checks if there are empty splits in any of the load requests.""" all_worker_indices = [] for worker_id in range(num_workers): - worker_handle = MockWorkerHandle(worker_id, num_workers) sampler = ChunkSamplerWithMockWorkerHandle( mask=slice(0, None), batch_size=batch_size, @@ -183,7 +192,7 @@ def test_workers_cover_full_dataset_without_overlap( preload_nchunks=preload_nchunks, drop_last=drop_last, ) - sampler.set_worker_handle(worker_handle) + sampler.set_mock_worker_info(worker_id, num_workers) all_worker_indices.append(collect_indices(sampler, n_obs)) # All workers should have disjoint chunks @@ -195,6 +204,44 @@ def test_workers_cover_full_dataset_without_overlap( assert set().union(*all_worker_indices) == set(range(n_obs)) +def test_batch_shuffle_is_reproducible(): + """Test that batch shuffling is reproducible when using ChunkSampler directly. + + This test verifies batch shuffling uses the sampler's seeded `_rng` instead of + an unseeded `np.random.default_rng()`, making it reproducible. + """ + n_obs, chunk_size, preload_nchunks, batch_size = 100, 10, 2, 5 + seed = 42 + + def collect_splits(sampler): + all_splits = [] + for load_request in sampler.sample(n_obs): + for split in load_request["splits"]: + all_splits.append(split.tolist()) + return all_splits + + # Run twice with same seed - should get identical batch ordering + sampler1 = ChunkSampler( + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + shuffle=True, + rng=np.random.default_rng(seed), + ) + splits1 = collect_splits(sampler1) + + sampler2 = ChunkSampler( + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + shuffle=True, + rng=np.random.default_rng(seed), + ) + splits2 = collect_splits(sampler2) + + assert splits1 == splits2, "Batch shuffling should be reproducible with same seed" + + # ============================================================================= # Validation tests # ============================================================================= From 02bfae235084f26a8ac23ef1177461484dc3b51f Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Thu, 5 Feb 2026 15:32:23 +0100 Subject: [PATCH 10/16] code dedup --- src/annbatch/samplers/_categorical_sampler.py | 34 +++++++-------- src/annbatch/samplers/_chunk_sampler.py | 12 +----- src/annbatch/samplers/_utils.py | 43 +++++++++++++++++++ 3 files changed, 60 insertions(+), 29 deletions(-) diff --git a/src/annbatch/samplers/_categorical_sampler.py b/src/annbatch/samplers/_categorical_sampler.py index 1a4a3f5a..bf1643f3 100644 --- a/src/annbatch/samplers/_categorical_sampler.py +++ b/src/annbatch/samplers/_categorical_sampler.py @@ -2,7 +2,6 @@ from __future__ import annotations -from importlib.util import find_spec from typing import TYPE_CHECKING import numpy as np @@ -10,6 +9,7 @@ from annbatch.abc import Sampler from annbatch.samplers._chunk_sampler import ChunkSampler +from annbatch.samplers._utils import get_worker_handle, is_in_worker from annbatch.utils import check_lt_1 if TYPE_CHECKING: @@ -279,13 +279,10 @@ def validate(self, n_obs: int) -> None: sampler.validate(n_obs) # Check for worker usage - CategoricalSampler doesn't support workers - if find_spec("torch"): - from torch.utils.data import get_worker_info - - if get_worker_info() is not None: - raise ValueError( - "CategoricalSampler does not support multiple workers. Use num_workers=0 in your DataLoader." - ) + if is_in_worker(): + raise ValueError( + "CategoricalSampler does not support multiple workers. Use num_workers=0 in your DataLoader." + ) def _sample(self, n_obs: int) -> Iterator[LoadRequest]: """Sample load requests, ensuring each batch is from a single category. @@ -503,8 +500,8 @@ def from_pandas( ) -> StratifiedCategoricalSampler: """Create a StratifiedCategoricalSampler from a pandas Categorical or Series. - The data is assumed to be sorted by category. This method computes the - boundaries for each category based on where values change. + This extends :meth:`CategoricalSampler.from_pandas` with additional + parameters for stratified sampling. Parameters ---------- @@ -530,9 +527,15 @@ def from_pandas( ------- StratifiedCategoricalSampler A sampler configured with boundaries derived from the categorical. + + Raises + ------ + ValueError + If the data is not sorted by category. + TypeError + If the input is not a Categorical or categorical Series. """ boundaries = cls._boundaries_from_pandas(categorical) - return cls( category_boundaries=boundaries, chunk_size=chunk_size, @@ -566,14 +569,7 @@ def validate(self, n_obs: int) -> None: def _get_worker_handle(self) -> WorkerHandle | None: """Get WorkerHandle for worker-specific RNGs.""" - if find_spec("torch"): - from torch.utils.data import get_worker_info - - from annbatch.utils import WorkerHandle - - if get_worker_info() is not None: - return WorkerHandle(self._rng) - return None + return get_worker_handle(self._rng) def _sample(self, n_obs: int) -> Iterator[LoadRequest]: """Sample load requests using stratified sampling with replacement. diff --git a/src/annbatch/samplers/_chunk_sampler.py b/src/annbatch/samplers/_chunk_sampler.py index c706431e..1795c3db 100644 --- a/src/annbatch/samplers/_chunk_sampler.py +++ b/src/annbatch/samplers/_chunk_sampler.py @@ -3,13 +3,12 @@ from __future__ import annotations import math -from importlib.util import find_spec from typing import TYPE_CHECKING import numpy as np from annbatch.abc import Sampler -from annbatch.samplers._utils import validate_batch_size +from annbatch.samplers._utils import get_worker_handle, validate_batch_size from annbatch.utils import check_lt_1, split_given_size if TYPE_CHECKING: @@ -111,14 +110,7 @@ def validate(self, n_obs: int) -> None: raise ValueError(f"Sampler mask.start ({start}) must be < mask.stop ({stop}).") def _get_worker_handle(self) -> WorkerHandle | None: - worker_handle = None - if find_spec("torch"): - from torch.utils.data import get_worker_info - - from annbatch.utils import WorkerHandle - - if get_worker_info() is not None: - worker_handle = WorkerHandle(self._rng) + worker_handle = get_worker_handle(self._rng) # Worker mode validation - only check when there are multiple workers # With batch_size=1, every batch is exactly 1 item, so no partial batches exist if ( diff --git a/src/annbatch/samplers/_utils.py b/src/annbatch/samplers/_utils.py index 7d9f1acd..7778faca 100644 --- a/src/annbatch/samplers/_utils.py +++ b/src/annbatch/samplers/_utils.py @@ -2,6 +2,49 @@ from __future__ import annotations +from importlib.util import find_spec +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + from annbatch.utils import WorkerHandle + + +def is_in_worker() -> bool: + """Check if currently running inside a torch DataLoader worker. + + Returns + ------- + bool + True if inside a DataLoader worker, False otherwise. + """ + if find_spec("torch"): + from torch.utils.data import get_worker_info + + return get_worker_info() is not None + return False + + +def get_worker_handle(rng: np.random.Generator) -> WorkerHandle | None: + """Get a WorkerHandle if running inside a torch DataLoader worker. + + Parameters + ---------- + rng + The RNG to spawn worker-specific RNGs from. + + Returns + ------- + WorkerHandle | None + A WorkerHandle if inside a DataLoader worker, None otherwise. + """ + if is_in_worker(): + from annbatch.utils import WorkerHandle + + return WorkerHandle(rng) + return None + def validate_batch_size(batch_size: int, chunk_size: int, preload_nchunks: int) -> None: """Validate batch_size against chunk_size and preload_nchunks constraints. From cf56db2c53e88b68568b37d09bdfed2531707b8b Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Thu, 5 Feb 2026 15:36:07 +0100 Subject: [PATCH 11/16] accurate docstrings --- src/annbatch/samplers/_categorical_sampler.py | 23 +++++++++++-------- src/annbatch/samplers/_utils.py | 5 ++-- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/annbatch/samplers/_categorical_sampler.py b/src/annbatch/samplers/_categorical_sampler.py index bf1643f3..69006a90 100644 --- a/src/annbatch/samplers/_categorical_sampler.py +++ b/src/annbatch/samplers/_categorical_sampler.py @@ -23,8 +23,9 @@ class CategoricalSampler(Sampler): """Categorical sampler for group-stratified batched data access. This sampler ensures each batch contains observations from a single category/group. - It samples from categories proportionally to their size, yielding batches where - all observations belong to the same category. + It iterates through all categories, yielding all batches exactly once per epoch. + The batch order is shuffled across categories, but each individual batch contains + observations from only one category. The sampler assumes data is sorted by category, with boundaries provided as slices. For convenience, use :meth:`from_pandas` to construct from a pandas Categorical. @@ -288,10 +289,11 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: """Sample load requests, ensuring each batch is from a single category. The sampling strategy: - 1. Collect all load requests from each category sampler + 1. Collect all load requests from each category sampler into chunk groups 2. Flatten into individual batches: (chunk_group_id, batch_index_within_group) - 3. Shuffle the batch order across categories - 4. Group batches by their chunk_group_id and yield combined load requests + 3. Shuffle the batch order across all categories + 4. Group n_categories batches together per load request and + combine chunks from the selected batches into a single load request """ batch_size = self._category_samplers[0]._batch_size @@ -313,7 +315,7 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: # Flatten into individual batches: (chunk_group_id, batch_index_within_group) all_batches: list[tuple[int, int]] = [] - for group_id, (chunks, n_batches) in enumerate(all_chunk_groups): + for group_id, (_, n_batches) in enumerate(all_chunk_groups): for batch_idx in range(n_batches): all_batches.append((group_id, batch_idx)) @@ -372,16 +374,17 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: class StratifiedCategoricalSampler(CategoricalSampler): - """Stratified sampler with uniform default weights and multi-worker support. + """Stratified categorical sampler with configurable weights and multi-worker support. Samples categories according to weights (uniform by default), yielding - n_yields batches total. Supports multi-worker DataLoaders by splitting - n_yields across workers. + exactly ``n_yields`` batches total. Supports multi-worker DataLoaders by splitting + ``n_yields`` across workers. Unlike :class:`CategoricalSampler`, this sampler: + - Yields a fixed number of batches (``n_yields``) rather than exhausting all data - Samples with replacement (categories reset when exhausted) - Supports multi-worker DataLoaders - - Uses uniform weights by default (not proportional to category size) + - Allows configurable sampling weights (uniform by default) Parameters ---------- diff --git a/src/annbatch/samplers/_utils.py b/src/annbatch/samplers/_utils.py index 7778faca..6ce6ae1a 100644 --- a/src/annbatch/samplers/_utils.py +++ b/src/annbatch/samplers/_utils.py @@ -56,12 +56,13 @@ def validate_batch_size(batch_size: int, chunk_size: int, preload_nchunks: int) chunk_size Size of each chunk. preload_nchunks - Number of chunks to preload. + Number of chunks to load per iteration. Raises ------ ValueError - If batch_size exceeds preload_size or preload_size is not divisible by batch_size. + If batch_size exceeds the total loaded size (chunk_size * preload_nchunks) + or if the total loaded size is not divisible by batch_size. """ preload_size = chunk_size * preload_nchunks From 583847182ba6c671aafc6fd83483bb4e617023c7 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Thu, 5 Feb 2026 16:21:09 +0100 Subject: [PATCH 12/16] base algo first step --- src/annbatch/samplers/_categorical_sampler.py | 145 +++++++----------- 1 file changed, 59 insertions(+), 86 deletions(-) diff --git a/src/annbatch/samplers/_categorical_sampler.py b/src/annbatch/samplers/_categorical_sampler.py index 69006a90..6fb68f3e 100644 --- a/src/annbatch/samplers/_categorical_sampler.py +++ b/src/annbatch/samplers/_categorical_sampler.py @@ -2,6 +2,7 @@ from __future__ import annotations +from itertools import batched from typing import TYPE_CHECKING import numpy as np @@ -115,6 +116,10 @@ def __init__( # Create a ChunkSampler for each category, using its boundary as the mask # Always use drop_last=True internally + # also compute the number of batches for each category + self._n_batches_per_category = [ + int((boundary.stop - boundary.start) // batch_size) for boundary in category_boundaries + ] self._category_samplers = [ ChunkSampler( chunk_size=chunk_size, @@ -128,6 +133,11 @@ def __init__( for boundary, child_rng in zip(category_boundaries, child_rngs, strict=True) ] + self._batch_size = batch_size + self._chunk_size = chunk_size + self._preload_nchunks = preload_nchunks + self._shuffle = shuffle + @staticmethod def _boundaries_from_pandas(categorical: pd.Categorical | pd.Series) -> list[slice]: """Compute category boundaries from a pandas Categorical or Series. @@ -257,11 +267,6 @@ def n_categories(self) -> int: """The number of categories in this sampler.""" return len(self._category_samplers) - @property - def category_sizes(self) -> list[int]: - """The size (number of observations) for each category.""" - return [s._mask.stop - s._mask.start for s in self._category_samplers] - def validate(self, n_obs: int) -> None: """Validate the sampler configuration against the loader's n_obs. @@ -285,92 +290,60 @@ def validate(self, n_obs: int) -> None: "CategoricalSampler does not support multiple workers. Use num_workers=0 in your DataLoader." ) - def _sample(self, n_obs: int) -> Iterator[LoadRequest]: - """Sample load requests, ensuring each batch is from a single category. - - The sampling strategy: - 1. Collect all load requests from each category sampler into chunk groups - 2. Flatten into individual batches: (chunk_group_id, batch_index_within_group) - 3. Shuffle the batch order across all categories - 4. Group n_categories batches together per load request and - combine chunks from the selected batches into a single load request - """ - batch_size = self._category_samplers[0]._batch_size + @staticmethod + def _iter_batches( + sampler: ChunkSampler, n_obs: int, batches_per_load: int + ) -> Iterator[tuple[list[slice], np.ndarray]]: + """Yield (chunks, split) for each batch from a sampler. - # Collect all chunk groups: list of (chunks, n_batches) - # chunk_group_id is the index into this list - all_chunk_groups: list[tuple[list[slice], int]] = [] + Each yielded batch contains only the chunks needed for that specific batch, + with split indices adjusted to index into those chunks. - for sampler in self._category_samplers: - for load_request in sampler._sample(n_obs): - chunks = list(load_request["chunks"]) - # Count only non-empty splits (drop_last may produce empty final split) - n_batches = sum(1 for s in load_request["splits"] if len(s) > 0) - if n_batches > 0: - all_chunk_groups.append((chunks, n_batches)) - - if not all_chunk_groups: - return - - # Flatten into individual batches: (chunk_group_id, batch_index_within_group) - all_batches: list[tuple[int, int]] = [] + Parameters + ---------- + sampler + The ChunkSampler to iterate over. + n_obs + Total number of observations. - for group_id, (_, n_batches) in enumerate(all_chunk_groups): - for batch_idx in range(n_batches): - all_batches.append((group_id, batch_idx)) + Yields + ------ + tuple[list[slice], np.ndarray] + (chunks, split) where chunks are the minimal chunks needed and + split contains indices into the concatenated chunks. + """ + for load_request in sampler._sample(n_obs): + chunks = load_request["chunks"] + yield from batched(chunks, batches_per_load) - if not all_batches: - return + def _sample(self, n_obs: int) -> Iterator[LoadRequest]: + """Sample load requests, ensuring each batch is from a single category. - # Shuffle the batch order - batch_order = np.arange(len(all_batches)) - self._rng.shuffle(batch_order) - - # Group batches that share the same chunk_group_id together - # Yield one load request per unique set of chunk groups - batches_per_load = len(self._category_samplers) - - for i in range(0, len(batch_order), batches_per_load): - selected_batch_indices = batch_order[i : i + batches_per_load] - - # Collect unique chunk groups needed for this load request - # Map: chunk_group_id -> (chunks, list of batch indices within that group) - groups_in_load: dict[int, list[int]] = {} - for batch_idx in selected_batch_indices: - group_id, batch_num = all_batches[batch_idx] - if group_id not in groups_in_load: - groups_in_load[group_id] = [] - groups_in_load[group_id].append(batch_num) - - # Build combined load request - combined_chunks: list[slice] = [] - combined_splits: list[np.ndarray] = [] - - # Track offset for each chunk group in the combined data - group_offsets: dict[int, int] = {} - current_offset = 0 - - # First pass: add chunks and compute offsets - for group_id in groups_in_load: - chunks, _ = all_chunk_groups[group_id] - group_offsets[group_id] = current_offset - combined_chunks.extend(chunks) - current_offset += sum(c.stop - c.start for c in chunks) - - # Second pass: create splits - for batch_idx in selected_batch_indices: - group_id, batch_num = all_batches[batch_idx] - offset = group_offsets[group_id] - - # Create split indices - start_idx = batch_num * batch_size - end_idx = start_idx + batch_size - split_indices = np.arange(start_idx, end_idx) + offset - if self._category_samplers[0]._shuffle: - self._rng.shuffle(split_indices) - combined_splits.append(split_indices) - - yield {"chunks": combined_chunks, "splits": combined_splits} + The sampling strategy: + 1. Collect all batches from each category sampler + 2. Shuffle the batch order across all categories + 3. Group n_categories batches together per load request + """ + batches_per_load = int((self._preload_nchunks * self._chunk_size) // self._batch_size) + batch_generators = [self._iter_batches(sampler, n_obs, batches_per_load) for sampler in self._category_samplers] + # simulate the category order + category_order = np.concatenate([np.full(n, i) for i, n in enumerate(self._n_batches_per_category)]) + if self._shuffle: + self._rng.shuffle(category_order) + + # Pre-allocate batch indices array for in-place shuffling + batch_indices = np.arange(self._batch_size) + + for cat_idxs in batched(category_order, batches_per_load): + chunks = [chunk for cat_idx in cat_idxs for chunk in next(batch_generators[cat_idx])] + # Create splits: one per batch, with offset based on batch position + splits = [] + for batch_num in range(len(cat_idxs)): + if self._shuffle: + self._rng.shuffle(batch_indices) + offset = batch_num * self._batch_size + splits.append(batch_indices.copy() + offset) + yield {"chunks": chunks, "splits": splits} class StratifiedCategoricalSampler(CategoricalSampler): From f39957a4ecbbccbeddde963cd68fd7385f9a694a Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Thu, 5 Feb 2026 19:01:00 +0100 Subject: [PATCH 13/16] core algo ready --- src/annbatch/samplers/_categorical_sampler.py | 18 ++-- tests/test_categorical_sampler.py | 97 ++++++++----------- 2 files changed, 54 insertions(+), 61 deletions(-) diff --git a/src/annbatch/samplers/_categorical_sampler.py b/src/annbatch/samplers/_categorical_sampler.py index 6fb68f3e..30b6d8eb 100644 --- a/src/annbatch/samplers/_categorical_sampler.py +++ b/src/annbatch/samplers/_categorical_sampler.py @@ -92,6 +92,10 @@ def __init__( rng: np.random.Generator | None = None, ): check_lt_1([len(category_boundaries)], ["Number of categories"]) + if batch_size < chunk_size: + raise ValueError( + f"batch_size ({batch_size}) cannot be less than chunk_size ({chunk_size}) because each batch must be from one category." + ) for i, boundary in enumerate(category_boundaries): if not isinstance(boundary, slice): @@ -292,7 +296,7 @@ def validate(self, n_obs: int) -> None: @staticmethod def _iter_batches( - sampler: ChunkSampler, n_obs: int, batches_per_load: int + sampler: ChunkSampler, n_obs: int, chunks_per_batch: int ) -> Iterator[tuple[list[slice], np.ndarray]]: """Yield (chunks, split) for each batch from a sampler. @@ -314,7 +318,7 @@ def _iter_batches( """ for load_request in sampler._sample(n_obs): chunks = load_request["chunks"] - yield from batched(chunks, batches_per_load) + yield from batched(chunks, chunks_per_batch) def _sample(self, n_obs: int) -> Iterator[LoadRequest]: """Sample load requests, ensuring each batch is from a single category. @@ -325,14 +329,15 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: 3. Group n_categories batches together per load request """ batches_per_load = int((self._preload_nchunks * self._chunk_size) // self._batch_size) - batch_generators = [self._iter_batches(sampler, n_obs, batches_per_load) for sampler in self._category_samplers] + chunks_per_batch = int(self._batch_size / self._chunk_size) + batch_generators = [self._iter_batches(sampler, n_obs, chunks_per_batch) for sampler in self._category_samplers] # simulate the category order category_order = np.concatenate([np.full(n, i) for i, n in enumerate(self._n_batches_per_category)]) if self._shuffle: self._rng.shuffle(category_order) # Pre-allocate batch indices array for in-place shuffling - batch_indices = np.arange(self._batch_size) + batch_indices = [np.arange(self._batch_size) + i * self._batch_size for i in range(batches_per_load)] for cat_idxs in batched(category_order, batches_per_load): chunks = [chunk for cat_idx in cat_idxs for chunk in next(batch_generators[cat_idx])] @@ -340,9 +345,8 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: splits = [] for batch_num in range(len(cat_idxs)): if self._shuffle: - self._rng.shuffle(batch_indices) - offset = batch_num * self._batch_size - splits.append(batch_indices.copy() + offset) + self._rng.shuffle(batch_indices[batch_num]) + splits.append(batch_indices[batch_num]) yield {"chunks": chunks, "splits": splits} diff --git a/tests/test_categorical_sampler.py b/tests/test_categorical_sampler.py index 0e62fae2..c456f09a 100644 --- a/tests/test_categorical_sampler.py +++ b/tests/test_categorical_sampler.py @@ -45,13 +45,12 @@ def test_basic_construction(): boundaries = [slice(0, 100), slice(100, 200), slice(200, 300)] sampler = CategoricalSampler( category_boundaries=boundaries, - batch_size=10, + batch_size=20, chunk_size=20, preload_nchunks=2, ) - assert sampler.batch_size == 10 + assert sampler.batch_size == 20 assert sampler.n_categories == 3 - assert sampler.category_sizes == [100, 100, 100] assert sampler.shuffle is False @@ -62,11 +61,10 @@ def test_from_pandas_categorical(): sampler = CategoricalSampler.from_pandas( categories, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, ) assert sampler.n_categories == 3 - assert sampler.category_sizes == [50, 30, 20] def test_from_pandas_series(): @@ -75,11 +73,10 @@ def test_from_pandas_series(): sampler = CategoricalSampler.from_pandas( series, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, ) assert sampler.n_categories == 2 - assert sampler.category_sizes == [40, 60] def test_from_pandas_unsorted_raises(): @@ -89,7 +86,7 @@ def test_from_pandas_unsorted_raises(): CategoricalSampler.from_pandas( categories, batch_size=2, - chunk_size=4, + chunk_size=2, preload_nchunks=1, ) @@ -101,7 +98,7 @@ def test_from_pandas_non_categorical_raises(): CategoricalSampler.from_pandas( series, batch_size=2, - chunk_size=4, + chunk_size=2, preload_nchunks=1, ) @@ -113,7 +110,7 @@ def test_from_pandas_empty_raises(): CategoricalSampler.from_pandas( categories, batch_size=2, - chunk_size=4, + chunk_size=2, preload_nchunks=1, ) @@ -140,7 +137,7 @@ def test_invalid_boundary_raises(boundaries, error_match): with pytest.raises((ValueError, TypeError), match=error_match): CategoricalSampler( category_boundaries=boundaries, - batch_size=5, + batch_size=10, chunk_size=10, preload_nchunks=1, ) @@ -151,7 +148,7 @@ def test_empty_boundaries_raises(): with pytest.raises(ValueError): CategoricalSampler( category_boundaries=[], - batch_size=5, + batch_size=10, chunk_size=10, preload_nchunks=1, ) @@ -165,11 +162,11 @@ def test_empty_boundaries_raises(): @pytest.mark.parametrize( "category_sizes,chunk_size,preload_nchunks,batch_size", [ - pytest.param([100, 100, 100], 20, 2, 10, id="equal_categories"), - pytest.param([50, 150, 100], 25, 2, 10, id="unequal_categories"), - pytest.param([30, 30, 30], 10, 3, 5, id="small_categories"), - pytest.param([200], 50, 2, 25, id="single_category"), - pytest.param([10, 20, 30, 40], 10, 1, 5, id="many_categories"), + pytest.param([100, 100, 100], 10, 2, 10, id="equal_categories"), + pytest.param([50, 150, 100], 10, 2, 10, id="unequal_categories"), + pytest.param([30, 30, 30], 5, 3, 5, id="small_categories"), + pytest.param([200], 25, 2, 25, id="single_category"), + pytest.param([10, 20, 30, 40], 5, 1, 5, id="many_categories"), ], ) def test_coverage_all_indices(category_sizes, chunk_size, preload_nchunks, batch_size): @@ -224,7 +221,7 @@ def test_each_split_from_single_category(boundaries, shuffle): sampler = CategoricalSampler( category_boundaries=boundaries, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, shuffle=shuffle, rng=np.random.default_rng(42), @@ -272,7 +269,7 @@ def test_shuffle_changes_order(): sampler_no_shuffle = CategoricalSampler( category_boundaries=boundaries, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, shuffle=False, ) @@ -280,7 +277,7 @@ def test_shuffle_changes_order(): sampler_shuffle = CategoricalSampler( category_boundaries=boundaries, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, shuffle=True, rng=np.random.default_rng(42), @@ -307,7 +304,7 @@ def test_validate_boundary_exceeds_n_obs(): sampler = CategoricalSampler( category_boundaries=boundaries, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, ) @@ -321,7 +318,7 @@ def test_validate_passes_for_valid_config(): sampler = CategoricalSampler( category_boundaries=boundaries, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, ) # Should not raise @@ -333,21 +330,14 @@ def test_validate_passes_for_valid_config(): # ============================================================================= -@pytest.mark.parametrize( - "batch_size,chunk_size,preload_nchunks,error_match", - [ - pytest.param(100, 10, 2, "batch_size cannot exceed", id="batch_exceeds_preload"), - pytest.param(7, 10, 2, "must be divisible by batch_size", id="not_divisible"), - ], -) -def test_invalid_batch_size_raises(batch_size, chunk_size, preload_nchunks, error_match): - """Test that invalid batch_size configurations raise ValueError.""" - with pytest.raises(ValueError, match=error_match): +def test_invalid_batch_size_raises(): + """Test that batch_size < chunk_size raises ValueError.""" + with pytest.raises(ValueError, match="cannot be less than chunk_size"): CategoricalSampler( category_boundaries=[slice(0, 100)], - batch_size=batch_size, - chunk_size=chunk_size, - preload_nchunks=preload_nchunks, + batch_size=5, + chunk_size=10, + preload_nchunks=2, ) @@ -363,7 +353,7 @@ def test_drop_last_enforced(): sampler = CategoricalSampler( category_boundaries=boundaries, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, ) @@ -386,7 +376,7 @@ def test_splits_have_correct_batch_size(): sampler = CategoricalSampler( category_boundaries=boundaries, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, ) @@ -411,7 +401,7 @@ def test_from_pandas_integration(): sampler = CategoricalSampler.from_pandas( categories, batch_size=10, - chunk_size=25, + chunk_size=10, preload_nchunks=2, shuffle=True, rng=np.random.default_rng(123), @@ -435,7 +425,7 @@ def get_indices(seed): sampler = CategoricalSampler( category_boundaries=boundaries, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, shuffle=True, rng=np.random.default_rng(seed), @@ -461,7 +451,7 @@ def test_stratified_basic_construction(): sampler = StratifiedCategoricalSampler( category_boundaries=boundaries, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, n_yields=50, ) @@ -479,7 +469,7 @@ def test_stratified_custom_weights(): sampler = StratifiedCategoricalSampler( category_boundaries=boundaries, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, n_yields=50, weights=[1.0, 2.0, 3.0], @@ -496,7 +486,7 @@ def test_stratified_n_yields_count(): sampler = StratifiedCategoricalSampler( category_boundaries=boundaries, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, n_yields=n_yields, rng=np.random.default_rng(42), @@ -516,7 +506,7 @@ def test_stratified_n_yields_invalid(): StratifiedCategoricalSampler( category_boundaries=boundaries, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, n_yields=0, ) @@ -531,7 +521,7 @@ def test_stratified_weights_validation(): StratifiedCategoricalSampler( category_boundaries=boundaries, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, n_yields=10, weights=[1.0], # Only 1 weight for 2 categories @@ -542,7 +532,7 @@ def test_stratified_weights_validation(): StratifiedCategoricalSampler( category_boundaries=boundaries, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, n_yields=10, weights=[1.0, -1.0], @@ -553,7 +543,7 @@ def test_stratified_weights_validation(): StratifiedCategoricalSampler( category_boundaries=boundaries, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, n_yields=10, weights=[0.0, 0.0], @@ -569,8 +559,8 @@ def test_stratified_replacement(): sampler = StratifiedCategoricalSampler( category_boundaries=boundaries, batch_size=10, - chunk_size=20, - preload_nchunks=1, + chunk_size=10, + preload_nchunks=2, n_yields=n_yields, rng=np.random.default_rng(42), ) @@ -591,7 +581,7 @@ def test_stratified_each_batch_single_category(): sampler = StratifiedCategoricalSampler( category_boundaries=boundaries, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, n_yields=50, shuffle=True, @@ -630,7 +620,7 @@ def get_batches(seed): sampler = StratifiedCategoricalSampler( category_boundaries=boundaries, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, n_yields=20, shuffle=True, @@ -653,13 +643,12 @@ def test_stratified_from_pandas(): sampler = StratifiedCategoricalSampler.from_pandas( categories, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, n_yields=30, ) assert sampler.n_categories == 3 - assert sampler.category_sizes == [50, 30, 20] assert sampler.n_yields == 30 @@ -670,7 +659,7 @@ def test_stratified_from_pandas_with_weights(): sampler = StratifiedCategoricalSampler.from_pandas( categories, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, n_yields=30, weights=[3.0, 2.0, 1.0], @@ -688,7 +677,7 @@ def test_stratified_uniform_weights_distribution(): sampler = StratifiedCategoricalSampler( category_boundaries=boundaries, batch_size=10, - chunk_size=20, + chunk_size=10, preload_nchunks=2, n_yields=n_yields, rng=np.random.default_rng(42), From 450e3d0e30fcbfcb37afc7006c78118714c76cbc Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Fri, 6 Feb 2026 10:54:32 +0100 Subject: [PATCH 14/16] Squashed commit of the following: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit commit 845053eac573526c66e7e85c7ffdeeb08901269a Author: selmanozleyen Date: Fri Feb 6 10:32:33 2026 +0100 blunder fix forgot to add worker_id property commit 1e71f0d4d439d4a2666d0ffa22126b5658279e56 Author: selmanozleyen Date: Fri Feb 6 10:27:40 2026 +0100 fix commit 585898d86df0ea946850a2ef9b262ef1a6ab1b96 Merge: 4df0a95 c9a5971 Author: Selman Özleyen <32667648+selmanozleyen@users.noreply.github.com> Date: Fri Feb 6 10:10:45 2026 +0100 Merge branch 'main' into fix/rng-handling commit 4df0a9563d1751fc138ecf3a1e181ae57997df91 Author: selmanozleyen Date: Fri Feb 6 10:04:07 2026 +0100 accurate docstrings commit e3e57395e61d5e3c97d36b7bd9cbaa9d2e72e597 Author: selmanozleyen Date: Fri Feb 6 10:01:24 2026 +0100 remove redundant comments commit ffb8041fa1baa785c8bb9cae8acdd02bd5e364ac Author: selmanozleyen Date: Fri Feb 6 09:59:57 2026 +0100 remove comment commit 2f24aee927bca497e8bbbb5affe599c07f1b0252 Merge: a23f01e 8cc3df0 Author: Selman Özleyen <32667648+selmanozleyen@users.noreply.github.com> Date: Fri Feb 6 09:58:33 2026 +0100 Merge branch 'main' into fix/rng-handling commit a23f01eb4d3f3d9886f9a9631aae06b78d841fe3 Author: selmanozleyen Date: Fri Feb 6 09:57:29 2026 +0100 update _spawn_worker_rng commit dafba5f75376a29c021f8a7fc7a7df0979665b77 Author: selmanozleyen Date: Fri Feb 6 09:34:56 2026 +0100 dedup and update commit 25d59082347e5de8e65d8ca17c92be0fb65a2602 Author: selmanozleyen Date: Fri Feb 6 09:31:26 2026 +0100 remove get_part_for_worker commit 628362502151c845165b991e33b1333c5368909f Author: Selman Özleyen <32667648+selmanozleyen@users.noreply.github.com> Date: Fri Feb 6 09:23:14 2026 +0100 Update tests/test_sampler.py Co-authored-by: Philipp A. commit bbc5951305b210abfef607c5c50e2e469028b2e2 Author: selmanozleyen Date: Thu Feb 5 14:48:45 2026 +0100 workerhandle takes rng commit 31f46719e7dd997e6924bf6c208b0f209d5f5cde Author: selmanozleyen Date: Thu Feb 5 13:55:29 2026 +0100 update mockworker so it passes all the tests expcept repr commit 1bcbb8887e5042c52e0c21dc3ad327c1a1a15bc1 Author: selmanozleyen Date: Thu Feb 5 13:54:54 2026 +0100 add unit test which fails in main but passes here commit f893c2b3bd902483e75dc2f29b64f61fda864233 Author: selmanozleyen Date: Thu Feb 5 13:44:24 2026 +0100 fix for rng handling --- src/annbatch/samplers/_chunk_sampler.py | 31 +++++++++------ src/annbatch/utils.py | 53 ++++++++++--------------- tests/test_sampler.py | 17 ++------ 3 files changed, 41 insertions(+), 60 deletions(-) diff --git a/src/annbatch/samplers/_chunk_sampler.py b/src/annbatch/samplers/_chunk_sampler.py index 1795c3db..11dfc7bc 100644 --- a/src/annbatch/samplers/_chunk_sampler.py +++ b/src/annbatch/samplers/_chunk_sampler.py @@ -8,7 +8,7 @@ import numpy as np from annbatch.abc import Sampler -from annbatch.samplers._utils import get_worker_handle, validate_batch_size +from annbatch.samplers._utils import validate_batch_size from annbatch.utils import check_lt_1, split_given_size if TYPE_CHECKING: @@ -110,7 +110,14 @@ def validate(self, n_obs: int) -> None: raise ValueError(f"Sampler mask.start ({start}) must be < mask.stop ({stop}).") def _get_worker_handle(self) -> WorkerHandle | None: - worker_handle = get_worker_handle(self._rng) + worker_handle = None + if find_spec("torch"): + from torch.utils.data import get_worker_info + + from annbatch.utils import WorkerHandle + + if get_worker_info() is not None: + worker_handle = WorkerHandle(self._rng) # Worker mode validation - only check when there are multiple workers # With batch_size=1, every batch is exactly 1 item, so no partial batches exist if ( @@ -129,22 +136,23 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: # Create chunk indices for possible shuffling and worker sharding chunk_indices = np.arange(math.ceil((stop - start) / self._chunk_size)) if self._shuffle: + # Use sampler's RNG for chunk ordering - same across all workers self._rng.shuffle(chunk_indices) chunks = self._compute_chunks(chunk_indices, start, stop) # Worker sharding: each worker gets a disjoint subset of chunks + if self._shuffle: + self._rng.shuffle(chunks) if worker_handle is not None: - chunks = worker_handle.get_part_for_worker(chunks) - - # Batch shuffling: use worker-specific RNG (different per worker) - batch_rng = worker_handle.rng if worker_handle is not None else self._rng - + chunks = np.array_split(chunks, worker_handle.num_workers)[worker_handle.worker_id] # Set up the iterator for chunks and the batch indices for splits in_memory_size = self._chunk_size * self._preload_nchunks chunks_per_request = split_given_size(chunks, self._preload_nchunks) batch_indices = np.arange(in_memory_size) split_batch_indices = split_given_size(batch_indices, self._batch_size) + batch_rng = worker_handle.rng if worker_handle is not None else self._rng for request_chunks in chunks_per_request[:-1]: if self._shuffle: + # Avoid copies using in-place shuffling since `self._shuffle` should not change mid-training batch_rng.shuffle(batch_indices) split_batch_indices = split_given_size(batch_indices, self._batch_size) yield {"chunks": request_chunks, "splits": split_batch_indices} @@ -157,13 +165,10 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: if total_obs_in_last_batch < self._batch_size: return total_obs_in_last_batch -= total_obs_in_last_batch % self._batch_size - # Skip yielding if there are no observations (can happen with drop_last=True and last request is empty) - if total_obs_in_last_batch == 0: - return - batch_indices = split_given_size( - (batch_rng.permutation if self._shuffle else np.arange)(total_obs_in_last_batch), - self._batch_size, + indices = ( + batch_rng.permutation(total_obs_in_last_batch) if self._shuffle else np.arange(total_obs_in_last_batch) ) + batch_indices = split_given_size(indices, self._batch_size) yield {"chunks": final_chunks, "splits": batch_indices} def _compute_chunks(self, chunk_indices: np.ndarray, start: int, stop: int) -> list[slice]: diff --git a/src/annbatch/utils.py b/src/annbatch/utils.py index fb1076e6..eaadb844 100644 --- a/src/annbatch/utils.py +++ b/src/annbatch/utils.py @@ -5,7 +5,6 @@ import warnings from dataclasses import dataclass from functools import wraps -from importlib.util import find_spec from typing import TYPE_CHECKING, Concatenate, Protocol import anndata as ad @@ -102,37 +101,41 @@ def __iter__(self): total += gap +def _spawn_worker_rng( + rng: np.random.Generator | None, + worker_id: int, +) -> np.random.Generator: + """Create a worker-specific RNG using the sequence-of-integers seeding pattern. + + Uses NumPy's recommended approach for multi-process RNG. See: + https://numpy.org/doc/stable/reference/random/parallel.html#sequence-of-integer-seeds + """ + if rng is not None: + root_seed = rng.integers(np.iinfo(np.int64).max) + return np.random.default_rng([worker_id, root_seed]) + else: + return np.random.default_rng() + + class WorkerHandle: """Handle for torch DataLoader worker context. This class should only be instantiated inside a torch DataLoader worker process (i.e., when `torch.utils.data.get_worker_info()` returns a non-None value). - It provides worker-specific RNG and partitioning utilities. + It provides worker-specific RNGs. Parameters ---------- rng - The RNG to spawn worker-specific RNGs from. If provided, uses its bit_generator's - seed sequence to spawn independent streams for each worker. If None, falls back - to using torch's worker seed. + The RNG to derive worker-specific RNGs from. If None, uses a fresh unseeded RNG. - The RNG is created using `SeedSequence.spawn()` to ensure each worker has an - independent but reproducible random stream, following numpy's recommended - pattern for parallel random number generation. """ def __init__(self, rng: np.random.Generator | None = None): - """Initialize WorkerHandle. Must be called from within a torch DataLoader worker.""" from torch.utils.data import get_worker_info self._worker_info = get_worker_info() - # Each worker gets its own RNG spawned from the sampler's RNG for reproducible batch shuffling - if rng is not None: - bit_generators = rng.bit_generator.spawn(self._worker_info.num_workers) - else: - seq = np.random.SeedSequence(self._worker_info.seed).spawn(self._worker_info.num_workers) - bit_generators = seq - self._rng = np.random.default_rng(bit_generators[self._worker_info.id]) + self._rng = _spawn_worker_rng(rng, self._worker_info.id) @property def rng(self) -> np.random.Generator: @@ -146,25 +149,9 @@ def num_workers(self) -> int: @property def worker_id(self) -> int: - """Return the current worker ID.""" + """Worker ID.""" return self._worker_info.id - def get_part_for_worker(self, obj: np.ndarray) -> np.ndarray: - """Get a chunk of an incoming array according to the current worker id. - - Parameters - ---------- - obj - Incoming array - - Returns - ------- - An evenly split part of the array corresponding to this worker. - """ - num_workers, worker_id = self._worker_info.num_workers, self._worker_info.id - chunks_split = np.array_split(obj, num_workers) - return chunks_split[worker_id] - def check_lt_1(vals: list[int], obs: list[str]) -> None: """Raise a ValueError if any of the values are less than one. diff --git a/tests/test_sampler.py b/tests/test_sampler.py index f6d7d7c5..e60193f3 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -7,6 +7,7 @@ from annbatch import ChunkSampler from annbatch.abc import Sampler +from annbatch.utils import _spawn_worker_rng # TODO(selmanozleyen): Check for the validation within the _get_worker_handle method. Mock worker handle wouldn't make sense # but overall one must also think about how validation can't be independent of the worker handle. @@ -29,12 +30,7 @@ class MockWorkerHandle: def __init__(self, worker_id: int, num_workers: int, rng: np.random.Generator | None = None): self.worker_id = worker_id self._num_workers = num_workers - # Each worker gets its own RNG spawned from the sampler's RNG (mirrors real WorkerHandle) - if rng is not None: - bit_generators = rng.bit_generator.spawn(num_workers) - else: - bit_generators = np.random.SeedSequence(42).spawn(num_workers) - self._rng = np.random.default_rng(bit_generators[worker_id]) + self._rng = _spawn_worker_rng(rng, worker_id) @property def rng(self) -> np.random.Generator: @@ -44,9 +40,6 @@ def rng(self) -> np.random.Generator: def num_workers(self) -> int: return self._num_workers - def get_part_for_worker(self, obj: np.ndarray) -> np.ndarray: - return np.array_split(obj, self._num_workers)[self.worker_id] - class ChunkSamplerWithMockWorkerHandle(ChunkSampler): def set_mock_worker_info(self, worker_id: int, num_workers: int): @@ -205,11 +198,7 @@ def test_workers_cover_full_dataset_without_overlap( def test_batch_shuffle_is_reproducible(): - """Test that batch shuffling is reproducible when using ChunkSampler directly. - - This test verifies batch shuffling uses the sampler's seeded `_rng` instead of - an unseeded `np.random.default_rng()`, making it reproducible. - """ + """Test that batch shuffling is reproducible when using ChunkSampler directly.""" n_obs, chunk_size, preload_nchunks, batch_size = 100, 10, 2, 5 seed = 42 From 667c737a21cb5477ed74109064193de3111acd7f Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Fri, 6 Feb 2026 10:58:23 +0100 Subject: [PATCH 15/16] add find spec --- src/annbatch/samplers/_chunk_sampler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/annbatch/samplers/_chunk_sampler.py b/src/annbatch/samplers/_chunk_sampler.py index 11dfc7bc..86a68f20 100644 --- a/src/annbatch/samplers/_chunk_sampler.py +++ b/src/annbatch/samplers/_chunk_sampler.py @@ -3,6 +3,7 @@ from __future__ import annotations import math +from importlib.util import find_spec from typing import TYPE_CHECKING import numpy as np From 6e20323ed0ebfc6a4d2aaf37499c52a16af90a44 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Fri, 6 Feb 2026 11:46:32 +0100 Subject: [PATCH 16/16] remove stratified sampler to start --- src/annbatch/__init__.py | 3 +- src/annbatch/samplers/__init__.py | 3 +- src/annbatch/samplers/_categorical_sampler.py | 341 ++---------------- tests/test_categorical_sampler.py | 266 +------------- 4 files changed, 34 insertions(+), 579 deletions(-) diff --git a/src/annbatch/__init__.py b/src/annbatch/__init__.py index 26075709..7fa9ed2e 100644 --- a/src/annbatch/__init__.py +++ b/src/annbatch/__init__.py @@ -5,7 +5,7 @@ from . import abc, types from .io import DatasetCollection, write_sharded from .loader import Loader -from .samplers._categorical_sampler import CategoricalSampler, StratifiedCategoricalSampler +from .samplers._categorical_sampler import CategoricalSampler from .samplers._chunk_sampler import ChunkSampler __version__ = version("annbatch") @@ -17,6 +17,5 @@ "write_sharded", "CategoricalSampler", "ChunkSampler", - "StratifiedCategoricalSampler", "abc", ] diff --git a/src/annbatch/samplers/__init__.py b/src/annbatch/samplers/__init__.py index d73829bc..e963e650 100644 --- a/src/annbatch/samplers/__init__.py +++ b/src/annbatch/samplers/__init__.py @@ -1,8 +1,7 @@ -from ._categorical_sampler import CategoricalSampler, StratifiedCategoricalSampler +from ._categorical_sampler import CategoricalSampler from ._chunk_sampler import ChunkSampler __all__ = [ "CategoricalSampler", "ChunkSampler", - "StratifiedCategoricalSampler", ] diff --git a/src/annbatch/samplers/_categorical_sampler.py b/src/annbatch/samplers/_categorical_sampler.py index 30b6d8eb..f7b224c2 100644 --- a/src/annbatch/samplers/_categorical_sampler.py +++ b/src/annbatch/samplers/_categorical_sampler.py @@ -10,14 +10,13 @@ from annbatch.abc import Sampler from annbatch.samplers._chunk_sampler import ChunkSampler -from annbatch.samplers._utils import get_worker_handle, is_in_worker +from annbatch.samplers._utils import is_in_worker from annbatch.utils import check_lt_1 if TYPE_CHECKING: from collections.abc import Iterator, Sequence from annbatch.types import LoadRequest - from annbatch.utils import WorkerHandle class CategoricalSampler(Sampler): @@ -96,24 +95,7 @@ def __init__( raise ValueError( f"batch_size ({batch_size}) cannot be less than chunk_size ({chunk_size}) because each batch must be from one category." ) - - for i, boundary in enumerate(category_boundaries): - if not isinstance(boundary, slice): - raise TypeError(f"Expected slice for boundary {i}, got {type(boundary)}") - if boundary.step is not None and boundary.step != 1: - raise ValueError(f"Boundary {i} must have step=1 or None, got {boundary.step}") - if boundary.start is None or boundary.stop is None: - raise ValueError(f"Boundary {i} must have explicit start and stop") - if boundary.start >= boundary.stop: - raise ValueError(f"Boundary {i} must have start < stop, got {boundary}") - if i == 0 and boundary.start != 0: - raise ValueError(f"First boundary must start at 0, got {boundary.start}") - if i > 0 and boundary.start != category_boundaries[i - 1].stop: - raise ValueError( - f"Boundaries must be contiguous: boundary {i} starts at {boundary.start} " - f"but boundary {i - 1} ends at {category_boundaries[i - 1].stop}" - ) - + self._validate_boundaries(category_boundaries) self._rng = rng or np.random.default_rng() child_rngs = self._rng.spawn(len(category_boundaries)) @@ -142,6 +124,24 @@ def __init__( self._preload_nchunks = preload_nchunks self._shuffle = shuffle + def _validate_boundaries(self, category_boundaries: Sequence[slice]) -> None: + for i, boundary in enumerate(category_boundaries): + if not isinstance(boundary, slice): + raise TypeError(f"Expected slice for boundary {i}, got {type(boundary)}") + if boundary.step is not None and boundary.step != 1: + raise ValueError(f"Boundary {i} must have step=1 or None, got {boundary.step}") + if boundary.start is None or boundary.stop is None: + raise ValueError(f"Boundary {i} must have explicit start and stop") + if boundary.start >= boundary.stop: + raise ValueError(f"Boundary {i} must have start < stop, got {boundary}") + if i == 0 and boundary.start != 0: + raise ValueError(f"First boundary must start at 0, got {boundary.start}") + if i > 0 and boundary.start != category_boundaries[i - 1].stop: + raise ValueError( + f"Boundaries must be contiguous: boundary {i} starts at {boundary.start} " + f"but boundary {i - 1} ends at {category_boundaries[i - 1].stop}" + ) + @staticmethod def _boundaries_from_pandas(categorical: pd.Categorical | pd.Series) -> list[slice]: """Compute category boundaries from a pandas Categorical or Series. @@ -290,44 +290,18 @@ def validate(self, n_obs: int) -> None: # Check for worker usage - CategoricalSampler doesn't support workers if is_in_worker(): - raise ValueError( - "CategoricalSampler does not support multiple workers. Use num_workers=0 in your DataLoader." - ) + raise ValueError("CategoricalSampler does not support multiple workers.") @staticmethod def _iter_batches( sampler: ChunkSampler, n_obs: int, chunks_per_batch: int ) -> Iterator[tuple[list[slice], np.ndarray]]: - """Yield (chunks, split) for each batch from a sampler. - - Each yielded batch contains only the chunks needed for that specific batch, - with split indices adjusted to index into those chunks. - - Parameters - ---------- - sampler - The ChunkSampler to iterate over. - n_obs - Total number of observations. - - Yields - ------ - tuple[list[slice], np.ndarray] - (chunks, split) where chunks are the minimal chunks needed and - split contains indices into the concatenated chunks. - """ + """Yield per batch given a sampler.""" for load_request in sampler._sample(n_obs): chunks = load_request["chunks"] yield from batched(chunks, chunks_per_batch) def _sample(self, n_obs: int) -> Iterator[LoadRequest]: - """Sample load requests, ensuring each batch is from a single category. - - The sampling strategy: - 1. Collect all batches from each category sampler - 2. Shuffle the batch order across all categories - 3. Group n_categories batches together per load request - """ batches_per_load = int((self._preload_nchunks * self._chunk_size) // self._batch_size) chunks_per_batch = int(self._batch_size / self._chunk_size) batch_generators = [self._iter_batches(sampler, n_obs, chunks_per_batch) for sampler in self._category_samplers] @@ -336,267 +310,14 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: if self._shuffle: self._rng.shuffle(category_order) - # Pre-allocate batch indices array for in-place shuffling - batch_indices = [np.arange(self._batch_size) + i * self._batch_size for i in range(batches_per_load)] + # pre-allocate and reshape to batches_per_load x batch_size + # so that we can shuffle with numpy all at once + batch_indices = np.arange(batches_per_load * self._batch_size).reshape(batches_per_load, self._batch_size) for cat_idxs in batched(category_order, batches_per_load): - chunks = [chunk for cat_idx in cat_idxs for chunk in next(batch_generators[cat_idx])] - # Create splits: one per batch, with offset based on batch position - splits = [] - for batch_num in range(len(cat_idxs)): - if self._shuffle: - self._rng.shuffle(batch_indices[batch_num]) - splits.append(batch_indices[batch_num]) - yield {"chunks": chunks, "splits": splits} - - -class StratifiedCategoricalSampler(CategoricalSampler): - """Stratified categorical sampler with configurable weights and multi-worker support. - - Samples categories according to weights (uniform by default), yielding - exactly ``n_yields`` batches total. Supports multi-worker DataLoaders by splitting - ``n_yields`` across workers. - - Unlike :class:`CategoricalSampler`, this sampler: - - Yields a fixed number of batches (``n_yields``) rather than exhausting all data - - Samples with replacement (categories reset when exhausted) - - Supports multi-worker DataLoaders - - Allows configurable sampling weights (uniform by default) - - Parameters - ---------- - category_boundaries - A sequence of slices defining the boundaries for each category. - Each slice represents a contiguous range of observations belonging to one category. - Data must be sorted by category before using this sampler. - chunk_size - Size of each chunk i.e. the range of each chunk yielded. - preload_nchunks - Number of chunks to load per iteration. - batch_size - Number of observations per batch. - n_yields - Total number of batches to yield (split across workers if num_workers > 1). - weights - Sampling weights per category. Default is uniform (equal probability). - Use ``weights=sampler.category_sizes`` for size-proportional sampling. - shuffle - Whether to shuffle chunk and index order within each category. - rng - Random number generator for shuffling. - - Examples - -------- - >>> boundaries = [slice(0, 100), slice(100, 250), slice(250, 400)] - >>> sampler = StratifiedCategoricalSampler( - ... category_boundaries=boundaries, - ... batch_size=32, - ... chunk_size=64, - ... preload_nchunks=4, - ... n_yields=1000, - ... ) - - Using custom weights (e.g., upsample rare categories): - - >>> sampler = StratifiedCategoricalSampler( - ... category_boundaries=boundaries, - ... batch_size=32, - ... chunk_size=64, - ... preload_nchunks=4, - ... n_yields=1000, - ... weights=[1.0, 2.0, 3.0], # Category 2 sampled 3x as often as category 0 - ... ) - """ - - _n_yields: int - _weights: np.ndarray - - def __init__( - self, - category_boundaries: Sequence[slice], - chunk_size: int, - preload_nchunks: int, - batch_size: int, - n_yields: int, - weights: Sequence[float] | None = None, - *, - shuffle: bool = False, - rng: np.random.Generator | None = None, - ): - super().__init__( - category_boundaries, - chunk_size, - preload_nchunks, - batch_size, - shuffle=shuffle, - rng=rng, - ) - - # Validate n_yields - if n_yields < 1: - raise ValueError("n_yields must be >= 1") - self._n_yields = n_yields - - # Handle weights (uniform by default) - if weights is None: - self._weights = np.ones(self.n_categories, dtype=float) - else: - if len(weights) != self.n_categories: - raise ValueError(f"weights length ({len(weights)}) must match n_categories ({self.n_categories})") - weights = np.asarray(weights, dtype=float) - if np.any(weights < 0): - raise ValueError("weights must be non-negative") - if weights.sum() == 0: - raise ValueError("weights must not sum to zero") - self._weights = weights - - @property - def n_yields(self) -> int: - """Total number of batches to yield.""" - return self._n_yields - - @property - def weights(self) -> np.ndarray: - """Sampling weights for each category (not normalized).""" - return self._weights.copy() - - @property - def probabilities(self) -> np.ndarray: - """Normalized sampling probabilities for each category.""" - return self._weights / self._weights.sum() - - @classmethod - def from_pandas( - cls, - categorical: pd.Categorical | pd.Series, - chunk_size: int, - preload_nchunks: int, - batch_size: int, - n_yields: int, - weights: Sequence[float] | None = None, - *, - shuffle: bool = False, - rng: np.random.Generator | None = None, - ) -> StratifiedCategoricalSampler: - """Create a StratifiedCategoricalSampler from a pandas Categorical or Series. - - This extends :meth:`CategoricalSampler.from_pandas` with additional - parameters for stratified sampling. - - Parameters - ---------- - categorical - A pandas Categorical or Series with categorical dtype. - Data must be sorted by category. - chunk_size - Size of each chunk. - preload_nchunks - Number of chunks to load per iteration. - batch_size - Number of observations per batch. - n_yields - Total number of batches to yield. - weights - Sampling weights per category. Default is uniform (equal probability). - shuffle - Whether to shuffle chunk and index order within each category. - rng - Random number generator for shuffling. - - Returns - ------- - StratifiedCategoricalSampler - A sampler configured with boundaries derived from the categorical. - - Raises - ------ - ValueError - If the data is not sorted by category. - TypeError - If the input is not a Categorical or categorical Series. - """ - boundaries = cls._boundaries_from_pandas(categorical) - return cls( - category_boundaries=boundaries, - chunk_size=chunk_size, - preload_nchunks=preload_nchunks, - batch_size=batch_size, - n_yields=n_yields, - weights=weights, - shuffle=shuffle, - rng=rng, - ) - - def validate(self, n_obs: int) -> None: - """Validate the sampler configuration against the loader's n_obs. - - Unlike CategoricalSampler, this sampler supports multi-worker DataLoaders. - - Parameters - ---------- - n_obs - The total number of observations in the loader. - - Raises - ------ - ValueError - If the sampler configuration is invalid for the given n_obs. - """ - # Validate category samplers (skip parent's worker check) - for sampler in self._category_samplers: - sampler.validate(n_obs) - # NOTE: Multi-worker IS supported for stratified (unlike parent CategoricalSampler) - - def _get_worker_handle(self) -> WorkerHandle | None: - """Get WorkerHandle for worker-specific RNGs.""" - return get_worker_handle(self._rng) - - def _sample(self, n_obs: int) -> Iterator[LoadRequest]: - """Sample load requests using stratified sampling with replacement. - - Categories are sampled according to weights (uniform by default). - When a category is exhausted, its iterator is reset (sampling with replacement). - """ - worker_handle = self._get_worker_handle() - - if worker_handle is not None: - worker_id = worker_handle.worker_id - num_workers = worker_handle.num_workers - # Split n_yields across workers - worker_n_yields = self._n_yields // num_workers - if worker_id < (self._n_yields % num_workers): - worker_n_yields += 1 - # Use worker-specific RNG from handle - worker_rng = worker_handle.rng - else: - worker_n_yields = self._n_yields - worker_rng = self._rng - - if worker_n_yields == 0: - return - - probs = self.probabilities - category_iters: list[Iterator[LoadRequest] | None] = [None] * self.n_categories - yields_so_far = 0 - - while yields_so_far < worker_n_yields: - # Sample category using worker RNG - cat_idx = int(worker_rng.choice(self.n_categories, p=probs)) - - # Get/reset iterator for this category - if category_iters[cat_idx] is None: - category_iters[cat_idx] = iter(self._category_samplers[cat_idx]._sample(n_obs)) - - try: - load_request = next(category_iters[cat_idx]) - except StopIteration: - # Reset iterator (sample with replacement) - category_iters[cat_idx] = iter(self._category_samplers[cat_idx]._sample(n_obs)) - load_request = next(category_iters[cat_idx]) - - # Yield individual batches from this load request - for split in load_request["splits"]: - if yields_so_far >= worker_n_yields: - return - yield {"chunks": load_request["chunks"], "splits": [split]} - yields_so_far += 1 + if self._shuffle: + batch_indices = self._rng.permuted(batch_indices, axis=1) + yield { + "chunks": [chunk for cat_idx in cat_idxs for chunk in next(batch_generators[cat_idx])], + "splits": list(batch_indices[: len(cat_idxs)]), + } diff --git a/tests/test_categorical_sampler.py b/tests/test_categorical_sampler.py index c456f09a..6746e924 100644 --- a/tests/test_categorical_sampler.py +++ b/tests/test_categorical_sampler.py @@ -6,7 +6,7 @@ import pandas as pd import pytest -from annbatch import CategoricalSampler, StratifiedCategoricalSampler +from annbatch import CategoricalSampler def collect_all_indices(sampler, n_obs): @@ -438,267 +438,3 @@ def get_indices(seed): assert indices1 == indices2, "Same seed should give same results" assert indices1 != indices3, "Different seeds should give different results" - - -# ============================================================================= -# StratifiedCategoricalSampler tests -# ============================================================================= - - -def test_stratified_basic_construction(): - """Test basic StratifiedCategoricalSampler construction.""" - boundaries = [slice(0, 100), slice(100, 200), slice(200, 300)] - sampler = StratifiedCategoricalSampler( - category_boundaries=boundaries, - batch_size=10, - chunk_size=10, - preload_nchunks=2, - n_yields=50, - ) - assert sampler.batch_size == 10 - assert sampler.n_categories == 3 - assert sampler.n_yields == 50 - assert sampler.shuffle is False - # Default weights are uniform - np.testing.assert_array_equal(sampler.weights, [1.0, 1.0, 1.0]) - - -def test_stratified_custom_weights(): - """Test StratifiedCategoricalSampler with custom weights.""" - boundaries = [slice(0, 100), slice(100, 200), slice(200, 300)] - sampler = StratifiedCategoricalSampler( - category_boundaries=boundaries, - batch_size=10, - chunk_size=10, - preload_nchunks=2, - n_yields=50, - weights=[1.0, 2.0, 3.0], - ) - np.testing.assert_array_equal(sampler.weights, [1.0, 2.0, 3.0]) - np.testing.assert_array_almost_equal(sampler.probabilities, [1 / 6, 2 / 6, 3 / 6]) - - -def test_stratified_n_yields_count(): - """Test that exactly n_yields batches are yielded.""" - boundaries = [slice(0, 100), slice(100, 200)] - n_yields = 25 - - sampler = StratifiedCategoricalSampler( - category_boundaries=boundaries, - batch_size=10, - chunk_size=10, - preload_nchunks=2, - n_yields=n_yields, - rng=np.random.default_rng(42), - ) - - batch_count = 0 - for load_request in sampler.sample(200): - batch_count += len(load_request["splits"]) - - assert batch_count == n_yields - - -def test_stratified_n_yields_invalid(): - """Test that n_yields < 1 raises ValueError.""" - boundaries = [slice(0, 100)] - with pytest.raises(ValueError, match="n_yields must be >= 1"): - StratifiedCategoricalSampler( - category_boundaries=boundaries, - batch_size=10, - chunk_size=10, - preload_nchunks=2, - n_yields=0, - ) - - -def test_stratified_weights_validation(): - """Test weight validation errors.""" - boundaries = [slice(0, 100), slice(100, 200)] - - # Wrong length - with pytest.raises(ValueError, match="weights length"): - StratifiedCategoricalSampler( - category_boundaries=boundaries, - batch_size=10, - chunk_size=10, - preload_nchunks=2, - n_yields=10, - weights=[1.0], # Only 1 weight for 2 categories - ) - - # Negative weights - with pytest.raises(ValueError, match="non-negative"): - StratifiedCategoricalSampler( - category_boundaries=boundaries, - batch_size=10, - chunk_size=10, - preload_nchunks=2, - n_yields=10, - weights=[1.0, -1.0], - ) - - # Zero sum - with pytest.raises(ValueError, match="not sum to zero"): - StratifiedCategoricalSampler( - category_boundaries=boundaries, - batch_size=10, - chunk_size=10, - preload_nchunks=2, - n_yields=10, - weights=[0.0, 0.0], - ) - - -def test_stratified_replacement(): - """Test that categories are reset when exhausted (sampling with replacement).""" - # Small category with only 2 complete batches possible - boundaries = [slice(0, 20)] # 20 obs, batch_size=10, drop_last=True -> 2 batches - n_yields = 10 # Request more than available - - sampler = StratifiedCategoricalSampler( - category_boundaries=boundaries, - batch_size=10, - chunk_size=10, - preload_nchunks=2, - n_yields=n_yields, - rng=np.random.default_rng(42), - ) - - batch_count = 0 - for load_request in sampler.sample(20): - batch_count += len(load_request["splits"]) - - # Should still yield n_yields batches due to replacement - assert batch_count == n_yields - - -def test_stratified_each_batch_single_category(): - """Test that each batch in stratified sampling is from a single category.""" - boundaries = [slice(0, 100), slice(100, 200), slice(200, 300)] - n_obs = 300 - - sampler = StratifiedCategoricalSampler( - category_boundaries=boundaries, - batch_size=10, - chunk_size=10, - preload_nchunks=2, - n_yields=50, - shuffle=True, - rng=np.random.default_rng(42), - ) - - for load_request in sampler.sample(n_obs): - chunks = load_request["chunks"] - if len(chunks) == 0: - continue - - # Build mapping from concatenated chunk index to original index - chunk_indices = [] - for chunk in chunks: - chunk_indices.extend(range(chunk.start, chunk.stop)) - - # Verify each split contains indices from only one category - for split in load_request["splits"]: - if len(split) == 0: - continue - - first_original_idx = chunk_indices[split[0]] - expected_category = _get_category_for_index(first_original_idx, boundaries) - - for idx in split: - original_idx = chunk_indices[idx] - split_category = _get_category_for_index(original_idx, boundaries) - assert split_category == expected_category - - -def test_stratified_rng_reproducibility(): - """Test that same RNG seed gives same results for stratified sampler.""" - boundaries = [slice(0, 100), slice(100, 200)] - - def get_batches(seed): - sampler = StratifiedCategoricalSampler( - category_boundaries=boundaries, - batch_size=10, - chunk_size=10, - preload_nchunks=2, - n_yields=20, - shuffle=True, - rng=np.random.default_rng(seed), - ) - return collect_flat_indices(sampler, 200) - - indices1 = get_batches(42) - indices2 = get_batches(42) - indices3 = get_batches(99) - - assert indices1 == indices2, "Same seed should give same results" - assert indices1 != indices3, "Different seeds should give different results" - - -def test_stratified_from_pandas(): - """Test StratifiedCategoricalSampler.from_pandas construction.""" - categories = pd.Categorical(["A"] * 50 + ["B"] * 30 + ["C"] * 20) - - sampler = StratifiedCategoricalSampler.from_pandas( - categories, - batch_size=10, - chunk_size=10, - preload_nchunks=2, - n_yields=30, - ) - - assert sampler.n_categories == 3 - assert sampler.n_yields == 30 - - -def test_stratified_from_pandas_with_weights(): - """Test StratifiedCategoricalSampler.from_pandas with custom weights.""" - categories = pd.Categorical(["A"] * 50 + ["B"] * 30 + ["C"] * 20) - - sampler = StratifiedCategoricalSampler.from_pandas( - categories, - batch_size=10, - chunk_size=10, - preload_nchunks=2, - n_yields=30, - weights=[3.0, 2.0, 1.0], - ) - - np.testing.assert_array_equal(sampler.weights, [3.0, 2.0, 1.0]) - - -def test_stratified_uniform_weights_distribution(): - """Test that uniform weights sample categories roughly equally.""" - boundaries = [slice(0, 100), slice(100, 200), slice(200, 300)] - n_obs = 300 - n_yields = 300 # Large number for statistical significance - - sampler = StratifiedCategoricalSampler( - category_boundaries=boundaries, - batch_size=10, - chunk_size=10, - preload_nchunks=2, - n_yields=n_yields, - rng=np.random.default_rng(42), - ) - - # Count batches per category - category_counts = [0, 0, 0] - for load_request in sampler.sample(n_obs): - chunks = load_request["chunks"] - chunk_indices = [] - for chunk in chunks: - chunk_indices.extend(range(chunk.start, chunk.stop)) - - for split in load_request["splits"]: - if len(split) > 0: - first_idx = chunk_indices[split[0]] - cat = _get_category_for_index(first_idx, boundaries) - category_counts[cat] += 1 - - # With uniform weights, each category should get roughly 1/3 of batches - # Allow 20% tolerance for randomness - expected = n_yields / 3 - for count in category_counts: - assert abs(count - expected) < expected * 0.3, f"Category count {count} too far from expected {expected}"