Skip to content

feat(accelerator): numba colocalization (pearson/manders/rwc/overlap)#60

Draft
timtreis wants to merge 14 commits into
feat/bzyx-shapefrom
feat/numba-coloc
Draft

feat(accelerator): numba colocalization (pearson/manders/rwc/overlap)#60
timtreis wants to merge 14 commits into
feat/bzyx-shapefrom
feat/numba-coloc

Conversation

@timtreis

@timtreis timtreis commented Jun 3, 2026

Copy link
Copy Markdown
Collaborator

Numba backends for the colocalization correlation features — pearson, manders_fold, rwc, and the ride-along overlap. On the to_bzyx base (#59), sibling to #56/#57/#58.

Approach

One grouped pair-flatten + one fused per-object kernel (coloc_per_object) replace the reference's per-object (N,H,W) boolean-stack + scipy.ndimage passes. Label prep is a single np.bincount (labels_to_offsets) feeding a single-scatter flatten. Every feature is a function of the per-object value vectors only, so the kernels never branch on 2D/3D and the (1,Y,X)-vs-(H,W) divergence that affects the intensity backend cannot occur here.

bzyx: colocalization is a (pixels_1, pixels_2, masks) triple, normalised by calling to_bzyx twice on the shared mask and reusing the single unwrap.

Speedups (1080², 144 objects, float pixels, JIT warmed)

feature numpy numba speedup
pearson 232 ms 7.4 ms 31.6×
manders_fold 454 ms 7.3 ms 62.0×
overlap 496 ms 7.3 ms 67.7×
rwc 560 ms 100 ms 5.6×

rwc is sort-bound — its per-object dense-rank argsort is intrinsic to a rank metric (4 exact alternatives measured slower or no-op).

Notes

  • overlap is not in the numpy _CORRELATION registry, so the numba correlation registry intentionally exposes one feature the numpy one does not (flagged in _numba_registries).
  • Pixels upcast to float64 — strictly more accurate than the reference on integer-dtype input (uint8 fi*si overflow; float32 lstsq slope), so golden tests use integer-valued float64 for rank-tie coverage. Real float images unaffected.
  • costes is a stacked follow-up.

Tests: 34 added (golden vs numpy: 2D/3D, single/batch, continuous + rank-tie; kernel units incl. labels_to_offsets). Full suite 114 passed, lint clean.

timtreis and others added 13 commits June 2, 2026 04:47
First real accelerator end-to-end on top of the merged #49 dispatch:
`set_accelerator("numba")` now routes `intensity` to a numba implementation
and composes it with the numpy backend for every other feature.

- _detect.py: capability flags (HAS_NUMBA/HAS_JAX/HAS_JAX_GPU) via find_spec,
  resolved once at import. No try/except — an absent backend is never attempted,
  a present-but-broken one raises.
- primitives/: shared host segment layer. flatten_labeled reduces a labeled
  (Z,Y,X) image to flat (values, seg0, coords); a single kernel set then covers
  2D, 3D and future batches with no image/batch axis baked in. max_position is a
  host scipy.ndimage.maximum_position call for bit-exact parity with the numpy
  backend's tie-break.
- primitives/_segment_numba.py: @njit(cache=True), single-threaded kernels —
  fused single-pass moments + centroid cross-sums, residual-sumsq std, CSR
  per-segment quantiles/MAD.
- core/numba/: import-selected backend (`from cp_measure.core.numba import
  get_intensity`); identical dict contract, 2D and 3D.
- bulk._dispatch: "numba" composes numba intensity + numpy rest; raises if numba
  is not installed (no silent fallback).
- numba is an optional extra ([numba]); the default install stays numba-free.
  CI tests install .[numba] and run the correctness harness.

test/test_backend_correctness.py asserts numba == numpy (2D/3D, edge on/off,
rtol=1e-6), the dispatch composition, and the absent-numba raise path.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Move Location_MaxIntensity_* out of the host scipy per-object call and into the
fused segment_moments kernel via a deterministic `>=`-last argmax (records the
max pixel's coordinates in the same single pass).

scipy.ndimage.maximum_position's labeled tie-break is `argsort` (quicksort) +
last-write-wins, i.e. an arbitrary tied pixel that is not stable across numpy
versions — so there is no stable rule to replicate. On real continuous data the
max is unique, so the kernel's `>=`-last result is bit-identical to scipy (the
correctness harness confirms 2D/3D, edge on/off); only exact-value ties can
differ, and the kernel's rule is the more reproducible of the two.

Drops the now-unused max_position_per_object host helper (and its scipy import)
from the primitive layer.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- flatten_labeled: derive (z,y,x) coords from numpy.nonzero(lmask) instead of
  materialising three full-volume mgrid arrays then masking them — same coords
  in the same C order, no per-call O(volume) temporaries.
- label_to_idx_lut: drop the unused sorted-labels return value (now just
  (lut, n)); the max_position-in-kernel refactor removed its only consumer.
- add a lighter segment_stats kernel (count/sum/min/max) and use it for the edge
  path, replacing the segment_moments call that needed throwaway zero coordinate
  arrays and discarded the centroid cross-sums.

No behaviour change; correctness harness + full suite stay green.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
flatten_labeled built the flat (values, seg0, coords) arrays with a numpy
(masks>0)&isfinite mask + numpy.nonzero + two fancy-index gathers — several
full-image passes plus a boolean-array allocation, and the dominant cost of the
non-edge path.

Replace it with flatten_numba: two grid scans (count, then fill) in a single
@njit kernel, coordinates taken from the loop indices. The flat-segment kernels
and the rest of the backend are unchanged — only how the flat arrays are built.

Measured (single image, non-edge core): flatten step ~4-10x faster (10x at
1024^2), full core ~1.1x (256^2) / ~1.5x (1024^2); the gain grows with image
size. Bit-identical output (correctness harness stays green).

The numpy flatten_labeled (its only consumer) is removed; primitives/segment.py
now holds just the numpy label->index lookup, the numba layer owns the flatten.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- _detect.py: drop the unused HAS_JAX / HAS_JAX_GPU flags. Besides being dead
  for this PR, HAS_JAX_GPU eagerly imported jax at module load whenever jax was
  installed, just to set a flag nothing reads. jax detection lands with the jax
  backend; HAS_NUMBA alone establishes the find_spec pattern.
- flatten the image without a forced float64 copy: pass masked_image through
  ascontiguousarray without dtype=, and let flatten_numba upcast the kept values
  into its float64 output. Avoids a full-image float64 temporary for non-float64
  inputs (e.g. float32 microscopy data); bit-identical for float64.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
label_to_idx_lut used numpy.unique(masks) — a full-image sort — to find the
present labels. scipy.ndimage.find_objects (scipy is already a core dep) returns
the same ascending present-label set in one O(P) pass, giving a bit-identical
LUT ~3-5x faster (12.4->3.5 ms at 1024^2; 21.9->4.4 ms on a 32x240x240 volume).

Trick borrowed from Alan's pure-numpy speedup (#55); unlike its
percentile/MAD changes, this one preserves output exactly (verified identical).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Profiling the sparse-large regime (1024^2, 64 obj, edge on) showed
skimage.find_boundaries was ~37% of the call (~20-29 ms) — the morphology
dominates, not the scan. A one-pass numba inner-boundary kernel (4-neighbour
check, the cp_measure_fast approach) is bit-identical to find_boundaries(
mode="inner") and 12-27x faster, verified exact across (H,W) and (1,H,W).

Used for 2D planes (Z==1); true 3D keeps skimage (6-neighbourhood). Single-image
1024^2/64 edge-on drops ~47->32 ms, and per-image batch ~445->264 ms.

No correctness change (exact boundary match; harness stays green).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Address PR #54 review:
- bulk._dispatch: reword the absent-numba RuntimeError from the imperative
  "install it via" to "you can install it via" (avoid issuing pip commands
  imperatively at the user).
- primitives is an internal layer with no public API to curate; import
  label_to_idx_lut directly from primitives.segment (matching how the
  _segment_numba kernels are already imported) and drop the __init__
  re-export. Documents the convention in the package docstring.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
feat(accelerator): numba intensity backend
Shared foundational helper used by the numba intensity/granularity/zernike
backends to normalise any input (2D/3D/4D/list) to the canonical batch-of-volumes
form: single image = batch of 1, returning a dict for a lone image/volume and a
list of dicts for a batch. Pure numpy, no numba. Extracted to its own PR so it
can be reviewed first and unblock the feature backends (#56/#57/#58).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Drop-in numba backends for the colocalization correlation features, replacing
the reference's per-object (N,H,W) boolean-stack + scipy.ndimage passes with one
grouped pair-flatten primitive and a single fused per-object kernel.

- primitives/_segment_numba.flatten_pairs_grouped: O(M) counting-sort grouping of
  two co-registered channels into per-object contiguous blocks (non-finite kept,
  to mirror the reference's pixels[mask] extraction).
- core/numba/_colocalization.coloc_per_object: one fused kernel yielding Pearson
  r + slope, Manders M1/M2, Overlap + K1/K2, and (gated) rank-weighted RWC1/RWC2;
  serial per object, no in-kernel parallelism.
- core/numba/measurecolocalization: four to_bzyx-normalised wrappers. The triple
  (pixels_1, pixels_2, masks) is normalised by calling to_bzyx twice on the shared
  mask and reusing the single unwrap. Features are value-vector-only, so the
  kernels never branch on 2D/3D and the (1,Y,X)-vs-(H,W) divergence cannot occur.

overlap is the unregistered 5th feature, surfaced numba-only (flagged in
_numba_registries). costes is a stacked follow-up. Pixels upcast to float64, so
genuine integer-dtype input can differ from the reference's uint8-overflow /
float32-lstsq artifacts (documented; real float images unaffected).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…oor)

labels_to_offsets() derives (lut, n, offsets) from a single np.bincount over the
raster, replacing scipy find_objects + the grouped flatten's separate count scan.
flatten_pairs_grouped now takes the precomputed offsets and does a single scatter
scan. Cuts the per-call prep shared by all four features from 3 full-image passes
to 2 (~6.2ms -> ~3.1ms here), bit-identical lut/offsets/values.

Speedups (1080^2, 144 obj, float): pearson 18.7->31.6x, manders 36.4->62.0x,
overlap 39.8->67.7x; rwc 5.3->5.6x (its per-object argsort dominates and is
intrinsic to the rank metric — a global lexsort alternative measured 2.3x slower).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
coloc_per_object did three sweeps per object over the value block: pass 1
(means + per-channel maxima), pass 2 (centred second moments -> Pearson/slope),
pass 3 (threshold-gated Manders/Overlap/RWC sums). Passes 2 and 3 are
independent given pass 1's means and maxima, so they now share one sweep.

Bit-identical to the previous kernel (each accumulator's add-order is unchanged;
verified array_equal across pearson/manders/rwc). ~3-4% faster on pearson/manders
(12.3 -> 11.9 ms, 1080^2/144 obj); rwc is sort-bound so ~1%. 28 coloc tests green.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@timtreis

timtreis commented Jun 4, 2026

Copy link
Copy Markdown
Collaborator Author

Follow-up perf (commit c22d856): fused coloc_per_object's centred-moment pass and threshold-gated pass into one sweep per object (they're independent given pass 1's means + maxima). Bit-identical to the prior kernel (verified array_equal across pearson/manders/rwc; each accumulator's add-order is unchanged), ~3-4% faster on pearson/manders (12.3→11.9 ms, 1080²/144 obj); rwc stays sort-bound. 28 coloc tests + full suite (114) green.

I also investigated the "flatten once and fan coloc_per_object's 9-tuple out to all features" idea (the featurizer runs pearson/manders/rwc as 3 separate wrappers on the same image). POC result: only ~1.12× in the realistic case (219.7→196.0 ms), because rwc's argsort dominates the coloc cost (~196 ms) and is paid once either way — sharing only saves the two cheap non-rwc kernel passes. It would be ~3× only if rwc is disabled. Given the architectural cost (combined dispatch + featurizer change) vs ~1.12% with rwc on, not worth bundling. coloc is effectively at its floor (the rwc sort).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants