diff --git a/pyproject.toml b/pyproject.toml index 06ffac1bf..da330145d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,8 @@ dynamic = [ dependencies = [ "aiohttp>=3.8.1", "anndata>=0.9", + "centrosome>=1.2.3", + "cp-measure>=0.1.19,<0.2", "cycler>=0.11", "dask[array]>=2021.2", "dask-image>=0.5", diff --git a/src/squidpy/experimental/im/__init__.py b/src/squidpy/experimental/im/__init__.py index 24f7cf6ee..52da878bd 100644 --- a/src/squidpy/experimental/im/__init__.py +++ b/src/squidpy/experimental/im/__init__.py @@ -6,6 +6,7 @@ WekaParams, detect_tissue, ) +from ._feature import calculate_image_features from ._make_tiles import make_tiles, make_tiles_from_spots from ._qc_image import qc_image from ._qc_metrics import QCMetric @@ -15,6 +16,7 @@ "FelzenszwalbParams", "QCMetric", "WekaParams", + "calculate_image_features", "detect_tissue", "make_tiles", "make_tiles_from_spots", diff --git a/src/squidpy/experimental/im/_feature.py b/src/squidpy/experimental/im/_feature.py new file mode 100644 index 000000000..bf84b9360 --- /dev/null +++ b/src/squidpy/experimental/im/_feature.py @@ -0,0 +1,998 @@ +"""Experimental feature extraction module. + +Extracts per-cell features from segmentation masks using cp_measure, +scikit-image regionprops, and squidpy-specific metrics. Large images +are automatically tiled so that each tile is processed independently. +""" + +from __future__ import annotations + +import time +import warnings +from dataclasses import dataclass, field, fields +from typing import Any, Literal, NamedTuple + +import anndata as ad +import numpy as np +import pandas as pd +import xarray as xr +from cp_measure.featurizer import featurize, make_featurizer_config +from joblib import Parallel, delayed +from skimage import measure +from skimage.feature import graycomatrix, graycoprops +from skimage.segmentation import relabel_sequential +from spatialdata import SpatialData, rasterize +from spatialdata._logging import logger as logg +from spatialdata.models import TableModel, get_channel_names +from spatialdata.transformations import get_transformation +from tqdm.auto import tqdm + +from squidpy.experimental.im._tiling import ( + build_tile_specs, + compute_cell_info, + compute_cell_info_multiscale, + compute_cell_info_tiled, + extract_tile_lazy, +) + +# --------------------------------------------------------------------------- +# Drop accounting +# --------------------------------------------------------------------------- + + +@dataclass +class DropReport: + """Counters for cells that were excluded during a featurization run. + + Emitted once at the end of ``calculate_image_features`` so users know + why their cell count shrank. + """ + + outside_image_extent: int = 0 + partial_at_image_boundary: int = 0 + cp_measure_no_data: int = 0 + empty_tiles: int = 0 + other: dict[str, int] = field(default_factory=dict) + + def summary(self) -> str: + lines = ["Cell drop report:"] + for f in fields(self): + v = getattr(self, f.name) + if isinstance(v, int) and v > 0: + lines.append(f" {f.name}: {v}") + elif isinstance(v, dict): + for k, vv in v.items(): + if vv: + lines.append(f" {k}: {vv}") + if len(lines) == 1: + return "Cell drop report: no cells dropped." + return "\n".join(lines) + + +__all__ = ["calculate_image_features"] + +# --------------------------------------------------------------------------- +# Skimage property sets +# --------------------------------------------------------------------------- + +_MASK_PROPS = frozenset( + { + "area", + "area_filled", + "area_convex", + "axis_major_length", + "axis_minor_length", + "eccentricity", + "equivalent_diameter_area", + "extent", + "feret_diameter_max", + "solidity", + "euler_number", + "centroid", + "centroid_local", + "perimeter", + "perimeter_crofton", + "inertia_tensor", + "inertia_tensor_eigvals", + } +) +_INTENSITY_PROPS = frozenset( + { + "intensity_max", + "intensity_mean", + "intensity_min", + "intensity_std", + } +) + +# cp_measure feature name → make_featurizer_config keyword(s) +_CPMEASURE_FLAGS: dict[str, dict[str, bool]] = { + "cpmeasure:intensity": {"intensity": True}, + "cpmeasure:sizeshape": {"sizeshape": True}, + "cpmeasure:texture": {"texture": True}, + "cpmeasure:granularity": {"granularity": True}, + "cpmeasure:zernike": {"zernike": True}, + "cpmeasure:feret": {"feret": True}, + "cpmeasure:radial": {"radial_distribution": True, "radial_zernikes": True}, + "cpmeasure:correlation": { + "correlation_pearson": True, + "correlation_costes": True, + "correlation_manders_fold": True, + "correlation_rwc": True, + }, + "cpmeasure:correlation_pearson": {"correlation_pearson": True}, + "cpmeasure:correlation_costes": {"correlation_costes": True}, + "cpmeasure:correlation_manders_fold": {"correlation_manders_fold": True}, + "cpmeasure:correlation_rwc": {"correlation_rwc": True}, +} + +# All known top-level feature group names (used for validation) +_ALL_FEATURES = ( + set(_CPMEASURE_FLAGS.keys()) + | {"skimage:label", "skimage:label+image"} + | {"squidpy:summary", "squidpy:texture", "squidpy:color_hist"} +) + + +# --------------------------------------------------------------------------- +# Feature parsing +# --------------------------------------------------------------------------- + + +class _ParsedFeatures(NamedTuple): + cp_flags: dict[str, bool] | None # kwargs for make_featurizer_config + skimage_label_props: frozenset[str] | None + skimage_intensity_props: frozenset[str] | None + squidpy_summary: bool + squidpy_texture: bool + squidpy_color_hist: bool + + +def _parse_features(features: list[str] | str | None) -> _ParsedFeatures: + """Parse user-facing feature names into structured config.""" + if features is None: + # Default: all cp_measure features + return _ParsedFeatures( + cp_flags={}, # empty dict → all defaults (all True) + skimage_label_props=None, + skimage_intensity_props=None, + squidpy_summary=False, + squidpy_texture=False, + squidpy_color_hist=False, + ) + + if isinstance(features, str): + features = [features] + + cp_flags: dict[str, bool] = {} + has_any_cp = False + label_props: set[str] | None = None + intensity_props: set[str] | None = None + sq_summary = False + sq_texture = False + sq_color_hist = False + + for f in features: + # cp_measure features + if f in _CPMEASURE_FLAGS: + has_any_cp = True + cp_flags.update(_CPMEASURE_FLAGS[f]) + + # skimage group-level + elif f == "skimage:label": + label_props = set(_MASK_PROPS) + elif f == "skimage:label+image": + intensity_props = set(_INTENSITY_PROPS) + + # skimage fine-grained: "skimage:label:prop" or "skimage:label+image:prop" + elif f.startswith("skimage:label:"): + prop = f.split(":", 2)[2] + if prop not in _MASK_PROPS: + raise ValueError(f"Unknown skimage label property: '{prop}'. Available: {sorted(_MASK_PROPS)}") + label_props = (label_props or set()) | {prop} + elif f.startswith("skimage:label+image:"): + prop = f.split(":", 2)[2] + if prop not in _INTENSITY_PROPS: + raise ValueError(f"Unknown skimage intensity property: '{prop}'. Available: {sorted(_INTENSITY_PROPS)}") + intensity_props = (intensity_props or set()) | {prop} + + # squidpy features + elif f == "squidpy:summary": + sq_summary = True + elif f == "squidpy:texture": + sq_texture = True + elif f == "squidpy:color_hist": + sq_color_hist = True + + else: + raise ValueError( + f"Unknown feature: '{f}'. Available top-level features: {sorted(_ALL_FEATURES)}, " + f"or use 'skimage:label:property' / 'skimage:label+image:property' for individual properties." + ) + + return _ParsedFeatures( + cp_flags=cp_flags if has_any_cp else None, + skimage_label_props=frozenset(label_props) if label_props else None, + skimage_intensity_props=frozenset(intensity_props) if intensity_props else None, + squidpy_summary=sq_summary, + squidpy_texture=sq_texture, + squidpy_color_hist=sq_color_hist, + ) + + +def _has_any_features(parsed: _ParsedFeatures) -> bool: + return ( + parsed.cp_flags is not None + or parsed.skimage_label_props is not None + or parsed.skimage_intensity_props is not None + or parsed.squidpy_summary + or parsed.squidpy_texture + or parsed.squidpy_color_hist + ) + + +# --------------------------------------------------------------------------- +# cp_measure config builder +# --------------------------------------------------------------------------- + + +def _build_cp_config(cp_flags: dict[str, bool], channel_names: list[str]) -> dict: + """Build a cp_measure featurizer config from parsed flags. + + When ``cp_flags`` is empty (the default-all case), every feature is + enabled. Otherwise, only the explicitly requested features are turned on. + """ + if not cp_flags: + # All defaults (everything True) + return make_featurizer_config(channel_names) + + # Start with everything off, then enable requested features + all_off = { + "intensity": False, + "texture": False, + "granularity": False, + "radial_distribution": False, + "radial_zernikes": False, + "sizeshape": False, + "zernike": False, + "feret": False, + "correlation_pearson": False, + "correlation_costes": False, + "correlation_manders_fold": False, + "correlation_rwc": False, + } + all_off.update(cp_flags) + return make_featurizer_config(channel_names, **all_off) + + +# --------------------------------------------------------------------------- +# Per-tile feature computation +# --------------------------------------------------------------------------- + + +def _featurize_tile( + tile_image: np.ndarray, + tile_labels: np.ndarray, + parsed: _ParsedFeatures, + channel_names: list[str], + *, + cp_config: dict | None = None, +) -> pd.DataFrame: + """Compute all requested features for a single tile. + + Parameters + ---------- + tile_image + ``(C, H, W)`` image tile. + tile_labels + ``(H, W)`` label tile with only owned cells. + parsed + Parsed feature configuration. + channel_names + Channel names for column naming. + cp_config + Pre-built cp_measure featurizer config. When ``None`` (default), the + config is built locally from ``parsed.cp_flags``. ``calculate_image_features`` + builds it once and reuses it across tiles; direct callers can rely on + the fallback. + + Returns + ------- + DataFrame indexed by cell label ID with one column per feature. + """ + cell_ids = np.unique(tile_labels) + cell_ids = cell_ids[cell_ids != 0] + if len(cell_ids) == 0: + return pd.DataFrame() + + parts: list[pd.DataFrame] = [] + + # --- cp_measure features --- + if cp_config is None and parsed.cp_flags is not None: + cp_config = _build_cp_config(parsed.cp_flags, channel_names) + if cp_config is not None: + # cp_measure assumes dense 1..N IDs and index-errors on sparse IDs. + contiguous_labels, _, inverse = relabel_sequential(tile_labels) + masks_3d = contiguous_labels[np.newaxis, :, :] + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + data, columns, rows = featurize(tile_image, masks_3d, cp_config) + if data.shape[0] > 0: + # cp_measure may return more rows than data; trim and remap. + row_labels = [int(inverse[r[2]]) for r in rows[: data.shape[0]]] + cp_df = pd.DataFrame(data, index=row_labels, columns=columns) + parts.append(cp_df) + + # --- skimage regionprops --- + if parsed.skimage_label_props is not None or parsed.skimage_intensity_props is not None: + df = _compute_skimage_features( + tile_labels, tile_image, parsed.skimage_label_props, parsed.skimage_intensity_props, channel_names + ) + if not df.empty: + parts.append(df) + + # --- squidpy per-cell features --- + if parsed.squidpy_summary or parsed.squidpy_texture or parsed.squidpy_color_hist: + df = _compute_squidpy_per_cell(tile_labels, tile_image, parsed, channel_names) + if not df.empty: + parts.append(df) + + if not parts: + return pd.DataFrame(index=cell_ids) + + combined = pd.concat(parts, axis=1) + combined = combined.reindex(cell_ids) + return combined + + +# --------------------------------------------------------------------------- +# skimage regionprops +# --------------------------------------------------------------------------- + + +def _regionprops_to_row(region: Any, props: frozenset[str]) -> dict[str, float]: + """Extract scalar features from a single regionprops object.""" + row: dict[str, float] = {} + for prop in props: + try: + value = getattr(region, prop) + arr = np.asarray(value) + if arr.ndim == 0: + row[prop] = float(arr) + elif arr.ndim == 1: + for i, v in enumerate(arr): + row[f"{prop}_{i}"] = float(v) + elif arr.ndim == 2: + for i in range(arr.shape[0]): + for j in range(arr.shape[1]): + row[f"{prop}_{i}x{j}"] = float(arr[i, j]) + except (ValueError, TypeError, AttributeError): + continue + return row + + +def _compute_skimage_features( + labels: np.ndarray, + image: np.ndarray, + label_props: frozenset[str] | None, + intensity_props: frozenset[str] | None, + channel_names: list[str], +) -> pd.DataFrame: + """Compute skimage regionprops features for all cells in a tile.""" + parts: list[pd.DataFrame] = [] + + if label_props is not None: + regions = measure.regionprops(labels) + rows = {r.label: _regionprops_to_row(r, label_props) for r in regions} + parts.append(pd.DataFrame.from_dict(rows, orient="index")) + + if intensity_props is not None: + for ch_idx, ch_name in enumerate(channel_names): + regions = measure.regionprops(labels, intensity_image=image[ch_idx]) + rows = {r.label: _regionprops_to_row(r, intensity_props) for r in regions} + df = pd.DataFrame.from_dict(rows, orient="index") + df = df.rename(columns=lambda c, _ch=ch_name: f"{c}_{_ch}") + parts.append(df) + + if not parts: + return pd.DataFrame() + return pd.concat(parts, axis=1) + + +# --------------------------------------------------------------------------- +# squidpy per-cell features +# --------------------------------------------------------------------------- + + +def _compute_squidpy_per_cell( + labels: np.ndarray, + image: np.ndarray, + parsed: _ParsedFeatures, + channel_names: list[str], +) -> pd.DataFrame: + """Compute squidpy features per cell within a tile.""" + regions = measure.regionprops(labels) + n_channels = image.shape[0] + rows: dict[int, dict[str, float]] = {} + + for region in regions: + lid = region.label + bbox = region.bbox # (min_row, min_col, max_row, max_col) + cell_features: dict[str, float] = {} + + # Extract cell's bounding box from image + img_crop = image[:, bbox[0] : bbox[2], bbox[1] : bbox[3]] + mask_crop = labels[bbox[0] : bbox[2], bbox[1] : bbox[3]] == lid + + for ch_idx in range(n_channels): + ch_name = channel_names[ch_idx] + ch_crop = img_crop[ch_idx].astype(np.float32) + masked_vals = ch_crop[mask_crop] + + if len(masked_vals) == 0: + continue + + if parsed.squidpy_summary: + cell_features[f"summary_mean_{ch_name}"] = float(np.mean(masked_vals)) + cell_features[f"summary_std_{ch_name}"] = float(np.std(masked_vals)) + cell_features[f"summary_min_{ch_name}"] = float(np.min(masked_vals)) + cell_features[f"summary_max_{ch_name}"] = float(np.max(masked_vals)) + + if parsed.squidpy_texture: + cell_features.update(_glcm_features(ch_crop, mask_crop, ch_name)) + + if parsed.squidpy_color_hist: + cell_features.update(_histogram_features(masked_vals, ch_name)) + + rows[lid] = cell_features + + return pd.DataFrame.from_dict(rows, orient="index") + + +def _glcm_features(channel_crop: np.ndarray, mask: np.ndarray, ch_name: str) -> dict[str, float]: + """GLCM texture features for a single channel within a cell's bbox.""" + quant_levels = 32 + ch = channel_crop.copy() + # Zero out non-cell pixels so they don't affect GLCM + ch[~mask] = 0 + ch_min, ch_max = float(ch[mask].min()), float(ch[mask].max()) + if ch_max > ch_min: + ch = (ch - ch_min) / (ch_max - ch_min) + else: + ch = np.zeros_like(ch) + ch_q = np.clip((ch * (quant_levels - 1)).round().astype(np.uint8), 0, quant_levels - 1) + ch_q[~mask] = 0 + + try: + glcm = graycomatrix(ch_q, distances=[1], angles=[0], levels=quant_levels, symmetric=True, normed=True) + return { + f"texture_contrast_{ch_name}": float(graycoprops(glcm, "contrast")[0, 0]), + f"texture_dissimilarity_{ch_name}": float(graycoprops(glcm, "dissimilarity")[0, 0]), + f"texture_homogeneity_{ch_name}": float(graycoprops(glcm, "homogeneity")[0, 0]), + f"texture_energy_{ch_name}": float(graycoprops(glcm, "energy")[0, 0]), + f"texture_ASM_{ch_name}": float(graycoprops(glcm, "ASM")[0, 0]), + f"texture_correlation_{ch_name}": float(graycoprops(glcm, "correlation")[0, 0]), + } + except (ValueError, IndexError): + return {} + + +def _histogram_features(masked_vals: np.ndarray, ch_name: str, bins: int = 16) -> dict[str, float]: + """Per-cell intensity histogram features.""" + lo, hi = float(masked_vals.min()), float(masked_vals.max()) + hist, _ = np.histogram(masked_vals, bins=bins, range=(lo, hi if hi > lo else lo + 1)) + hist = hist.astype(np.float32) + hist_sum = hist.sum() + if hist_sum > 0: + hist = hist / hist_sum + return {f"color_hist_bin{b}_{ch_name}": float(v) for b, v in enumerate(hist)} + + +# --------------------------------------------------------------------------- +# Input preparation (lazy — returns xarray DataArrays, not numpy) +# --------------------------------------------------------------------------- + + +# --------------------------------------------------------------------------- +# Coordinate-system aware alignment +# --------------------------------------------------------------------------- + + +def _shared_coordinate_system(sdata: SpatialData, image_key: str, labels_key: str) -> str: + img_t = get_transformation(sdata.images[image_key], get_all=True) + lbl_t = get_transformation(sdata.labels[labels_key], get_all=True) + shared = set(img_t) & set(lbl_t) + if not shared: + raise ValueError( + f"Image '{image_key}' and labels '{labels_key}' share no coordinate " + f"system (image: {sorted(img_t)}, labels: {sorted(lbl_t)})." + ) + return "global" if "global" in shared else sorted(shared)[0] + + +def _relative_affine(sdata: SpatialData, image_key: str, labels_key: str, cs: str) -> np.ndarray: + """Return the 3x3 affine that maps labels-pixel-coords to image-pixel-coords. + + Uses ``(x, y)`` axis order to match :mod:`spatialdata` convention. + """ + t_img = get_transformation(sdata.images[image_key], to_coordinate_system=cs) + t_lbl = get_transformation(sdata.labels[labels_key], to_coordinate_system=cs) + # image_pixel <- global <- labels_pixel + m_img_to_global = t_img.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y")) + m_lbl_to_global = t_lbl.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y")) + m_global_to_img = np.linalg.inv(m_img_to_global) + return m_global_to_img @ m_lbl_to_global + + +def _rasterize_to_image_grid(element: Any, image_da: xr.DataArray, cs: str) -> xr.DataArray: + """Rasterize a spatialdata element onto an image DataArray's pixel grid.""" + logg.warning( + f"Materializing element onto image grid via spatialdata.rasterize in '{cs}'. " + f"Lazy behavior is lost for this run." + ) + img_h = int(image_da.sizes["y"]) + img_w = int(image_da.sizes["x"]) + result = rasterize( + element, + ["x", "y"], + min_coordinate=[0, 0], + max_coordinate=[img_w, img_h], + target_coordinate_system=cs, + target_unit_to_pixels=1.0, + return_regions_as_labels=True, + ) + if isinstance(result, xr.DataArray): + return result + return xr.DataArray(np.asarray(result), dims=["y", "x"]) + + +def _is_close_identity(m: np.ndarray, atol: float = 1e-6) -> bool: + return bool(np.allclose(m, np.eye(m.shape[0]), atol=atol)) + + +def _decompose_pixel_translation(m: np.ndarray, atol: float = 1e-6) -> tuple[int, int] | None: + """If ``m`` is identity-plus-integer-translation, return ``(tx, ty)``; else None. + + ``m`` is a 3x3 affine in (x, y) axis order. + """ + rotscale = m[:2, :2] + if not np.allclose(rotscale, np.eye(2), atol=atol): + return None + tx, ty = float(m[0, 2]), float(m[1, 2]) + if not (abs(tx - round(tx)) < atol and abs(ty - round(ty)) < atol): + return None + return int(round(tx)), int(round(ty)) + + +def _align_to_image_grid( + sdata: SpatialData, + image_key: str, + labels_key: str, + image_da: xr.DataArray, + labels_da: xr.DataArray, + align_mode: Literal["strict", "rasterize"], + drop_report: DropReport, +) -> tuple[xr.DataArray, xr.DataArray]: + """Crop image and labels to their pixel-grid overlap, honoring transforms. + + See module docstring of concern 3 fix for full semantics. Mutates + ``drop_report`` to count cells dropped because they fall outside the + overlap rectangle. + """ + cs = _shared_coordinate_system(sdata, image_key, labels_key) + m = _relative_affine(sdata, image_key, labels_key, cs) + + # Integer-pixel offset of labels relative to image. (tx, ty) means + # labels pixel (0, 0) lands at image pixel (tx, ty) in (x, y) order. + if _is_close_identity(m): + tx, ty = 0, 0 + elif (decomposed := _decompose_pixel_translation(m)) is not None: + tx, ty = decomposed + elif align_mode == "strict": + raise ValueError( + f"Labels not aligned to image pixel grid in coordinate system '{cs}'. " + f"Relative affine (x,y) =\n{m}\n" + f"Pass align_mode='rasterize' to resample labels onto the image grid " + f"(via spatialdata.rasterize), or pre-align with spatialdata.rasterize " + f"in your pipeline." + ) + else: + labels_da = _rasterize_to_image_grid(sdata.labels[labels_key], image_da, cs) + tx, ty = 0, 0 + + # Determine overlap rectangle in image-pixel coords. + img_h = int(image_da.sizes["y"]) + img_w = int(image_da.sizes["x"]) + lbl_h = int(labels_da.sizes.get("y", labels_da.shape[-2])) + lbl_w = int(labels_da.sizes.get("x", labels_da.shape[-1])) + + # Labels pixel (i_y, i_x) in label coords maps to image pixel (i_y+ty, i_x+tx). + img_y0 = max(0, ty) + img_x0 = max(0, tx) + img_y1 = min(img_h, lbl_h + ty) + img_x1 = min(img_w, lbl_w + tx) + if img_y1 <= img_y0 or img_x1 <= img_x0: + raise ValueError(f"Image '{image_key}' and labels '{labels_key}' do not overlap in coordinate system '{cs}'.") + + lbl_y0 = img_y0 - ty + lbl_x0 = img_x0 - tx + lbl_y1 = img_y1 - ty + lbl_x1 = img_x1 - tx + + image_crop = image_da.isel(y=slice(img_y0, img_y1), x=slice(img_x0, img_x1)) + labels_crop = labels_da.isel(y=slice(lbl_y0, lbl_y1), x=slice(lbl_x0, lbl_x1)) + + # Count cells that fall (partially) outside the labels_crop. + cells_inside, cells_partial, cells_outside = _classify_dropped_cells(labels_da, lbl_y0, lbl_x0, lbl_y1, lbl_x1) + if cells_outside or cells_partial: + drop_report.outside_image_extent += cells_outside + drop_report.partial_at_image_boundary += cells_partial + warnings.warn( + f"Dropping {cells_outside} cells outside the image extent and " + f"{cells_partial} cells partially outside. See end-of-run drop report.", + UserWarning, + stacklevel=2, + ) + + return image_crop, labels_crop + + +def _classify_dropped_cells( + labels_da: xr.DataArray, + y0: int, + x0: int, + y1: int, + x1: int, +) -> tuple[int, int, int]: + """Return ``(fully_inside, partially_inside, fully_outside)`` cell counts. + + Uses per-cell bounding boxes computed via tile-streamed reads so the + full label array is never materialized. + """ + lbl_h = int(labels_da.sizes.get("y", labels_da.shape[-2])) + lbl_w = int(labels_da.sizes.get("x", labels_da.shape[-1])) + if y0 <= 0 and x0 <= 0 and y1 >= lbl_h and x1 >= lbl_w: + return 0, 0, 0 + + cell_info = compute_cell_info_tiled(labels_da) + fully_inside = 0 + partial = 0 + fully_outside = 0 + for ci in cell_info.values(): + by0, bx0 = ci.bbox_y0, ci.bbox_x0 + by1, bx1 = by0 + ci.bbox_h, bx0 + ci.bbox_w + if by1 <= y0 or by0 >= y1 or bx1 <= x0 or bx0 >= x1: + fully_outside += 1 + elif by0 >= y0 and by1 <= y1 and bx0 >= x0 and bx1 <= x1: + fully_inside += 1 + else: + partial += 1 + return fully_inside, partial, fully_outside + + +def _resolve_da(node: xr.DataTree | xr.DataArray, scale: str | None) -> xr.DataArray: + """Get a DataArray from a DataTree or single-scale element (stays lazy).""" + if not isinstance(node, xr.DataTree): + return node + if scale is None: + raise ValueError("Scale must be provided for DataTree data.") + if scale not in node: + raise ValueError(f"Scale '{scale}' not found. Available: {list(node.keys())}") + return node[scale].ds["image"] + + +def _validate_inputs( + sdata: SpatialData, + image_key: str, + labels_key: str | None, + shapes_key: str | None, + scale: str | None, +) -> None: + """Run all input validation checks (no data loading).""" + if image_key not in sdata.images: + raise ValueError(f"Image key '{image_key}' not found, valid keys: {list(sdata.images.keys())}") + if labels_key is None and shapes_key is None: + raise ValueError("Provide either `labels_key` or `shapes_key`.") + if labels_key is not None and shapes_key is not None: + raise ValueError("Use either `labels_key` or `shapes_key`, not both.") + if labels_key is not None and labels_key not in sdata.labels: + raise ValueError(f"Labels key '{labels_key}' not found, valid keys: {list(sdata.labels.keys())}") + if shapes_key is not None and shapes_key not in sdata.shapes: + raise ValueError(f"Shapes key '{shapes_key}' not found, valid keys: {list(sdata.shapes.keys())}") + if labels_key is not None and isinstance(sdata.labels[labels_key], xr.DataTree) and scale is None: + raise ValueError("When using multi-scale labels, please specify the scale.") + if isinstance(sdata.images[image_key], xr.DataTree) and scale is None: + raise ValueError("When using multi-scale images, please specify the scale.") + + +def _prepare_lazy( + sdata: SpatialData, + image_key: str, + labels_key: str | None, + shapes_key: str | None, + scale: str | None, + channels: list[str] | None, + align_mode: Literal["strict", "rasterize"], + drop_report: DropReport, +) -> tuple[xr.DataArray, xr.DataArray, list[str]]: + """Return lazy (dask-backed) image and labels DataArrays, plus channel names. + + Does NOT call ``.compute()`` — arrays stay lazy for on-demand tile reads. + For the shapes→labels path, labels are materialized (rasterize returns + an in-memory array) but wrapped in a DataArray for a uniform interface. + """ + _validate_inputs(sdata, image_key, labels_key, shapes_key, scale) + + # Image DataArray (lazy) + image_da = _resolve_da(sdata.images[image_key], scale) + if "c" not in image_da.dims: + image_da = image_da.expand_dims("c") + + # Labels DataArray (lazy for labels_key, materialized for shapes_key) + if labels_key is not None: + labels_da = _resolve_da(sdata.labels[labels_key], scale) + else: + logg.info("Converting shapes to labels.") + img_shape = {d: image_da.sizes[d] for d in ("y", "x")} + try: + labels_result = rasterize( + sdata.shapes[shapes_key], + ["x", "y"], + min_coordinate=[0, 0], + max_coordinate=[img_shape["x"], img_shape["y"]], + target_coordinate_system="global", + target_unit_to_pixels=1.0, + return_regions_as_labels=True, + ) + except ValueError as e: + raise ValueError( + "Failed to rasterize shapes; geometries may be empty or unsupported. " + "Filter out empty/non-polygon geometries or choose a different shapes_key." + ) from e + if isinstance(labels_result, xr.DataArray): + labels_da = labels_result + else: + labels_da = xr.DataArray(np.asarray(labels_result), dims=["y", "x"]) + + # Align labels to image pixel grid via SpatialData transformations. + # For the shapes_key path, rasterize already targets the image grid, so + # the transforms are identity and this is a cheap no-op. + if labels_key is not None: + image_da, labels_da = _align_to_image_grid( + sdata, image_key, labels_key, image_da, labels_da, align_mode, drop_report + ) + + # Resolve channel names through spatialdata's canonical accessor so we + # honor c_coords set at parse time. Always cast to str. + all_ch = [str(v) for v in get_channel_names(sdata.images[image_key])] + if len(all_ch) != image_da.sizes["c"]: + # Multiscale element where get_channel_names may report from a + # different scale than image_da. Fall back to positional naming. + all_ch = [str(i) for i in range(image_da.sizes["c"])] + + ch_names: list[str] + if channels is not None: + selected_idx: list[int] = [] + ch_names = [] + for ch in channels: + if not isinstance(ch, str): + raise TypeError( + f"channels must contain strings (channel names); got {type(ch).__name__} {ch!r}. " + f"Available channel names: {all_ch}." + ) + if ch not in all_ch: + raise ValueError(f"Channel '{ch}' not found. Available: {all_ch}") + selected_idx.append(all_ch.index(ch)) + ch_names.append(ch) + image_da = image_da.isel(c=selected_idx) + else: + ch_names = all_ch + + return image_da, labels_da, ch_names + + +def _compute_centroids( + sdata: SpatialData, + labels_key: str | None, + labels_da: xr.DataArray, + scale: str | None, +) -> dict: + """Compute cell centroids using the most efficient strategy available.""" + # Multiscale labels → use coarsest scale + if labels_key is not None and isinstance(sdata.labels[labels_key], xr.DataTree): + logg.info("Computing centroids from coarse scale.") + return compute_cell_info_multiscale(sdata.labels[labels_key], target_scale=scale or "scale0") + + # Small enough to fit in memory → direct regionprops + n_pixels = labels_da.sizes.get("y", 1) * labels_da.sizes.get("x", 1) + if n_pixels <= 4096 * 4096: + lbl_np = labels_da.values + if lbl_np.ndim > 2: + lbl_np = lbl_np.squeeze() + return compute_cell_info(lbl_np) + + # Large single-scale → tiled centroid computation + logg.info("Computing centroids in tiled mode (large single-scale labels).") + return compute_cell_info_tiled(labels_da) + + +# --------------------------------------------------------------------------- +# Main function +# --------------------------------------------------------------------------- + + +def calculate_image_features( + sdata: SpatialData, + image_key: str, + labels_key: str | None = None, + shapes_key: str | None = None, + scale: str | None = None, + channels: list[str] | list[int] | None = None, + features: list[str] | str | None = None, + tile_size: int = 2048, + overlap_margin: int | Literal["auto"] = "auto", + align_mode: Literal["strict", "rasterize"] = "strict", + adata_key_added: str = "morphology", + invalid_as_zero: bool = True, + n_jobs: int = 1, + inplace: bool = True, +) -> ad.AnnData | None: + """ + Calculate per-cell features from segmentation masks. + + Uses `cp_measure `_ for + CellProfiler-derived features, scikit-image ``regionprops`` for + morphological/intensity features, and squidpy-specific per-cell + metrics (summary statistics, GLCM texture, colour histograms). + + Large images are automatically tiled into ``tile_size × tile_size`` + chunks with overlap so that every cell is fully contained in exactly + one tile. + + Parameters + ---------- + sdata + SpatialData object. + image_key + Key in ``sdata.images``. + labels_key + Key in ``sdata.labels`` with segmentation masks. + shapes_key + Key in ``sdata.shapes`` (rasterized to labels internally). + scale + Scale level for multi-scale data. + channels + Subset of channel names to use, matching those returned by + :func:`spatialdata.models.get_channel_names`. ``None`` uses all + channels. Integer indices are not accepted -- always pass names. + features + Which features to compute. Accepts a list of strings: + + - ``"cpmeasure:intensity"``, ``"cpmeasure:sizeshape"``, + ``"cpmeasure:texture"``, ``"cpmeasure:granularity"``, + ``"cpmeasure:zernike"``, ``"cpmeasure:feret"``, + ``"cpmeasure:radial"``, ``"cpmeasure:correlation"`` + - ``"skimage:label"`` (all mask props), ``"skimage:label:area"`` + (single prop), ``"skimage:label+image"`` (all intensity props), + ``"skimage:label+image:intensity_mean"`` (single prop) + - ``"squidpy:summary"``, ``"squidpy:texture"``, + ``"squidpy:color_hist"`` + + ``None`` enables all cp_measure features. + tile_size + Side length of the tiling grid (pixels). + overlap_margin + Overlap around each tile to capture boundary cells. + ``"auto"`` computes the minimum from the largest cell's bounding box. + align_mode + How to handle image/labels coordinate-system alignment when their + pixel grids do not match. + + * ``"strict"`` (default): require the relative transform between + image and labels to be identity or an integer-pixel translation. + Raise otherwise with a hint pointing to :func:`spatialdata.rasterize`. + * ``"rasterize"``: silently resample labels onto the image pixel + grid using :func:`spatialdata.rasterize` when the transforms are + not pixel-aligned. Logs a warning because this materializes the + full label grid in memory. + adata_key_added + Key under which to store the result in ``sdata.tables``. + invalid_as_zero + Replace ``inf`` and ``NaN`` values with zero. + n_jobs + Number of parallel jobs for tile processing. + inplace + If ``True``, store result in ``sdata.tables``. Otherwise return it. + + Returns + ------- + :class:`~anndata.AnnData` when ``inplace=False``, otherwise ``None``. + """ + # --- Parse & validate --- + parsed = _parse_features(features) + if not _has_any_features(parsed): + raise ValueError("No valid features requested.") + + drop_report = DropReport() + + image_da, labels_da, channel_names = _prepare_lazy( + sdata, image_key, labels_key, shapes_key, scale, channels, align_mode, drop_report + ) + + # --- Warmup: compute centroids without materializing full arrays --- + cell_info = _compute_centroids(sdata, labels_key, labels_da, scale) + if not cell_info: + logg.info(drop_report.summary()) + raise ValueError("No cells found in labels (all zeros).") + + H = int(labels_da.sizes.get("y", labels_da.shape[-2])) + W = int(labels_da.sizes.get("x", labels_da.shape[-1])) + + # --- Tile --- + specs = build_tile_specs((H, W), cell_info, tile_size=tile_size, overlap_margin=overlap_margin) + total_tiles = len(specs) + logg.info(f"Processing {total_tiles} tiles ({tile_size}x{tile_size}, margin={overlap_margin}).") + + # Build cp_measure config once; the same dict is reused for every tile. + cp_config = _build_cp_config(parsed.cp_flags, channel_names) if parsed.cp_flags is not None else None + + # --- Process tiles (each worker materializes only its own ~2k x 2k crop) --- + def _process_one(spec): + tile_img, tile_lbl = extract_tile_lazy(image_da, labels_da, spec) + return _featurize_tile(tile_img, tile_lbl, parsed, channel_names, cp_config=cp_config) + + log_every = max(1, total_tiles // 10) + start_t = time.monotonic() + tile_dfs: list[pd.DataFrame] = [] + results_iter = Parallel(n_jobs=n_jobs, prefer="threads", return_as="generator_unordered")( + delayed(_process_one)(spec) for spec in specs + ) + for done, df in enumerate( + tqdm(results_iter, total=total_tiles, desc="Featurizing tiles", unit="tile"), + start=1, + ): + if df.empty: + drop_report.empty_tiles += 1 + else: + tile_dfs.append(df) + if done == 1 or done == total_tiles or done % log_every == 0: + elapsed = time.monotonic() - start_t + logg.info(f"Tile {done}/{total_tiles} done (elapsed {elapsed:.1f}s).") + + if not tile_dfs: + logg.info(drop_report.summary()) + raise ValueError("No features computed for any tile.") + + combined = pd.concat(tile_dfs, axis=0) + + # --- Post-process --- + if invalid_as_zero: + combined = combined.replace([np.inf, -np.inf], 0).fillna(0) + + # Sort by cell label for deterministic output + combined = combined.sort_index() + + # --- Build AnnData --- + adata = ad.AnnData(X=combined.values.astype(np.float32)) + adata.obs_names = [f"cell_{i}" for i in combined.index] + adata.var_names = list(combined.columns) + + region_key_value = labels_key if labels_key is not None else shapes_key + adata.uns["spatialdata_attrs"] = { + "region": region_key_value, + "region_key": "region", + "instance_key": "label_id", + } + adata.obs["region"] = pd.Categorical([region_key_value] * len(adata)) + + if shapes_key is not None and len(sdata.shapes[shapes_key]) == len(adata): + adata.obs["label_id"] = sdata.shapes[shapes_key].index.values + else: + adata.obs["label_id"] = combined.index.values + + logg.info(drop_report.summary()) + + if inplace: + sdata.tables[adata_key_added] = TableModel.parse(adata) + return None + return adata diff --git a/src/squidpy/experimental/im/_tiling.py b/src/squidpy/experimental/im/_tiling.py index 4f6424ee3..dbdd2b745 100644 --- a/src/squidpy/experimental/im/_tiling.py +++ b/src/squidpy/experimental/im/_tiling.py @@ -6,19 +6,39 @@ cell. Non-owned cells are zeroed out in each tile's mask so that downstream processing never double-counts. -All functions accept pre-computed centroid dicts and image shapes — they -never materialize the full image or label array. +Two parallel APIs are exposed: + +* In-memory: ``compute_cell_info(labels) -> dict`` + ``extract_tile``. +* Lazy / xarray-backed: ``compute_cell_info_multiscale``, + ``compute_cell_info_tiled``, ``extract_tile_lazy``. + +``build_tile_specs`` takes only ``(shape, cell_info)``, so it is agnostic +to whether the labels are in memory, dask-backed, or multiscale. """ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Literal import numpy as np import xarray as xr +from scipy import ndimage as ndi from skimage.measure import regionprops +__all__ = [ + "CellInfo", + "TileSpec", + "build_tile_specs", + "compute_cell_info", + "compute_cell_info_multiscale", + "compute_cell_info_tiled", + "extract_labels_tile_lazy", + "extract_tile", + "extract_tile_lazy", + "verify_coverage", +] + @dataclass(frozen=True) class CellInfo: @@ -27,8 +47,10 @@ class CellInfo: label: int centroid_y: float centroid_x: float - bbox_h: int # height of bounding box - bbox_w: int # width of bounding box + bbox_h: int # height of bounding box (pixels) + bbox_w: int # width of bounding box (pixels) + bbox_y0: int = 0 # top-left y of bounding box + bbox_x0: int = 0 # top-left x of bounding box @dataclass(frozen=True) @@ -38,10 +60,10 @@ class TileSpec: Attributes ---------- base - The non-overlapping region ``(y0, x0, y1, x1)`` used for centroid + Non-overlapping region ``(y0, x0, y1, x1)`` used for centroid ownership. Tiles partition the image into a grid of base regions. crop - The extended region ``(y0, x0, y1, x1)`` that includes the overlap + Extended region ``(y0, x0, y1, x1)`` that includes the overlap margin. This is the actual slice extracted from the image/labels. owned_ids Label IDs whose centroid falls inside ``base``. Only these labels @@ -50,16 +72,16 @@ class TileSpec: base: tuple[int, int, int, int] crop: tuple[int, int, int, int] - owned_ids: frozenset[int] + owned_ids: frozenset[int] = field(default_factory=frozenset) # --------------------------------------------------------------------------- -# Centroid computation +# Cell info — in-memory # --------------------------------------------------------------------------- def compute_cell_info(labels: np.ndarray) -> dict[int, CellInfo]: - """Compute centroid and bounding-box size for every label from a numpy array. + """Compute centroid and bounding-box size for every label. Parameters ---------- @@ -70,132 +92,193 @@ def compute_cell_info(labels: np.ndarray) -> dict[int, CellInfo]: ------- Mapping from label ID to :class:`CellInfo`. """ + if labels.ndim != 2: + raise ValueError(f"Expected 2-D labels, got shape {labels.shape}") props = regionprops(labels) info: dict[int, CellInfo] = {} for p in props: min_row, min_col, max_row, max_col = p.bbox info[p.label] = CellInfo( label=p.label, - centroid_y=p.centroid[0], - centroid_x=p.centroid[1], + centroid_y=float(p.centroid[0]), + centroid_x=float(p.centroid[1]), bbox_h=max_row - min_row, bbox_w=max_col - min_col, + bbox_y0=min_row, + bbox_x0=min_col, ) return info +# --------------------------------------------------------------------------- +# Cell info — multiscale (read coarse pyramid level, scale back to target) +# --------------------------------------------------------------------------- + + +def _pick_coarsest_scale(label_tree: xr.DataTree) -> str: + """Return the coarsest scale key in a multiscale DataTree.""" + scales = sorted(label_tree.keys(), key=lambda s: int(s.replace("scale", ""))) + return scales[-1] + + +def _scale_dims(node: xr.DataTree | xr.DataArray) -> tuple[int, int]: + """Return (H, W) of a single scale level.""" + if isinstance(node, xr.DataTree): + # spatialdata stores the array under .ds["image"] + da = node.ds["image"] + else: + da = node + return int(da.sizes["y"]), int(da.sizes["x"]) + + def compute_cell_info_multiscale( - labels_node: xr.DataTree, - target_scale: str = "scale0", + label_tree: xr.DataTree, + target_scale: str, ) -> dict[int, CellInfo]: - """Compute centroids using the coarsest scale of a multiscale label pyramid. + """Compute cell info from the coarsest scale, rescaled to target scale. + + Reading the coarsest scale avoids materializing the full-res labels + just to find centroids. + + Parameters + ---------- + label_tree + Multi-scale labels (e.g. ``sdata.labels[key]``). + target_scale + Scale level whose pixel grid the returned centroids/bbox refer to. - Reads only the smallest resolution, then scales coordinates to *target_scale*. + Returns + ------- + Cell info dict, in ``target_scale`` pixel coordinates. """ - available = list(labels_node.keys()) - if not available: - return {} - - def _spatial_size(k: str) -> int: - da = labels_node[k].ds["image"] - h = int(da.sizes.get("y", da.shape[-2])) - w = int(da.sizes.get("x", da.shape[-1])) - return h * w - - coarsest = min(available, key=_spatial_size) - coarse_da = labels_node[coarsest].ds["image"] - coarse_labels = np.asarray(coarse_da.values).squeeze() - - if coarse_labels.ndim != 2: - raise ValueError(f"Expected 2-D labels at scale {coarsest}, got shape {coarse_labels.shape}") - - target_da = labels_node[target_scale].ds["image"] - target_h, target_w = target_da.sizes.get("y", target_da.shape[-2]), target_da.sizes.get("x", target_da.shape[-1]) - coarse_h, coarse_w = coarse_labels.shape - scale_y = target_h / coarse_h - scale_x = target_w / coarse_w - - props = regionprops(coarse_labels) - return { - p.label: CellInfo( - label=p.label, - centroid_y=p.centroid[0] * scale_y, - centroid_x=p.centroid[1] * scale_x, - bbox_h=int(np.ceil((p.bbox[2] - p.bbox[0]) * scale_y)), - bbox_w=int(np.ceil((p.bbox[3] - p.bbox[1]) * scale_x)), + if target_scale not in label_tree: + raise ValueError(f"target_scale '{target_scale}' not found in DataTree. Available: {list(label_tree.keys())}") + + coarsest = _pick_coarsest_scale(label_tree) + if coarsest == target_scale: + labels_arr = label_tree[coarsest].ds["image"].values + if labels_arr.ndim > 2: + labels_arr = labels_arr.squeeze() + return compute_cell_info(labels_arr) + + coarse_h, coarse_w = _scale_dims(label_tree[coarsest]) + target_h, target_w = _scale_dims(label_tree[target_scale]) + + sy = target_h / coarse_h + sx = target_w / coarse_w + + labels_arr = label_tree[coarsest].ds["image"].values + if labels_arr.ndim > 2: + labels_arr = labels_arr.squeeze() + coarse_info = compute_cell_info(labels_arr) + + rescaled: dict[int, CellInfo] = {} + for lid, ci in coarse_info.items(): + rescaled[lid] = CellInfo( + label=ci.label, + centroid_y=ci.centroid_y * sy, + centroid_x=ci.centroid_x * sx, + bbox_h=int(np.ceil(ci.bbox_h * sy)), + bbox_w=int(np.ceil(ci.bbox_w * sx)), + bbox_y0=int(np.floor(ci.bbox_y0 * sy)), + bbox_x0=int(np.floor(ci.bbox_x0 * sx)), ) - for p in props - } + return rescaled + + +# --------------------------------------------------------------------------- +# Cell info — tiled (single-scale large labels, no full materialization) +# --------------------------------------------------------------------------- def compute_cell_info_tiled( labels_da: xr.DataArray, chunk_size: int = 4096, ) -> dict[int, CellInfo]: - """Compute centroids by reading label tiles — never materializes the full array. + """Compute cell info by tile-streaming the labels array. - For cells spanning multiple chunks, centroids are computed as - area-weighted means of per-chunk centroids. + Accumulates pixel sums + bbox per label across non-overlapping tiles. + Cells that span tile boundaries are merged correctly because the per- + label statistics are additive. Parameters ---------- labels_da - 2-D (y, x) dask-backed xarray DataArray. + Lazy/eager 2-D xarray DataArray of integer labels. chunk_size - Size of chunks to read at a time. + Tile side length for streaming reads. + + Returns + ------- + Cell info dict in ``labels_da``'s native pixel grid. """ - H = int(labels_da.sizes.get("y", labels_da.shape[-2])) - W = int(labels_da.sizes.get("x", labels_da.shape[-1])) + if labels_da.ndim > 2: + labels_da = labels_da.squeeze() + if labels_da.ndim != 2: + raise ValueError(f"Expected 2-D labels, got shape {labels_da.shape}") + + H, W = int(labels_da.sizes["y"]), int(labels_da.sizes["x"]) - # Per-label accumulators: [sum_y*area, sum_x*area, total_area, min_y, max_y, min_x, max_x] - stats: dict[int, list[float]] = {} + area: dict[int, int] = {} + sum_y: dict[int, float] = {} + sum_x: dict[int, float] = {} + min_y: dict[int, int] = {} + min_x: dict[int, int] = {} + max_y: dict[int, int] = {} + max_x: dict[int, int] = {} for y0 in range(0, H, chunk_size): y1 = min(y0 + chunk_size, H) for x0 in range(0, W, chunk_size): x1 = min(x0 + chunk_size, W) - chunk = labels_da.isel(y=slice(y0, y1), x=slice(x0, x1)).values - if chunk.ndim > 2: - chunk = chunk.squeeze() - - for p in regionprops(chunk): - lid = p.label - cy_global = float(p.centroid[0] + y0) - cx_global = float(p.centroid[1] + x0) - area = float(p.area) - min_row = float(p.bbox[0] + y0) - max_row = float(p.bbox[2] + y0) - min_col = float(p.bbox[1] + x0) - max_col = float(p.bbox[3] + x0) - - if lid not in stats: - stats[lid] = [cy_global * area, cx_global * area, area, min_row, max_row, min_col, max_col] - else: - s = stats[lid] - s[0] += cy_global * area - s[1] += cx_global * area - s[2] += area - s[3] = min(s[3], min_row) - s[4] = max(s[4], max_row) - s[5] = min(s[5], min_col) - s[6] = max(s[6], max_col) - - result: dict[int, CellInfo] = {} - for lid, s in stats.items(): - if lid == 0: - continue - result[lid] = CellInfo( - label=lid, - centroid_y=s[0] / s[2], - centroid_x=s[1] / s[2], - bbox_h=int(s[4] - s[3]), - bbox_w=int(s[6] - s[5]), + tile = labels_da.isel(y=slice(y0, y1), x=slice(x0, x1)).values + if tile.ndim > 2: + tile = tile.squeeze() + flat = tile.ravel() + counts = np.bincount(flat) + if counts.size <= 1 or not counts[1:].any(): + continue + + yy, xx = np.mgrid[y0 : y0 + tile.shape[0], x0 : x0 + tile.shape[1]] + sums_y = np.bincount(flat, weights=yy.ravel(), minlength=counts.size) + sums_x = np.bincount(flat, weights=xx.ravel(), minlength=counts.size) + slices = ndi.find_objects(tile) + + for lid_int, count in enumerate(counts): + if lid_int == 0 or count == 0: + continue + area[lid_int] = area.get(lid_int, 0) + int(count) + sum_y[lid_int] = sum_y.get(lid_int, 0.0) + float(sums_y[lid_int]) + sum_x[lid_int] = sum_x.get(lid_int, 0.0) + float(sums_x[lid_int]) + sl = slices[lid_int - 1] + if sl is None: + continue + y_slice, x_slice = sl + ymin = y0 + y_slice.start + ymax = y0 + y_slice.stop - 1 + xmin = x0 + x_slice.start + xmax = x0 + x_slice.stop - 1 + min_y[lid_int] = min(min_y.get(lid_int, ymin), ymin) + min_x[lid_int] = min(min_x.get(lid_int, xmin), xmin) + max_y[lid_int] = max(max_y.get(lid_int, ymax), ymax) + max_x[lid_int] = max(max_x.get(lid_int, xmax), xmax) + + info: dict[int, CellInfo] = {} + for lid_int, a in area.items(): + info[lid_int] = CellInfo( + label=lid_int, + centroid_y=sum_y[lid_int] / a, + centroid_x=sum_x[lid_int] / a, + bbox_h=max_y[lid_int] - min_y[lid_int] + 1, + bbox_w=max_x[lid_int] - min_x[lid_int] + 1, + bbox_y0=min_y[lid_int], + bbox_x0=min_x[lid_int], ) - return result + return info # --------------------------------------------------------------------------- -# Tile spec building +# Tile specification # --------------------------------------------------------------------------- @@ -204,42 +287,41 @@ def _auto_margin(cell_info: dict[int, CellInfo]) -> int: if not cell_info: return 0 max_extent = max(max(c.bbox_h, c.bbox_w) for c in cell_info.values()) - # Centroid can be at most half a bbox away from the cell's edge. - # Add 1 pixel for safety (rounding / off-by-one). return int(np.ceil(max_extent / 2)) + 1 def build_tile_specs( - image_shape: tuple[int, int], + shape: tuple[int, int], cell_info: dict[int, CellInfo], tile_size: int = 2048, overlap_margin: int | Literal["auto"] = "auto", ) -> list[TileSpec]: - """Build tile specifications from pre-computed centroids. + """Build tile specifications from precomputed cell info. - No pixel data is needed — only the image dimensions and centroid dict. + The new ``(shape, cell_info)`` signature makes this agnostic to label + materialization — caller supplies dims and centroids, this function + just partitions. Parameters ---------- - image_shape - ``(H, W)`` of the full-resolution image/labels. + shape + ``(H, W)`` of the labels array. cell_info - Pre-computed centroids from :func:`compute_cell_info`, - :func:`compute_cell_info_multiscale`, or :func:`compute_cell_info_tiled`. + Output of :func:`compute_cell_info` (or one of its variants). tile_size Side length of the non-overlapping base grid cells. overlap_margin - Pixel margin added around each base region. ``"auto"`` computes the - minimum margin from the largest cell's bounding box. + Pixel margin added around each base region. ``"auto"`` computes + the minimum margin from the largest cell's bounding box. Returns ------- List of :class:`TileSpec`, one per grid cell that owns at least one - label. Empty tiles (no cells) are omitted. + label. Empty tiles are omitted. """ - H, W = image_shape if tile_size <= 0: raise ValueError(f"tile_size must be positive, got {tile_size}") + H, W = shape if isinstance(overlap_margin, str) and overlap_margin == "auto": margin = _auto_margin(cell_info) @@ -248,36 +330,41 @@ def build_tile_specs( if margin < 0: raise ValueError(f"overlap_margin must be non-negative, got {margin}") + # Assign each cell to a base-grid cell by its centroid cell_to_tile: dict[int, tuple[int, int]] = {} for lid, ci in cell_info.items(): - tile_row = min(int(ci.centroid_y) // tile_size, (H - 1) // tile_size) - tile_col = min(int(ci.centroid_x) // tile_size, (W - 1) // tile_size) + tile_row = min(int(ci.centroid_y) // tile_size, max((H - 1) // tile_size, 0)) + tile_col = min(int(ci.centroid_x) // tile_size, max((W - 1) // tile_size, 0)) cell_to_tile[lid] = (tile_row, tile_col) tile_to_cells: dict[tuple[int, int], set[int]] = {} for lid, key in cell_to_tile.items(): tile_to_cells.setdefault(key, set()).add(lid) + n_rows = max((H + tile_size - 1) // tile_size, 1) + n_cols = max((W + tile_size - 1) // tile_size, 1) + specs: list[TileSpec] = [] - for (row, col), owned in sorted(tile_to_cells.items()): - by0 = row * tile_size - bx0 = col * tile_size - by1 = min(by0 + tile_size, H) - bx1 = min(bx0 + tile_size, W) - - cy0 = max(by0 - margin, 0) - cx0 = max(bx0 - margin, 0) - cy1 = min(by1 + margin, H) - cx1 = min(bx1 + margin, W) - - specs.append( - TileSpec( - base=(by0, bx0, by1, bx1), - crop=(cy0, cx0, cy1, cx1), - owned_ids=frozenset(owned), + for row in range(n_rows): + for col in range(n_cols): + owned = tile_to_cells.get((row, col), set()) + if not owned: + continue + by0 = row * tile_size + bx0 = col * tile_size + by1 = min(by0 + tile_size, H) + bx1 = min(bx0 + tile_size, W) + cy0 = max(by0 - margin, 0) + cx0 = max(bx0 - margin, 0) + cy1 = min(by1 + margin, H) + cx1 = min(bx1 + margin, W) + specs.append( + TileSpec( + base=(by0, bx0, by1, bx1), + crop=(cy0, cx0, cy1, cx1), + owned_ids=frozenset(owned), + ) ) - ) - return specs @@ -286,30 +373,41 @@ def build_tile_specs( # --------------------------------------------------------------------------- +def _zero_non_owned(tile_labels: np.ndarray, owned: frozenset[int]) -> np.ndarray: + """Return a copy of ``tile_labels`` with non-owned labels set to 0.""" + if not owned: + return np.zeros_like(tile_labels) + owned_arr = np.fromiter(owned, dtype=tile_labels.dtype, count=len(owned)) + keep = np.isin(tile_labels, owned_arr) + return np.where(keep, tile_labels, 0) + + def extract_tile( image: np.ndarray, labels: np.ndarray, spec: TileSpec, ) -> tuple[np.ndarray, np.ndarray]: - """Extract a tile from numpy arrays, zeroing out non-owned cells. + """Extract a tile's image and mask from in-memory arrays. Parameters ---------- image - ``(C, H, W)`` numpy array. + 3-D array of shape ``(C, H, W)``. labels - ``(H, W)`` numpy label array. + 2-D integer label image of shape ``(H, W)``. spec Tile specification. Returns ------- - tile_image, tile_labels + tile_image + Cropped image of shape ``(C, crop_h, crop_w)``. + tile_labels + Cropped label image with non-owned cells zeroed out. """ cy0, cx0, cy1, cx1 = spec.crop tile_image = image[:, cy0:cy1, cx0:cx1] - tile_labels = labels[cy0:cy1, cx0:cx1].copy() - _zero_non_owned(tile_labels, spec.owned_ids) + tile_labels = _zero_non_owned(labels[cy0:cy1, cx0:cx1], spec.owned_ids) return tile_image, tile_labels @@ -318,32 +416,30 @@ def extract_tile_lazy( labels_da: xr.DataArray, spec: TileSpec, ) -> tuple[np.ndarray, np.ndarray]: - """Extract a tile from dask-backed xarray arrays. - - Materializes only the tile's crop region (~2k×2k), not the full image. + """Extract a tile by lazy slicing then materializing only the crop. Parameters ---------- image_da - ``(c, y, x)`` dask-backed DataArray. + Lazy DataArray of shape ``(C, H, W)``. labels_da - ``(y, x)`` dask-backed DataArray. + Lazy 2-D DataArray of labels. spec Tile specification. Returns ------- tile_image - ``(C, crop_h, crop_w)`` numpy array. + Numpy ``(C, crop_h, crop_w)``. tile_labels - ``(crop_h, crop_w)`` numpy array with non-owned cells zeroed. + Numpy ``(crop_h, crop_w)`` with non-owned cells zeroed out. """ cy0, cx0, cy1, cx1 = spec.crop tile_image = image_da.isel(y=slice(cy0, cy1), x=slice(cx0, cx1)).values - tile_labels = labels_da.isel(y=slice(cy0, cy1), x=slice(cx0, cx1)).values.copy() - if tile_labels.ndim > 2: - tile_labels = tile_labels.squeeze() - _zero_non_owned(tile_labels, spec.owned_ids) + tile_labels_raw = labels_da.isel(y=slice(cy0, cy1), x=slice(cx0, cx1)).values + if tile_labels_raw.ndim > 2: + tile_labels_raw = tile_labels_raw.squeeze() + tile_labels = _zero_non_owned(tile_labels_raw, spec.owned_ids) return tile_image, tile_labels @@ -351,58 +447,27 @@ def extract_labels_tile_lazy( labels_da: xr.DataArray, spec: TileSpec, ) -> np.ndarray: - """Extract a labels-only tile from a dask-backed DataArray. + """Extract a labels-only tile from a lazy DataArray. - Like :func:`extract_tile_lazy` but skips the image entirely. - Materializes only the crop region. + Like :func:`extract_tile_lazy` but skips the image entirely. Materializes + only the crop region. Parameters ---------- labels_da - ``(y, x)`` dask-backed DataArray. + Lazy 2-D DataArray of labels. spec Tile specification. Returns ------- - ``(crop_h, crop_w)`` numpy array with non-owned cells zeroed. + Numpy ``(crop_h, crop_w)`` with non-owned cells zeroed out. """ cy0, cx0, cy1, cx1 = spec.crop - tile_labels = labels_da.isel(y=slice(cy0, cy1), x=slice(cx0, cx1)).values.copy() - if tile_labels.ndim > 2: - tile_labels = tile_labels.squeeze() - _zero_non_owned(tile_labels, spec.owned_ids) - return tile_labels - - -def _zero_non_owned(tile_labels: np.ndarray, owned_ids: frozenset[int]) -> None: - """Zero out labels not in *owned_ids* (in-place). - - Uses a boolean lookup table indexed by label ID for O(n) per-pixel - cost when label IDs are dense. Falls back to :func:`numpy.isin` - when the maximum label ID is large relative to the tile size, so - sparse-but-large ID spaces (e.g. globally-unique segmentation IDs - from multi-FOV pipelines) don't allocate an oversized LUT. - """ - if tile_labels.size == 0: - return - - if not owned_ids: - tile_labels[:] = 0 - return - - max_id = int(tile_labels.max()) - # LUT is cheaper than np.isin only when max_id fits in roughly one - # tile's worth of bool entries; above that, the alloc dominates. - if max_id < tile_labels.size: - lut = np.zeros(max_id + 1, dtype=bool) - for lid in owned_ids: - if lid <= max_id: - lut[lid] = True - tile_labels[~lut[tile_labels]] = 0 - else: - owned_arr = np.fromiter(owned_ids, dtype=tile_labels.dtype, count=len(owned_ids)) - tile_labels[~np.isin(tile_labels, owned_arr)] = 0 + tile_labels_raw = labels_da.isel(y=slice(cy0, cy1), x=slice(cx0, cx1)).values + if tile_labels_raw.ndim > 2: + tile_labels_raw = tile_labels_raw.squeeze() + return _zero_non_owned(tile_labels_raw, spec.owned_ids) # --------------------------------------------------------------------------- @@ -410,35 +475,29 @@ def _zero_non_owned(tile_labels: np.ndarray, owned_ids: frozenset[int]) -> None: # --------------------------------------------------------------------------- -def verify_coverage( - all_label_ids: set[int], - specs: list[TileSpec], -) -> None: +def verify_coverage(label_ids: set[int], specs: list[TileSpec]) -> None: """Assert that tile specs provide full, non-overlapping cell coverage. Parameters ---------- - all_label_ids - Set of all nonzero label IDs expected in the image. + label_ids + Set of all expected nonzero label IDs. specs - Tile specifications to verify. + Tile specifications. Raises ------ - ValueError - If any cell is missing or assigned to more than one tile. + AssertionError + If any cell is missing, duplicated, or unknown. """ owned_union: set[int] = set() for spec in specs: overlap = owned_union & spec.owned_ids - if overlap: - raise ValueError(f"Cells {overlap} assigned to multiple tiles") + assert not overlap, f"Cells {overlap} assigned to multiple tiles" owned_union |= spec.owned_ids - missing = all_label_ids - owned_union - if missing: - raise ValueError(f"Cells {missing} not assigned to any tile") + missing = label_ids - owned_union + assert not missing, f"Cells {missing} not assigned to any tile" - extra = owned_union - all_label_ids - if extra: - raise ValueError(f"Tile specs reference non-existent labels {extra}") + extra = owned_union - label_ids + assert not extra, f"Tile specs reference non-existent labels {extra}" diff --git a/tests/_images/TilingVisual_tile_assignment_gap.png b/tests/_images/TilingVisual_tile_assignment_gap.png new file mode 100644 index 000000000..1fb306de3 Binary files /dev/null and b/tests/_images/TilingVisual_tile_assignment_gap.png differ diff --git a/tests/_images/TilingVisual_tile_assignment_touching.png b/tests/_images/TilingVisual_tile_assignment_touching.png new file mode 100644 index 000000000..f8b3ac98d Binary files /dev/null and b/tests/_images/TilingVisual_tile_assignment_touching.png differ diff --git a/tests/experimental/test_calculate_image_features.py b/tests/experimental/test_calculate_image_features.py new file mode 100644 index 000000000..4a7e50e41 --- /dev/null +++ b/tests/experimental/test_calculate_image_features.py @@ -0,0 +1,640 @@ +"""Tests for calculate_image_features. + +Uses a small synthetic SpatialData (200×200 image, ~20 cells) so tests +run in seconds without downloading real data. +""" + +from __future__ import annotations + +import anndata as ad +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from spatialdata import SpatialData +from spatialdata.models import Image2DModel, Labels2DModel + +import squidpy as sq + + +@pytest.fixture() +def sdata_synthetic(): + """Synthetic SpatialData with a small 3-channel image and ~20 rectangular cells.""" + rng = np.random.default_rng(42) + H, W, C = 200, 200, 3 + + image_data = rng.integers(0, 255, (C, H, W), dtype=np.uint8) + image_xr = xr.DataArray( + image_data, + dims=["c", "y", "x"], + coords={"c": ["R", "G", "B"]}, + ) + + # Place ~20 rectangular cells in a grid (non-overlapping, 30×30 each) + labels_data = np.zeros((H, W), dtype=np.int32) + cell_id = 0 + for y in range(10, H - 30, 40): + for x in range(10, W - 30, 40): + cell_id += 1 + labels_data[y : y + 30, x : x + 30] = cell_id + + labels_xr = xr.DataArray(labels_data, dims=["y", "x"]) + + return SpatialData( + images={"test_img": Image2DModel.parse(image_xr)}, + labels={"test_labels": Labels2DModel.parse(labels_xr)}, + ) + + +class TestCalculateImageFeatures: + """Tests for calculate_image_features function.""" + + # --- Basic functionality --- + + def test_skimage_label_inplace(self, sdata_synthetic): + """Inplace stores AnnData in sdata.tables.""" + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["skimage:label"], + adata_key_added="morphology", + inplace=True, + ) + + assert "morphology" in sdata_synthetic.tables + adata = sdata_synthetic.tables["morphology"] + assert adata.n_obs > 0 + assert adata.n_vars > 0 + assert "spatialdata_attrs" in adata.uns + assert adata.uns["spatialdata_attrs"]["region"] == "test_labels" + assert "region" in adata.obs + assert "label_id" in adata.obs + + def test_not_inplace_returns_anndata(self, sdata_synthetic): + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["skimage:label"], + inplace=False, + ) + assert isinstance(result, ad.AnnData) + assert result.n_obs > 0 + assert result.n_vars > 0 + + # --- Feature sources --- + + def test_skimage_label_properties(self, sdata_synthetic): + """skimage:label produces mask-only morphological features.""" + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["skimage:label"], + inplace=False, + ) + assert "area" in result.var_names + + def test_skimage_label_single_property(self, sdata_synthetic): + """Fine-grained: skimage:label:area → only area column.""" + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["skimage:label:area"], + inplace=False, + ) + assert list(result.var_names) == ["area"] + + def test_skimage_intensity(self, sdata_synthetic): + """skimage:label+image produces per-channel intensity features.""" + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["skimage:label+image"], + inplace=False, + ) + assert result.n_vars > 0 + assert any("_" in col for col in result.var_names) + + def test_skimage_intensity_single_property(self, sdata_synthetic): + """Fine-grained: only intensity_mean per channel.""" + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["skimage:label+image:intensity_mean"], + inplace=False, + ) + assert all(col.startswith("intensity_mean_") for col in result.var_names) + assert not any(col.startswith("intensity_max") for col in result.var_names) + + def test_cpmeasure_sizeshape(self, sdata_synthetic): + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["cpmeasure:sizeshape"], + inplace=False, + ) + assert isinstance(result, ad.AnnData) + assert result.n_obs > 0 + assert any("Area" in col for col in result.var_names) + + def test_cpmeasure_intensity(self, sdata_synthetic): + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["cpmeasure:intensity"], + inplace=False, + ) + assert isinstance(result, ad.AnnData) + assert result.n_obs > 0 + # intensity features are per-channel, so column names contain "__" + assert any("__" in col for col in result.var_names) + + def test_cpmeasure_correlation(self, sdata_synthetic): + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["cpmeasure:correlation"], + inplace=False, + ) + assert isinstance(result, ad.AnnData) + assert result.n_obs > 0 + + def test_cpmeasure_default_all(self, sdata_synthetic): + """features=None enables all cp_measure features.""" + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + inplace=False, + ) + assert isinstance(result, ad.AnnData) + assert result.n_obs > 0 + assert result.n_vars > 50 + + def test_squidpy_summary(self, sdata_synthetic): + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["squidpy:summary"], + inplace=False, + ) + assert isinstance(result, ad.AnnData) + assert result.n_obs > 0 + assert any(col.startswith("summary_mean") for col in result.var_names) + + def test_squidpy_texture(self, sdata_synthetic): + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["squidpy:texture"], + inplace=False, + ) + assert isinstance(result, ad.AnnData) + assert result.n_obs > 0 + assert any(col.startswith("texture_contrast") for col in result.var_names) + + def test_squidpy_color_hist(self, sdata_synthetic): + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["squidpy:color_hist"], + inplace=False, + ) + assert isinstance(result, ad.AnnData) + assert result.n_obs > 0 + assert any(col.startswith("color_hist_bin") for col in result.var_names) + + # --- Mixed sources --- + + def test_mixed_cpmeasure_and_skimage(self, sdata_synthetic): + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["cpmeasure:sizeshape", "skimage:label:area"], + inplace=False, + ) + assert "area" in result.var_names + assert any("Area" in col for col in result.var_names) + + # --- Validation errors --- + + def test_invalid_image_key(self, sdata_synthetic): + with pytest.raises(ValueError, match="Image key 'nonexistent' not found"): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="nonexistent", + labels_key="test_labels", + features=["skimage:label"], + ) + + def test_invalid_labels_key(self, sdata_synthetic): + with pytest.raises(ValueError, match="Labels key 'nonexistent' not found"): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="nonexistent", + features=["skimage:label"], + ) + + def test_both_labels_and_shapes_error(self, sdata_synthetic): + with pytest.raises(ValueError, match="Use either"): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + shapes_key="fake", + features=["skimage:label"], + ) + + def test_missing_labels_and_shapes(self, sdata_synthetic): + with pytest.raises(ValueError, match="Provide either"): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + features=["skimage:label"], + ) + + def test_invalid_feature(self, sdata_synthetic): + with pytest.raises(ValueError, match="Unknown feature"): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["nonexistent:measurement"], + ) + + def test_no_valid_features(self, sdata_synthetic): + with pytest.raises(ValueError, match="No valid features requested"): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=[], + inplace=False, + ) + + def test_dimension_mismatch_identity_succeeds_with_overlap(self, capsys): + """Image and labels have different dims but identity transforms; + we crop to the overlap rectangle and process the cells that fall inside. + Regression for the PR #982 'fails despite alignment transform' note. + """ + rng = np.random.default_rng(42) + image_xr = xr.DataArray( + rng.integers(0, 255, (3, 200, 200), dtype=np.uint8), + dims=["c", "y", "x"], + coords={"c": ["R", "G", "B"]}, + ) + labels_arr = np.zeros((100, 100), dtype=np.int32) + labels_arr[10:40, 10:40] = 1 + labels_arr[60:90, 60:90] = 2 + labels_xr = xr.DataArray(labels_arr, dims=["y", "x"]) + sdata = SpatialData( + images={"img": Image2DModel.parse(image_xr)}, + labels={"lbl": Labels2DModel.parse(labels_xr)}, + ) + + adata = sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + labels_key="lbl", + features=["skimage:label"], + inplace=False, + ) + assert adata is not None and adata.n_obs == 2 + captured = capsys.readouterr() + # No cells dropped: labels (100, 100) sits fully inside image (200, 200) at origin + assert "Cell drop report" in captured.out + + # --- Channel selection --- + + def test_channel_selection_by_name(self, sdata_synthetic): + """Selecting a single channel reduces feature columns.""" + result_all = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["skimage:label+image:intensity_mean"], + inplace=False, + ) + # Image2DModel.parse converts channel coords to integers [0,1,2] + result_one = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + channels=["0"], + features=["skimage:label+image:intensity_mean"], + inplace=False, + ) + # All channels → 3 columns; one channel → 1 column + assert result_all.n_vars == 3 + assert result_one.n_vars == 1 + assert "intensity_mean_0" in result_one.var_names + + def test_channel_selection_rejects_int(self, sdata_synthetic): + """Integer channel indices are no longer accepted -- names only.""" + with pytest.raises(TypeError, match="channels must contain strings"): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + channels=[0], # int, not str -- should fail validation + features=["squidpy:summary"], + inplace=False, + ) + + def test_channel_selection_invalid(self, sdata_synthetic): + with pytest.raises(ValueError, match="Channel 'DAPI' not found"): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + channels=["DAPI"], + features=["skimage:label"], + ) + + # --- Tiled vs non-tiled equivalence --- + + def test_tiled_vs_single_tile_equivalence(self, sdata_synthetic): + """Tile-invariant features should be identical whether we tile or not. + + Position-dependent features (centroid, perimeter_crofton) are expected + to differ across tile boundaries, so we test with ``area`` and + ``squidpy:summary`` which depend only on the cell's pixel values. + """ + kw = { + "image_key": "test_img", + "labels_key": "test_labels", + "features": ["skimage:label:area", "squidpy:summary"], + "inplace": False, + "invalid_as_zero": True, + } + # Single tile (tile_size >= image → no tiling) + result_single = sq.experimental.im.calculate_image_features(sdata_synthetic, tile_size=1000, **kw) + # Multiple tiles (tile_size=100 → 4 tiles on 200×200) + result_tiled = sq.experimental.im.calculate_image_features(sdata_synthetic, tile_size=100, **kw) + + # Same cells, same features + assert result_single.n_obs == result_tiled.n_obs + assert set(result_single.var_names) == set(result_tiled.var_names) + + # Align columns and rows for comparison + common_cols = list(result_single.var_names) + df_single = pd.DataFrame(result_single.X, index=result_single.obs["label_id"].values, columns=common_cols) + df_tiled = pd.DataFrame( + result_tiled[:, common_cols].X, index=result_tiled.obs["label_id"].values, columns=common_cols + ) + df_single = df_single.sort_index() + df_tiled = df_tiled.sort_index() + + np.testing.assert_array_equal(df_single.index, df_tiled.index) + np.testing.assert_allclose(df_single.values, df_tiled.values, rtol=1e-5, atol=1e-5) + + # --- Parallelization --- + + def test_n_jobs_produces_same_result(self, sdata_synthetic): + """n_jobs>1 produces the same result as n_jobs=1.""" + kw = { + "image_key": "test_img", + "labels_key": "test_labels", + "features": ["skimage:label:area"], + "inplace": False, + } + result_seq = sq.experimental.im.calculate_image_features(sdata_synthetic, n_jobs=1, **kw) + result_par = sq.experimental.im.calculate_image_features(sdata_synthetic, n_jobs=2, **kw) + + assert result_seq.n_obs == result_par.n_obs + np.testing.assert_array_equal( + result_seq.X[np.argsort(result_seq.obs["label_id"].values)], + result_par.X[np.argsort(result_par.obs["label_id"].values)], + ) + + +# --------------------------------------------------------------------------- +# Per-PR-#982-concern regression tests +# --------------------------------------------------------------------------- + + +def _toy_sdata( + image_shape: tuple[int, int] = (200, 200), + n_channels: int = 3, + channel_names: list[str] | None = None, + labels_shape: tuple[int, int] | None = None, + labels_translation: tuple[float, float] | None = None, + labels_scale: tuple[float, float] | None = None, + label_ids: list[int] | None = None, +) -> SpatialData: + """Build a synthetic SpatialData with controllable label/image transforms.""" + from spatialdata.transformations import Scale, Translation, set_transformation + + rng = np.random.default_rng(0) + H, W = image_shape + image_data = rng.integers(0, 255, (n_channels, H, W), dtype=np.uint8) + image_xr = xr.DataArray(image_data, dims=["c", "y", "x"]) + + LH, LW = labels_shape if labels_shape is not None else image_shape + labels_data = np.zeros((LH, LW), dtype=np.int32) + ids = label_ids if label_ids is not None else list(range(1, 6)) + cell_h, cell_w = max(LH // 8, 4), max(LW // 8, 4) + for i, lid in enumerate(ids): + row = i // 3 + col = i % 3 + y0 = 10 + row * (cell_h + 6) + x0 = 10 + col * (cell_w + 6) + if y0 + cell_h > LH or x0 + cell_w > LW: + continue + labels_data[y0 : y0 + cell_h, x0 : x0 + cell_w] = lid + labels_xr = xr.DataArray(labels_data, dims=["y", "x"]) + + img_el = ( + Image2DModel.parse(image_xr, c_coords=channel_names) + if channel_names is not None + else Image2DModel.parse(image_xr) + ) + lbl_el = Labels2DModel.parse(labels_xr) + + if labels_translation is not None: + ty, tx = labels_translation + set_transformation(lbl_el, Translation([tx, ty], axes=("x", "y")), "global") + if labels_scale is not None: + sy, sx = labels_scale + set_transformation(lbl_el, Scale([sx, sy], axes=("x", "y")), "global") + + return SpatialData(images={"img": img_el}, labels={"lbl": lbl_el}) + + +class TestPR982Concerns: + """Regression tests for the six open concerns on PR #982.""" + + # -- Concern 1: channel names are str-typed in output columns -- + + def test_concern1_channel_str_names_in_columns(self): + sdata = _toy_sdata(channel_names=["DAPI", "CD3", "CD8"]) + adata = sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + labels_key="lbl", + features=["squidpy:summary"], + inplace=False, + ) + cols = list(adata.var_names) + assert any("_DAPI" in c for c in cols) + assert any("_CD3" in c for c in cols) + assert any("_CD8" in c for c in cols) + # Make sure the numeric-fallback names did not slip in: + assert not any(c.endswith("_0") or c.endswith("_1") or c.endswith("_2") for c in cols) + + # -- Concern 2: progress logs are emitted -- + + def test_concern2_progress_log_emitted(self, capsys): + sdata = _toy_sdata() + sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + labels_key="lbl", + features=["skimage:label:area"], + tile_size=80, # forces >1 tile on 200x200 + inplace=False, + ) + captured = capsys.readouterr() + import re + + # spatialdata's logger renders via rich and injects ANSI escapes + # between tokens, so the digits in "Tile 1/9" are wrapped. + ansi_re = re.compile(r"\x1b\[[0-9;]*m") + plain = ansi_re.sub("", captured.out) + assert re.search(r"Tile \d+/\d+", plain), f"no progress log in:\n{plain}" + + # -- Concern 3 (a): identity transforms, dim mismatch -> overlap path -- + + def test_concern3_identity_dim_mismatch_uses_overlap(self, capsys): + # labels (100, 100) sits inside image (200, 200) at the origin under Identity. + sdata = _toy_sdata(image_shape=(200, 200), labels_shape=(100, 100)) + adata = sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + labels_key="lbl", + features=["skimage:label:area"], + inplace=False, + ) + captured = capsys.readouterr() + assert "Cell drop report" in captured.out + # All cells fit inside the (100, 100) labels image, so none are dropped: + assert adata.n_obs > 0 + + # -- Concern 3 (b): integer translation -> overlap is offset -- + + def test_concern3_integer_translation(self): + sdata = _toy_sdata( + image_shape=(200, 200), + labels_shape=(100, 100), + labels_translation=(50.0, 0.0), + ) + adata = sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + labels_key="lbl", + features=["skimage:label:area"], + inplace=False, + ) + assert adata.n_obs > 0 + + # -- Concern 3 (c): non-integer scale, strict -> raises with hint -- + + def test_concern3_strict_raises_on_non_integer_scale(self): + sdata = _toy_sdata( + image_shape=(200, 200), + labels_shape=(100, 100), + labels_scale=(1.7, 1.7), + ) + with pytest.raises(ValueError, match="spatialdata.rasterize"): + sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + labels_key="lbl", + features=["skimage:label:area"], + align_mode="strict", + inplace=False, + ) + + # -- Concern 3 (d): align_mode='rasterize' resamples and proceeds -- + + def test_concern3_rasterize_path_succeeds(self, capsys): + sdata = _toy_sdata( + image_shape=(200, 200), + labels_shape=(100, 100), + labels_scale=(1.7, 1.7), + ) + adata = sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + labels_key="lbl", + features=["skimage:label:area"], + align_mode="rasterize", + inplace=False, + ) + assert adata.n_obs > 0 + captured = capsys.readouterr() + # rasterize path emits a warning about materialization + assert "rasterize" in captured.out.lower() or "rasterize" in captured.err.lower() + + # -- Concern 4: channel subset selection -- + + def test_concern4_channel_subset_by_name(self): + sdata = _toy_sdata(n_channels=4, channel_names=["c0", "c1", "c2", "c3"]) + adata = sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + labels_key="lbl", + features=["squidpy:summary"], + channels=["c0", "c2"], + inplace=False, + ) + cols = list(adata.var_names) + assert any("_c0" in c for c in cols) + assert any("_c2" in c for c in cols) + assert not any("_c1" in c for c in cols) + assert not any("_c3" in c for c in cols) + + # -- Concern 5: spatialdata_attrs on output table -- + + def test_concern5_spatialdata_attrs_present(self): + sdata = _toy_sdata() + sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + labels_key="lbl", + features=["skimage:label:area"], + inplace=True, + adata_key_added="morphology", + ) + attrs = sdata.tables["morphology"].uns["spatialdata_attrs"] + assert "region" in attrs + assert "region_key" in attrs + assert "instance_key" in attrs + assert attrs["region"] == "lbl" + + # -- Concern 6: non-contiguous label IDs survive cp_measure roundtrip -- + + def test_concern6_non_contiguous_label_ids(self): + sdata = _toy_sdata(label_ids=[1, 37, 82]) + adata = sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + labels_key="lbl", + features=["skimage:label:area"], + inplace=False, + ) + observed = set(adata.obs["label_id"].astype(int).tolist()) + assert {1, 37, 82}.issubset(observed) diff --git a/tests/experimental/test_tiling.py b/tests/experimental/test_tiling.py new file mode 100644 index 000000000..d8cf01b4a --- /dev/null +++ b/tests/experimental/test_tiling.py @@ -0,0 +1,530 @@ +"""Tests for cell-aware tiling logic. + +Uses a deterministic "brick-pattern" grid of rectangular cells on a +500×500 image. Even rows are aligned; odd rows are shifted right by +half a cell width, like bricks in a wall. The image divides into 4 +tiles of 250×250. Because cell positions are predictable we can check +*exactly* which cell lands in which tile. +""" + +from __future__ import annotations + +import matplotlib.pyplot as plt +import numpy as np +import pytest +import xarray as xr + +from squidpy.experimental.im._tiling import ( + build_tile_specs, + compute_cell_info, + compute_cell_info_multiscale, + compute_cell_info_tiled, + extract_tile, + extract_tile_lazy, + verify_coverage, +) +from tests.conftest import PlotTester, PlotTesterMeta + +# --------------------------------------------------------------------------- +# Brick-pattern fixture +# --------------------------------------------------------------------------- + +_IMAGE_SIZE = 500 +_CELL_H = 20 +_CELL_W = 30 + + +def _make_brick_labels( + image_size: int = _IMAGE_SIZE, + cell_h: int = _CELL_H, + cell_w: int = _CELL_W, + gap: int = 10, +) -> tuple[np.ndarray, dict[int, tuple[float, float]]]: + """Create a brick-pattern label image and return centroids. + + Parameters + ---------- + image_size + Side length of the square image. + cell_h, cell_w + Height and width of each rectangular cell. + gap + Gap between cells (0 = touching). + + Returns + ------- + labels + ``(image_size, image_size)`` int32 array. + centroids + Mapping from label ID → ``(centroid_y, centroid_x)``. + """ + labels = np.zeros((image_size, image_size), dtype=np.int32) + centroids: dict[int, tuple[float, float]] = {} + + step_y = cell_h + gap + step_x = cell_w + gap + cell_id = 0 + + row_idx = 0 + y = gap // 2 # start with half-gap from top + while y + cell_h <= image_size: + # Odd rows shift right by half a cell+gap step + x_offset = (step_x // 2) if (row_idx % 2 == 1) else 0 + x = x_offset + gap // 2 + while x + cell_w <= image_size: + cell_id += 1 + labels[y : y + cell_h, x : x + cell_w] = cell_id + # Match regionprops centroid: mean of pixel indices [y, y+cell_h-1] + cy = y + (cell_h - 1) / 2.0 + cx = x + (cell_w - 1) / 2.0 + centroids[cell_id] = (cy, cx) + x += step_x + y += step_y + row_idx += 1 + + return labels, centroids + + +def _make_image(image_size: int = _IMAGE_SIZE, n_channels: int = 3) -> np.ndarray: + rng = np.random.default_rng(42) + return rng.integers(0, 255, (n_channels, image_size, image_size), dtype=np.uint8) + + +def _expected_tile_key(cy: float, cx: float, tile_size: int, image_size: int) -> tuple[int, int]: + """Which tile base-grid cell a centroid falls into.""" + max_row = (image_size - 1) // tile_size + max_col = (image_size - 1) // tile_size + row = min(int(cy) // tile_size, max_row) + col = min(int(cx) // tile_size, max_col) + return (row, col) + + +_TILE_SIZE = 250 # 500 / 250 = 2×2 = 4 tiles + + +def _specs_from_labels(labels, tile_size=_TILE_SIZE, overlap_margin="auto"): + """Convenience: compute cell info + build tile specs from a numpy label array.""" + cell_info = compute_cell_info(labels) + return build_tile_specs(labels.shape, cell_info, tile_size=tile_size, overlap_margin=overlap_margin) + + +def _label_ids(labels): + """All nonzero label IDs as a set.""" + ids = set(np.unique(labels).tolist()) + ids.discard(0) + return ids + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(params=[10, 0], ids=["gap=10", "gap=0"]) +def brick_labels(request): + """Brick-pattern labels with gap (non-touching) or without (touching).""" + gap = request.param + labels, centroids = _make_brick_labels(gap=gap) + return labels, centroids, gap + + +@pytest.fixture() +def brick_image(): + return _make_image() + + +# --------------------------------------------------------------------------- +# build_tile_specs — deterministic checks +# --------------------------------------------------------------------------- + + +class TestBuildTileSpecs: + def test_four_tiles(self, brick_labels): + """500×500 with tile_size=250 produces at most 4 tiles.""" + labels, _, _ = brick_labels + specs = _specs_from_labels(labels, tile_size=_TILE_SIZE) + assert len(specs) <= 4 + + def test_full_coverage(self, brick_labels): + """Every cell is assigned to exactly one tile.""" + labels, _, _ = brick_labels + specs = _specs_from_labels(labels, tile_size=_TILE_SIZE) + verify_coverage(_label_ids(labels), specs) + + def test_cell_assigned_to_centroid_tile(self, brick_labels): + """Each cell's tile matches the tile we predict from its centroid.""" + labels, centroids, _ = brick_labels + specs = _specs_from_labels(labels, tile_size=_TILE_SIZE) + + # Build actual mapping: cell_id → tile base origin + actual: dict[int, tuple[int, int]] = {} + for spec in specs: + for lid in spec.owned_ids: + actual[lid] = (spec.base[0], spec.base[1]) + + for lid, (cy, cx) in centroids.items(): + expected_row, expected_col = _expected_tile_key(cy, cx, _TILE_SIZE, _IMAGE_SIZE) + expected_origin = (expected_row * _TILE_SIZE, expected_col * _TILE_SIZE) + assert actual[lid] == expected_origin, ( + f"Cell {lid} centroid=({cy:.1f},{cx:.1f}): expected tile origin {expected_origin}, got {actual[lid]}" + ) + + def test_no_duplicates(self, brick_labels): + """No cell ID appears in more than one tile.""" + labels, _, _ = brick_labels + specs = _specs_from_labels(labels, tile_size=_TILE_SIZE) + + seen: set[int] = set() + for spec in specs: + overlap = seen & spec.owned_ids + assert not overlap, f"Duplicate cell IDs: {overlap}" + seen |= spec.owned_ids + + def test_boundary_cells_exist(self, brick_labels): + """With the brick offset, some cells straddle the y=250 or x=250 boundary.""" + labels, centroids, gap = brick_labels + # A cell straddles a boundary if its rectangle crosses y=250 or x=250 + # but its centroid is on one side + boundary_cells = [] + for lid, (cy, cx) in centroids.items(): + half_h = _CELL_H / 2.0 + half_w = _CELL_W / 2.0 + y0, y1 = cy - half_h, cy + half_h + x0, x1 = cx - half_w, cx + half_w + crosses_y = y0 < 250 < y1 + crosses_x = x0 < 250 < x1 + if crosses_y or crosses_x: + boundary_cells.append(lid) + + # With cell_h=20 and various gaps, we expect some boundary cells + # (the brick offset makes this likely for odd rows near y=250) + # Just verify they're all assigned somewhere + specs = _specs_from_labels(labels, tile_size=_TILE_SIZE) + all_owned = set() + for s in specs: + all_owned |= s.owned_ids + for lid in boundary_cells: + assert lid in all_owned, f"Boundary cell {lid} not assigned" + + def test_crop_contains_owned_cells_fully(self, brick_labels): + """Every owned cell's rectangle fits inside its tile's crop region.""" + labels, centroids, _ = brick_labels + specs = _specs_from_labels(labels, tile_size=_TILE_SIZE, overlap_margin="auto") + + for spec in specs: + cy0, cx0, cy1, cx1 = spec.crop + for lid in spec.owned_ids: + cent_y, cent_x = centroids[lid] + # Reconstruct cell pixel range from centroid + # Centroid is mean of [y, y+cell_h-1], so half-extent = (cell_h-1)/2 + cell_y0 = cent_y - (_CELL_H - 1) / 2.0 + cell_y1 = cent_y + (_CELL_H - 1) / 2.0 + cell_x0 = cent_x - (_CELL_W - 1) / 2.0 + cell_x1 = cent_x + (_CELL_W - 1) / 2.0 + assert cy0 <= cell_y0 and cell_y1 <= cy1, ( + f"Cell {lid} y-range [{cell_y0:.0f},{cell_y1:.0f}] not in crop y-range [{cy0},{cy1}]" + ) + assert cx0 <= cell_x0 and cell_x1 <= cx1, ( + f"Cell {lid} x-range [{cell_x0:.0f},{cell_x1:.0f}] not in crop x-range [{cx0},{cx1}]" + ) + + +class TestBuildTileSpecsEdgeCases: + def test_empty_labels(self): + labels = np.zeros((500, 500), dtype=np.int32) + specs = _specs_from_labels(labels, tile_size=250) + assert specs == [] + verify_coverage(_label_ids(labels), specs) + + def test_single_cell_whole_image(self): + """One cell that fills most of the image.""" + labels = np.zeros((500, 500), dtype=np.int32) + labels[10:490, 10:490] = 1 + specs = _specs_from_labels(labels, tile_size=250) + verify_coverage(_label_ids(labels), specs) + assert len(specs) == 1 # centroid is at ~(250,250), lands in one tile + + def test_invalid_tile_size(self): + with pytest.raises(ValueError, match="tile_size must be positive"): + build_tile_specs((100, 100), {}, tile_size=0) + + def test_tile_size_larger_than_image(self): + """tile_size > image → single tile.""" + labels, _ = _make_brick_labels(image_size=100, gap=5) + specs = _specs_from_labels(labels, tile_size=1000) + verify_coverage(_label_ids(labels), specs) + assert len(specs) == 1 + + +# --------------------------------------------------------------------------- +# extract_tile +# --------------------------------------------------------------------------- + + +class TestExtractTile: + def test_non_owned_cells_zeroed(self, brick_labels, brick_image): + """Only owned cells survive in the extracted tile mask.""" + labels, _, _ = brick_labels + specs = _specs_from_labels(labels, tile_size=_TILE_SIZE) + + for spec in specs: + _, tile_lbl = extract_tile(brick_image, labels, spec) + present = set(np.unique(tile_lbl)) + present.discard(0) + assert present == spec.owned_ids, f"Tile base={spec.base}: expected {spec.owned_ids}, got {present}" + + def test_owned_cell_pixels_preserved(self, brick_labels, brick_image): + """Pixel values for owned cells match the original labels.""" + labels, _, _ = brick_labels + specs = _specs_from_labels(labels, tile_size=_TILE_SIZE) + + for spec in specs: + cy0, cx0, cy1, cx1 = spec.crop + _, tile_lbl = extract_tile(brick_image, labels, spec) + for lid in spec.owned_ids: + orig_in_crop = labels[cy0:cy1, cx0:cx1] == lid + tile_matches = tile_lbl == lid + np.testing.assert_array_equal(orig_in_crop, tile_matches) + + def test_original_labels_not_mutated(self, brick_labels, brick_image): + labels, _, _ = brick_labels + labels_copy = labels.copy() + specs = _specs_from_labels(labels, tile_size=_TILE_SIZE) + for spec in specs: + extract_tile(brick_image, labels, spec) + np.testing.assert_array_equal(labels, labels_copy) + + def test_image_crop_shape(self, brick_labels, brick_image): + """Extracted image has shape (C, crop_h, crop_w).""" + labels, _, _ = brick_labels + specs = _specs_from_labels(labels, tile_size=_TILE_SIZE) + for spec in specs: + tile_img, tile_lbl = extract_tile(brick_image, labels, spec) + cy0, cx0, cy1, cx1 = spec.crop + assert tile_img.shape == (3, cy1 - cy0, cx1 - cx0) + assert tile_lbl.shape == (cy1 - cy0, cx1 - cx0) + + +# --------------------------------------------------------------------------- +# End-to-end roundtrip +# --------------------------------------------------------------------------- + + +class TestEndToEnd: + def test_roundtrip_no_cells_lost(self, brick_labels, brick_image): + """Build specs → extract tiles → union of labels == all cells.""" + labels, centroids, _ = brick_labels + specs = _specs_from_labels(labels, tile_size=_TILE_SIZE) + verify_coverage(_label_ids(labels), specs) + + recovered: set[int] = set() + for spec in specs: + _, tile_lbl = extract_tile(brick_image, labels, spec) + tile_ids = set(np.unique(tile_lbl)) + tile_ids.discard(0) + assert tile_ids == spec.owned_ids + recovered |= tile_ids + + assert recovered == set(centroids.keys()) + + def test_touching_cells_no_merge(self): + """With gap=0, adjacent cells still get distinct labels and assignments.""" + labels, centroids = _make_brick_labels(gap=0) + n_cells = len(centroids) + assert n_cells > 0 + + specs = _specs_from_labels(labels, tile_size=_TILE_SIZE) + verify_coverage(_label_ids(labels), specs) + + # Total owned cells across all tiles == total cells + total_owned = sum(len(s.owned_ids) for s in specs) + assert total_owned == n_cells + + def test_nontouching_cells_same_result(self): + """With gap=10, same coverage guarantees hold.""" + labels, centroids = _make_brick_labels(gap=10) + n_cells = len(centroids) + assert n_cells > 0 + + specs = _specs_from_labels(labels, tile_size=_TILE_SIZE) + verify_coverage(_label_ids(labels), specs) + + total_owned = sum(len(s.owned_ids) for s in specs) + assert total_owned == n_cells + + +# --------------------------------------------------------------------------- +# Visual test — tile assignment plot +# --------------------------------------------------------------------------- + +# Tile colors: one distinct color per tile quadrant +_TILE_COLORS = [ + (0.12, 0.47, 0.71), # blue — top-left + (1.00, 0.50, 0.05), # orange — top-right + (0.17, 0.63, 0.17), # green — bottom-left + (0.84, 0.15, 0.16), # red — bottom-right +] + + +def _plot_tile_assignment(labels, specs, title=""): + """Render each cell colored by its owning tile, with grid lines.""" + rgb = np.ones((*labels.shape, 3), dtype=np.float32) # white background + + for i, spec in enumerate(specs): + color = _TILE_COLORS[i % len(_TILE_COLORS)] + for lid in spec.owned_ids: + mask = labels == lid + rgb[mask] = color + + fig, ax = plt.subplots(1, 1, figsize=(6, 6)) + ax.imshow(rgb, origin="upper") + + # Draw tile base-grid lines + for spec in specs: + by0, bx0, by1, bx1 = spec.base + rect = plt.Rectangle( + (bx0 - 0.5, by0 - 0.5), + bx1 - bx0, + by1 - by0, + linewidth=1.5, + edgecolor="black", + facecolor="none", + linestyle="--", + ) + ax.add_patch(rect) + + ax.set_xlim(-0.5, labels.shape[1] - 0.5) + ax.set_ylim(labels.shape[0] - 0.5, -0.5) + ax.set_title(title or "Tile assignment") + ax.set_xlabel("x") + ax.set_ylabel("y") + + +# --------------------------------------------------------------------------- +# Lazy / multiscale helpers +# --------------------------------------------------------------------------- + + +def _make_multiscale_tree(labels: np.ndarray, n_scales: int = 3) -> xr.DataTree: + """Build a tiny multiscale DataTree by integer-downsampling.""" + scales: dict[str, xr.DataTree] = {} + for i in range(n_scales): + step = 2**i + sub = labels[::step, ::step] + ds = xr.Dataset({"image": xr.DataArray(sub, dims=("y", "x"))}) + scales[f"scale{i}"] = xr.DataTree(ds) + return xr.DataTree.from_dict(scales) + + +class TestComputeCellInfoMultiscale: + def test_target_is_coarsest_matches_eager(self): + labels, _ = _make_brick_labels(gap=10) + tree = _make_multiscale_tree(labels, n_scales=3) + # scale2 is coarsest. Target it -> use that scale directly. + info_ms = compute_cell_info_multiscale(tree, target_scale="scale2") + info_eager = compute_cell_info(tree["scale2"].ds["image"].values) + assert set(info_ms.keys()) == set(info_eager.keys()) + for lid in info_ms: + assert info_ms[lid].centroid_y == pytest.approx(info_eager[lid].centroid_y, abs=0.5) + assert info_ms[lid].centroid_x == pytest.approx(info_eager[lid].centroid_x, abs=0.5) + + def test_rescale_to_finer(self): + labels, _ = _make_brick_labels(gap=10) + tree = _make_multiscale_tree(labels, n_scales=3) + info_ms = compute_cell_info_multiscale(tree, target_scale="scale0") + info_eager = compute_cell_info(labels) + # Centroids should be close (within ~1 px due to coarse-scale quantization) + assert set(info_ms.keys()) == set(info_eager.keys()) + for lid in info_ms: + assert info_ms[lid].centroid_y == pytest.approx(info_eager[lid].centroid_y, abs=4.0) + assert info_ms[lid].centroid_x == pytest.approx(info_eager[lid].centroid_x, abs=4.0) + + def test_missing_target_raises(self): + labels, _ = _make_brick_labels(gap=10) + tree = _make_multiscale_tree(labels, n_scales=2) + with pytest.raises(ValueError, match="not found in DataTree"): + compute_cell_info_multiscale(tree, target_scale="scale99") + + +class TestComputeCellInfoTiled: + def test_matches_eager_no_cell_spans_tiles(self): + labels, _ = _make_brick_labels(gap=10) # cells are 20x30, well below chunk + labels_da = xr.DataArray(labels, dims=("y", "x")) + info_tiled = compute_cell_info_tiled(labels_da, chunk_size=128) + info_eager = compute_cell_info(labels) + assert set(info_tiled.keys()) == set(info_eager.keys()) + for lid in info_eager: + assert info_tiled[lid].centroid_y == pytest.approx(info_eager[lid].centroid_y, abs=1e-6) + assert info_tiled[lid].centroid_x == pytest.approx(info_eager[lid].centroid_x, abs=1e-6) + assert info_tiled[lid].bbox_h == info_eager[lid].bbox_h + assert info_tiled[lid].bbox_w == info_eager[lid].bbox_w + + def test_matches_eager_cells_span_tile_boundary(self): + # A 100x100 cell crossing chunk boundary at 50. + labels = np.zeros((200, 200), dtype=np.int32) + labels[30:130, 30:130] = 1 + labels_da = xr.DataArray(labels, dims=("y", "x")) + info_tiled = compute_cell_info_tiled(labels_da, chunk_size=50) + info_eager = compute_cell_info(labels) + assert set(info_tiled.keys()) == set(info_eager.keys()) + for lid in info_eager: + assert info_tiled[lid].centroid_y == pytest.approx(info_eager[lid].centroid_y, abs=1e-6) + assert info_tiled[lid].centroid_x == pytest.approx(info_eager[lid].centroid_x, abs=1e-6) + assert info_tiled[lid].bbox_h == info_eager[lid].bbox_h + assert info_tiled[lid].bbox_w == info_eager[lid].bbox_w + + def test_empty_labels(self): + labels = np.zeros((100, 100), dtype=np.int32) + labels_da = xr.DataArray(labels, dims=("y", "x")) + assert compute_cell_info_tiled(labels_da, chunk_size=32) == {} + + +class TestExtractTileLazy: + def test_matches_eager(self, brick_labels, brick_image): + labels, _, _ = brick_labels + specs = _specs_from_labels(labels, tile_size=_TILE_SIZE) + labels_da = xr.DataArray(labels, dims=("y", "x")) + image_da = xr.DataArray(brick_image, dims=("c", "y", "x")) + for spec in specs: + img_e, lbl_e = extract_tile(brick_image, labels, spec) + img_l, lbl_l = extract_tile_lazy(image_da, labels_da, spec) + np.testing.assert_array_equal(img_e, img_l) + np.testing.assert_array_equal(lbl_e, lbl_l) + + +class TestVerifyCoverage: + def test_detects_duplicate(self): + spec_a = build_tile_specs((100, 100), {1: _make_ci(1, 25, 25)}, tile_size=50) + spec_b = build_tile_specs((100, 100), {1: _make_ci(1, 25, 25)}, tile_size=50) + with pytest.raises(AssertionError, match="multiple tiles"): + verify_coverage({1}, spec_a + spec_b) + + def test_detects_missing(self): + specs = build_tile_specs((100, 100), {}, tile_size=50) + with pytest.raises(AssertionError, match="not assigned"): + verify_coverage({42}, specs) + + def test_detects_extra(self): + specs = build_tile_specs((100, 100), {1: _make_ci(1, 25, 25)}, tile_size=50, overlap_margin=0) + with pytest.raises(AssertionError, match="non-existent"): + verify_coverage(set(), specs) + + +def _make_ci(label: int, cy: float, cx: float, h: int = 4, w: int = 4): + from squidpy.experimental.im._tiling import CellInfo + + return CellInfo(label=label, centroid_y=cy, centroid_x=cx, bbox_h=h, bbox_w=w) + + +class TestTilingVisual(PlotTester, metaclass=PlotTesterMeta): + def test_plot_tile_assignment_gap(self): + """Visual: brick pattern (gap=10), cells colored by tile.""" + labels, _ = _make_brick_labels(gap=10) + specs = _specs_from_labels(labels, tile_size=_TILE_SIZE) + _plot_tile_assignment(labels, specs, title="Tile assignment (gap=10)") + + def test_plot_tile_assignment_touching(self): + """Visual: brick pattern (gap=0, touching), cells colored by tile.""" + labels, _ = _make_brick_labels(gap=0) + specs = _specs_from_labels(labels, tile_size=_TILE_SIZE) + _plot_tile_assignment(labels, specs, title="Tile assignment (gap=0, touching)")