Skip to content
Draft
Show file tree
Hide file tree
Changes from 93 commits
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
b5ee428
fix zarr checking
measty Jan 16, 2025
1ac61c5
use cross-version syntax for zarr.group
measty Jan 16, 2025
f12b048
more zarr v3 changes
measty Jan 16, 2025
a0e1650
Merge branch 'develop' into fix-zarr-check
shaneahmed Jan 24, 2025
9d19a5e
:pushpin: Remove Pin from Zarr
shaneahmed Jan 31, 2025
2212eef
Merge branch 'develop' into fix-zarr-check
shaneahmed Feb 7, 2025
eb4cecd
Merge branch 'develop' into fix-zarr-check
shaneahmed Feb 21, 2025
4b65046
Merge branch 'develop' into fix-zarr-check
shaneahmed Mar 3, 2025
67e5088
Merge branch 'develop' into fix-zarr-check
shaneahmed Mar 19, 2025
cba6f24
Merge branch 'develop' into fix-zarr-check
shaneahmed Apr 11, 2025
5557be5
Merge branch 'develop' into fix-zarr-check
shaneahmed Apr 25, 2025
5c9144f
Merge branch 'develop' into fix-zarr-check
shaneahmed May 23, 2025
1b4da41
Merge branch 'develop' into fix-zarr-check
shaneahmed Jun 9, 2025
1eee555
Merge branch 'develop' into fix-zarr-check
shaneahmed Jun 13, 2025
cc36cc7
Merge branch 'develop' into fix-zarr-check
shaneahmed Jun 20, 2025
c93c3da
Merge branch 'develop' into fix-zarr-check
shaneahmed Jul 11, 2025
a2dd86f
Merge branch 'develop' into fix-zarr-check
shaneahmed Aug 15, 2025
8ea2ff6
Merge branch 'develop' into fix-zarr-check
shaneahmed Sep 5, 2025
c2b79b9
Merge branch 'develop' into fix-zarr-check
shaneahmed Oct 3, 2025
beb3d24
Merge branch 'develop' into fix-zarr-check
shaneahmed Oct 10, 2025
464f3ec
Merge branch 'develop' into fix-zarr-check
shaneahmed Oct 16, 2025
32df9b8
Merge branch 'develop' into fix-zarr-check
shaneahmed Mar 5, 2026
4c226a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2026
e6484ac
Merge branch 'develop' into fix-zarr-check
shaneahmed Mar 12, 2026
ec53cc9
:pushpin: Pin `zarr` and `tifffile`
shaneahmed Mar 12, 2026
b25838b
:pushpin: Pin `tifffile`
shaneahmed Mar 12, 2026
9ca7b4d
:pushpin: Update Python Versions
shaneahmed Mar 12, 2026
0bb7d6a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2026
06c041f
:bug: Fix setup.py to include 3.14
shaneahmed Mar 12, 2026
c1e5c97
Merge remote-tracking branch 'origin/dev-remove-python3.10' into dev-…
shaneahmed Mar 12, 2026
6d2c84a
:fire: Remove openslide from requirements.conda.yml
shaneahmed Mar 12, 2026
ff72d53
:fire: Remove Python 3.14
shaneahmed Mar 12, 2026
bdafc46
Merge branch 'dev-remove-python3.10' into fix-zarr-check
shaneahmed Mar 12, 2026
01492d2
Merge branch 'develop' into fix-zarr-check
shaneahmed Mar 16, 2026
11937c8
:bug: Fix Semantic Segmentor
shaneahmed Mar 16, 2026
e137acf
:bug: Fix misc.py
shaneahmed Mar 17, 2026
fe9e16c
:bug: Fix `wsireader.py`
shaneahmed Mar 17, 2026
d12bb1c
:white_check_mark: Force skip test
shaneahmed Mar 18, 2026
2522a54
:bug: Fix NGFF Reader
shaneahmed Mar 19, 2026
9e162ab
:bug: Fix NGFF Reader
shaneahmed Mar 19, 2026
4beb04a
:white_check_mark: Use FsspecStore for remote files.
shaneahmed Mar 19, 2026
750ca02
Merge remote-tracking branch 'origin/fix-zarr-check' into fix-zarr-check
shaneahmed Mar 19, 2026
efd7765
:recycle: Update how `register_codec` is used
shaneahmed Mar 19, 2026
bd7299a
:white_check_mark: Add support to read s3 using NGFFWSIReader
shaneahmed Mar 19, 2026
62daa88
:pencil2: Fix typos
shaneahmed Mar 19, 2026
33d9c8f
:bug: Fix `nucleus_detector.py`
shaneahmed Mar 23, 2026
3fc7f73
:bug: Zarr uses str indexing instead of int
shaneahmed Mar 23, 2026
a1b86ff
:white_check_mark: Test with np full
shaneahmed Mar 25, 2026
0edae64
:white_check_mark: Fix mtsegmentor patches and tiles_no_metadata
shaneahmed Mar 25, 2026
4c8633c
:white_check_mark: Fix mtsegmentor patches
shaneahmed Mar 25, 2026
2a67908
:white_check_mark: Fix test_single_task_mtsegmentor
shaneahmed Mar 25, 2026
e904129
:white_check_mark: Fix test_wsi_mtsegmentor_correct_nonsquare_shape a…
shaneahmed Mar 25, 2026
a99d744
:white_check_mark: Fix test_wsi_segmentor_annotationstore
shaneahmed Mar 25, 2026
0cfe3a1
Merge branch 'develop' into fix-zarr-check
shaneahmed Mar 26, 2026
2dcc471
:white_check_mark: Fix test_micronet_output
shaneahmed Mar 26, 2026
9ac1ad2
Merge branch 'develop' into fix-zarr-check
shaneahmed Mar 30, 2026
af22da9
Merge branch 'develop' into fix-zarr-check
shaneahmed Mar 31, 2026
306a1f7
:arrow_up: `FsspecJsonWSIReader` Zarr 3 Fix (#1049)
aacic Apr 2, 2026
ec74cb2
Merge branch 'develop' into fix-zarr-check
shaneahmed Apr 7, 2026
65ade70
Merge branch 'develop' into fix-zarr-check
shaneahmed Apr 9, 2026
1f387e4
fix mypy errors
Jiaqi-Lv Apr 15, 2026
aefa5d8
:bug: Fix missing mask for contours
shaneahmed Apr 16, 2026
6f9e6b7
:hammer: Use `skip` for follow imports
shaneahmed Apr 16, 2026
07b4746
fix mypy type errors
Jiaqi-Lv Apr 16, 2026
d7d95d2
:hammer: Mark s3 test for NGFF as expected to fail
shaneahmed Apr 17, 2026
7a70666
:bug: Fix deepsource error
shaneahmed Apr 17, 2026
b759b0c
:bug: Fix deepsource error cyclomatic complexity too high.
shaneahmed Apr 17, 2026
58762df
:bug: Fix instance test with zarr.Array
shaneahmed Apr 17, 2026
ad6cf1f
:white_check_mark: Add tests for coverage
shaneahmed Apr 21, 2026
2d3f874
:white_check_mark: Add tests for wsireader coverage
shaneahmed Apr 22, 2026
8bc83c8
:white_check_mark: Add tests for multi_task_segmentor coverage
shaneahmed Apr 22, 2026
4e8dc5c
:fire: Remove dtype object check
shaneahmed Apr 22, 2026
845749a
:bug: Fix "store" attribute error with dictionary
shaneahmed Apr 22, 2026
5b120b4
:bulb: Address Co-Pilot comments
shaneahmed Apr 22, 2026
7f8dcd2
:bulb: Address Co-Pilot comments
shaneahmed Apr 22, 2026
be395e2
:white_check_mark: Add tests to improve coverage
shaneahmed Apr 22, 2026
38efdd9
:white_check_mark: Add tests to improve coverage
shaneahmed Apr 24, 2026
ccb0a14
:technologist: Address quality check issues
shaneahmed Apr 24, 2026
0d341d7
Merge branch 'develop' into fix-zarr-check
shaneahmed Apr 29, 2026
4534250
:bug: Fix pip install workflow
shaneahmed Apr 29, 2026
c0a72af
Merge branch 'develop' into fix-zarr-check
shaneahmed May 6, 2026
2abff8b
:bug: Replace `create_dataset` with `create_array`
shaneahmed May 6, 2026
9dd0a71
:bug: Fix test_clear_zarr
shaneahmed May 6, 2026
8e4dca6
:bug: Fix chunksize 0 unsupported by zarr v3.2.0+
shaneahmed May 6, 2026
323eca1
add cerberus initial attempt
measty May 7, 2026
f190f83
Merge branch 'fix-zarr-check' of https://github.com/TissueImageAnalyt…
measty May 7, 2026
46119f5
remove alternative architectures
measty May 8, 2026
6d16f3a
add test
measty May 8, 2026
d981346
Merge branch 'develop' into add-cerberus
shaneahmed May 8, 2026
25d8171
restructure code
measty May 8, 2026
5c5e304
add tests
measty May 8, 2026
a6bed60
Merge branch 'add-cerberus' of https://github.com/TissueImageAnalytic…
measty May 8, 2026
6a6cc40
halo postproc
measty May 22, 2026
27e84dc
fix broken margin behaviour
measty May 22, 2026
159959f
Merge branch 'develop' of https://github.com/TissueImageAnalytics/tia…
measty May 27, 2026
16ba015
add docstrings
measty May 27, 2026
280d65b
deepsource fixes
measty May 27, 2026
0c4ca2f
mypy fixes
measty May 27, 2026
1eed5af
fix test
measty May 27, 2026
a572ac5
add tests
measty May 27, 2026
6d5c595
postproc test
measty May 27, 2026
67381e5
Merge branch 'develop' into add-cerberus
shaneahmed Jun 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions tests/engines/test_multi_task_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
DaskDelayedJSONStore,
MultiTaskSegmentor,
_clear_zarr,
_crop_halo_post_process_output,
_get_postproc_tile_read_bounds,
_get_sel_indices_margin_lines,
_normalise_postproc_halo,
_post_save_json_store,
_process_instance_predictions,
_save_multitask_vertical_to_cache,
Expand Down Expand Up @@ -1079,6 +1082,76 @@ def test_get_tile_info_small_image_triggers_early_return(
assert np.all(flag == 0)


def test_postproc_halo_bounds_and_output_crop() -> None:
"""Test halo-expanded tile output is cropped and shifted to core space."""
halo_xy = _normalise_postproc_halo((3, 2))
assert np.array_equal(halo_xy, np.array([2, 3]))

read_bounds = _get_postproc_tile_read_bounds(
tile_bounds=(4, 5, 10, 11),
postproc_halo_xy=halo_xy,
image_shape=(12, 13),
)
assert read_bounds == (2, 2, 12, 13)

predictions = np.arange(11 * 10).reshape(11, 10)
info_dict = {
"box": np.array(
[
[2, 3, 4, 5],
[5, 6, 7, 8],
[9, 6, 11, 8],
],
dtype=np.int32,
),
"centroid": np.array(
[
[3, 4],
[6, 7],
[10, 7],
],
dtype=np.float32,
),
"contours": np.array(
[
[[2, 3], [4, 3], [4, 5], [2, 5]],
[[5, 6], [7, 6], [7, 8], [5, 8]],
[[9, 6], [11, 6], [11, 8], [9, 8]],
],
dtype=np.int32,
),
"type": np.array([1, 2, 3], dtype=np.int32),
}

cropped = _crop_halo_post_process_output(
post_process_output=(
{
"task_type": "gland",
"seg_type": "instance",
"predictions": predictions,
"info_dict": info_dict,
},
),
tile_bounds=(4, 5, 10, 11),
tile_read_bounds=read_bounds,
)[0]

assert np.array_equal(cropped["predictions"], predictions[3:9, 2:8])
assert np.array_equal(cropped["info_dict"]["type"], np.array([1, 2]))
assert np.array_equal(
cropped["info_dict"]["box"],
np.array([[0, 0, 2, 2], [3, 3, 5, 5]], dtype=np.int32),
)
assert np.array_equal(
cropped["info_dict"]["centroid"],
np.array([[1, 1], [4, 4]], dtype=np.float32),
)
assert np.array_equal(
cropped["info_dict"]["contours"][0],
np.array([[0, 0], [2, 0], [2, 2], [0, 2]], dtype=np.int32),
)


class FakeSeg(MultiTaskSegmentor):
"""Minimal subclass that allows us to override internals cleanly."""

Expand Down
295 changes: 295 additions & 0 deletions tests/models/test_arch_cerberus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
"""Unit tests for the Cerberus architecture."""

from __future__ import annotations

import dask.array as da
import numpy as np
import pytest
import torch

from tiatoolbox.models import Cerberus
from tiatoolbox.models.architecture import get_pretrained_model
from tiatoolbox.models.architecture.cerberus.model import (
_build_tissue_raw_map,
_crop_center_tensor,
_inst_dict_for_dask_processing,
_pad_contours,
)
from tiatoolbox.models.architecture.cerberus.postproc import (
PostProcInstErodedContourMap,
get_bounding_box,
)
from tiatoolbox.models.engine.io_config import IOInstanceSegmentorConfig

PATCH_OUTPUT_SHAPE = (144, 144)
INFER_INPUT_SHAPE = (256, 256)


def _module_prefixed_state_dict(model: Cerberus) -> dict[str, torch.Tensor]:
"""Return a Cerberus checkpoint state dict saved from DataParallel."""
return {f"module.{key}": value for key, value in model.state_dict().items()}


def test_cerberus_load_weights_from_desc_checkpoint(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test Cerberus checkpoint loading with ``desc`` and ``module.`` prefixes."""
source_model = Cerberus()
checkpoint = {"desc": _module_prefixed_state_dict(source_model)}

def _mock_torch_load(
*_args: object,
**_kwargs: object,
) -> dict[str, dict[str, torch.Tensor]]:
return checkpoint

monkeypatch.setattr(torch, "load", _mock_torch_load)

model = Cerberus()
model.load_weights_from_file("weights.tar")

state_key = "backbone.conv1.weight"
assert torch.equal(
model.state_dict()[state_key],
source_model.state_dict()[state_key],
)


def test_cerberus_pretrained_registry(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test the Cerberus pretrained registry entry and model IO config."""
checkpoint = {"desc": _module_prefixed_state_dict(Cerberus())}

def _mock_torch_load(
*_args: object,
**_kwargs: object,
) -> dict[str, dict[str, torch.Tensor]]:
return checkpoint

monkeypatch.setattr(torch, "load", _mock_torch_load)

model, ioconfig = get_pretrained_model(
"cerberus-resnet34",
pretrained_weights="weights.tar",
)

assert isinstance(model, Cerberus)
assert isinstance(ioconfig, IOInstanceSegmentorConfig)
assert tuple(ioconfig.patch_input_shape) == (448, 448)
assert tuple(ioconfig.patch_output_shape) == PATCH_OUTPUT_SHAPE
assert tuple(ioconfig.stride_shape) == PATCH_OUTPUT_SHAPE
assert len(ioconfig.output_resolutions) == len(Cerberus.head_names)


def test_cerberus_infer_batch_output_shapes() -> None:
"""Test Cerberus inference output order and shape."""
model = Cerberus()
batch = torch.zeros((1, *INFER_INPUT_SHAPE, 3), dtype=torch.uint8)

outputs = model.infer_batch(model, batch, device="cpu")

assert len(outputs) == len(Cerberus.head_names)
expected_shapes = (
(1, *PATCH_OUTPUT_SHAPE, 2),
(1, *PATCH_OUTPUT_SHAPE, 1),
(1, *PATCH_OUTPUT_SHAPE, 2),
(1, *PATCH_OUTPUT_SHAPE, 1),
(1, *PATCH_OUTPUT_SHAPE, 2),
(1, *PATCH_OUTPUT_SHAPE, 1),
)
for output, expected_shape in zip(outputs, expected_shapes, strict=True):
assert output.shape == expected_shape
assert output.dtype == np.float32


def test_cerberus_postproc_empty_maps() -> None:
"""Test Cerberus post-processing output structure for empty predictions."""
raw_maps = [
np.zeros((*PATCH_OUTPUT_SHAPE, 2), dtype=np.float32),
np.zeros((*PATCH_OUTPUT_SHAPE, 1), dtype=np.float32),
np.zeros((*PATCH_OUTPUT_SHAPE, 2), dtype=np.float32),
np.zeros((*PATCH_OUTPUT_SHAPE, 1), dtype=np.float32),
np.zeros((*PATCH_OUTPUT_SHAPE, 2), dtype=np.float32),
np.zeros((*PATCH_OUTPUT_SHAPE, 1), dtype=np.float32),
]

outputs = Cerberus().postproc(raw_maps, offset=(3, 5))

assert [output["task_type"] for output in outputs] == ["nuclei", "gland", "lumen"]
for output in outputs:
assert output["seg_type"] == "instance"
assert output["predictions"].shape == PATCH_OUTPUT_SHAPE
assert output["predictions"].dtype == np.int32

info_dict = output["info_dict"]
assert info_dict["box"].shape == (0, 4)
assert info_dict["box"].dtype == np.int32
assert info_dict["centroid"].shape == (0, 2)
assert info_dict["centroid"].dtype == np.float32
assert info_dict["contours"].shape == (0, 0, 2)
assert info_dict["contours"].dtype == np.int32
assert info_dict["prob"].shape == (0,)
assert info_dict["prob"].dtype == np.float32
assert info_dict["type"].shape == (0,)
assert info_dict["type"].dtype == np.int32


def test_cerberus_postproc_dask_maps_and_lumen_gland_mask(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test Cerberus post-processing Dask output and lumen-in-gland masking."""
output_shape = (16, 16)
raw_maps = [
da.from_array(
np.zeros((*output_shape, channels), dtype=np.float32),
chunks=(8, 8, channels),
)
for channels in (2, 1, 2, 1, 2, 1)
]
calls = []

def _mock_post_process(
raw_map: np.ndarray,
idx_dict: dict[str, list[int]],
tissue_mode: str,
ds_factor: float,
) -> tuple[np.ndarray, np.ndarray | None]:
calls.append((tissue_mode, raw_map.shape, idx_dict, ds_factor))
inst_map = np.zeros(output_shape, dtype=np.int32)
type_map = np.ones(output_shape, dtype=np.uint8)
if tissue_mode == "Nuclei":
inst_map[2:5, 2:5] = 1
elif tissue_mode == "Gland":
inst_map[1:8, 1:8] = 1
else:
inst_map[3:6, 3:6] = 1
inst_map[10:13, 10:13] = 2
type_map = None
return inst_map, type_map

def _mock_get_instance_info(
inst_map: np.ndarray,
type_map: np.ndarray | None,
offset: tuple[int, int],
verbose: object,
) -> dict[int, dict]:
assert offset == (7, 11)
assert verbose is False
type_value = 0 if type_map is None else int(type_map[inst_map > 0][0])
return {
1: {
"box": np.array([1, 2, 3, 4], dtype=np.int32),
"centroid": np.array([2.5, 3.5], dtype=np.float32),
"contours": np.array([[1, 2], [3, 4]], dtype=np.int32),
"prob": 0.75,
"type": type_value,
},
}

monkeypatch.setattr(
PostProcInstErodedContourMap,
"post_process",
_mock_post_process,
)
monkeypatch.setattr(
"tiatoolbox.models.architecture.cerberus.model.HoVerNet.get_instance_info",
_mock_get_instance_info,
)

outputs = Cerberus().postproc(raw_maps, offset=(7, 11))

assert [call[0] for call in calls] == ["Nuclei", "Gland", "Lumen"]
assert calls[0][1:] == (
(16, 16, 3),
{"Nuclei-INST": [0, 2], "Nuclei-TYPE": [2, 3]},
1.0,
)
assert [output["task_type"] for output in outputs] == ["nuclei", "gland", "lumen"]
lumen_map = outputs[2]["predictions"].compute()
assert np.all(lumen_map[3:6, 3:6] == 1)
assert np.all(lumen_map[10:13, 10:13] == 0)
for output in outputs:
assert isinstance(output["predictions"], da.Array)
assert output["predictions"].dtype == np.int32
assert output["info_dict"]["box"].compute().dtype == np.int32
assert output["info_dict"]["centroid"].compute().dtype == np.float32
assert output["info_dict"]["contours"].compute().shape == (1, 2, 2)
assert output["info_dict"]["prob"].compute().dtype == np.float32
assert output["info_dict"]["type"].compute().dtype == np.int32


def test_cerberus_model_helpers() -> None:
"""Test Cerberus private helper conversions."""
tissue_map, idx_dict = _build_tissue_raw_map(
{
"Nuclei-INST": np.zeros((4, 5, 2), dtype=np.float32),
"Nuclei-TYPE": np.ones((4, 5), dtype=np.float32),
},
"Nuclei",
)
assert tissue_map.shape == (4, 5, 3)
assert idx_dict == {"Nuclei-INST": [0, 2], "Nuclei-TYPE": [2, 3]}

tensor = torch.arange(1 * 5 * 6 * 1, dtype=torch.float32).reshape(1, 5, 6, 1)
cropped = _crop_center_tensor(tensor, (3, 4))
assert cropped.shape == (1, 3, 4, 1)
assert torch.equal(cropped, tensor[:, 1:4, 1:5, :])

contours = np.array(
[
np.array([[1, 2], [3, 4]], dtype=np.int32),
np.array([[5, 6]], dtype=np.int32),
],
dtype=object,
)
padded = _pad_contours(contours)
assert padded.shape == (2, 2, 2)
assert np.array_equal(padded[1, 0], [5, 6])
assert np.array_equal(padded[1, 1], [np.iinfo(np.int32).min] * 2)

dask_info = _inst_dict_for_dask_processing({}, is_dask=True)
assert dask_info["contours"].compute().shape == (0, 0, 2)
assert dask_info["type"].compute().dtype == np.int32


def test_cerberus_eroded_contour_postproc_non_empty_and_errors() -> None:
"""Test non-empty Cerberus contour post-processing and validation errors."""
gland_raw_map = np.zeros((80, 80, 3), dtype=np.float32)
gland_raw_map[10:60, 10:60, 0] = 0.9
gland_raw_map[..., 2] = 2

inst_map, type_map = PostProcInstErodedContourMap.post_process(
raw_map=gland_raw_map,
idx_dict={"Gland-INST": [0, 2], "Gland-TYPE": [2, 3]},
tissue_mode="Gland",
)

assert inst_map.shape == (80, 80)
assert inst_map.max() == 1
assert type_map is not None
assert type_map.shape == (80, 80)
assert np.all(type_map == 2)
assert get_bounding_box(inst_map > 0) == (6, 65, 6, 65)

lumen_raw_map = np.zeros((40, 40, 2), dtype=np.float32)
lumen_raw_map[8:25, 8:25, 0] = 0.9
lumen_inst_map, lumen_type_map = PostProcInstErodedContourMap.post_process(
raw_map=lumen_raw_map,
idx_dict={"Lumen-INST": [0, 2]},
tissue_mode="Lumen",
)
assert lumen_inst_map.max() == 1
assert lumen_type_map is None

with pytest.raises(ValueError, match="Unsupported Cerberus tissue mode"):
PostProcInstErodedContourMap.post_process(
raw_map=lumen_raw_map,
idx_dict={"Lumen-INST": [0, 2]},
tissue_mode="Stroma",
)

with pytest.raises(KeyError, match="Missing required Cerberus map"):
PostProcInstErodedContourMap.post_process(
raw_map=lumen_raw_map,
idx_dict={},
tissue_mode="Lumen",
)
Loading
Loading