Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 214 additions & 0 deletions tests/test_processor_lifecycle.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import unittest
import importlib
import os
import tempfile
from unittest.mock import MagicMock, patch

from trident import Processor
from trident.IO import create_lock

processor_module = importlib.import_module("trident.Processor")

Expand All @@ -22,7 +25,39 @@ def __exit__(self, exc_type, exc, tb):
return False


class _DummyLoop:
def __init__(self, items, **kwargs):
self.items = items

def __iter__(self):
return iter(self.items)

def set_postfix_str(self, _message):
return None


class TestProcessorLifecycle(unittest.TestCase):
def _processor_with_wsi(self, job_dir, wsi):
processor = Processor.__new__(Processor)
processor.job_dir = job_dir
processor.skip_errors = True
processor.max_workers = 1
processor.wsis = [wsi]
processor.save_config = MagicMock()
return processor

def _wsi(self):
wsi = MagicMock()
wsi.name = "slide"
wsi.ext = ".svs"
wsi.slide_path = "/tmp/slide.svs"
wsi.dimensions = (100, 100)
wsi.mpp = 0.5
wsi.mag = 20
wsi.level_count = 1
wsi.release = MagicMock()
return wsi

def test_release_closes_exitstack_and_clears_wsis(self):
exit_calls = {"count": 0}

Expand Down Expand Up @@ -111,6 +146,185 @@ def fake_exists(path):

processor.run_patch_feature_extraction_job.assert_not_called()

def test_segmentation_error_removes_lock_when_skip_errors(self):
with tempfile.TemporaryDirectory() as tmpdir:
wsi = self._wsi()
contour_path = os.path.join(tmpdir, "contours", "slide.jpg")

def write_partial_contour_then_fail(**_kwargs):
with open(contour_path, "w", encoding="utf-8") as f:
f.write("partial")
raise RuntimeError("boom")

wsi.segment_tissue.side_effect = write_partial_contour_then_fail
processor = self._processor_with_wsi(tmpdir, wsi)

with patch.object(processor_module, "tqdm", side_effect=lambda items, **kwargs: _DummyLoop(items, **kwargs)), \
patch.object(processor_module, "update_task_state"), \
patch.object(processor_module, "update_log"):
processor.run_segmentation_job(
segmentation_model=MagicMock(),
seg_mag=10,
batch_size=1,
device="cpu",
)

self.assertFalse(os.path.exists(os.path.join(tmpdir, "contours", "slide.jpg.lock")))
self.assertFalse(os.path.exists(contour_path))

def test_segmentation_post_acquire_existing_output_releases_lock(self):
with tempfile.TemporaryDirectory() as tmpdir:
wsi = self._wsi()
processor = self._processor_with_wsi(tmpdir, wsi)

def create_lock_then_finish(path):
acquired = create_lock(path)
with open(path, "w", encoding="utf-8") as f:
f.write("finished by another worker")
return acquired

with patch.object(processor_module, "tqdm", side_effect=lambda items, **kwargs: _DummyLoop(items, **kwargs)), \
patch.object(processor_module, "create_lock", side_effect=create_lock_then_finish), \
patch.object(processor_module, "update_task_state"), \
patch.object(processor_module, "update_log"):
processor.run_segmentation_job(
segmentation_model=MagicMock(),
seg_mag=10,
batch_size=1,
device="cpu",
)

wsi.segment_tissue.assert_not_called()
self.assertFalse(os.path.exists(os.path.join(tmpdir, "contours", "slide.jpg.lock")))

def test_coords_error_removes_lock_when_skip_errors(self):
with tempfile.TemporaryDirectory() as tmpdir:
tissue_seg_path = os.path.join(tmpdir, "slide.geojson")
with open(tissue_seg_path, "w", encoding="utf-8") as f:
f.write("{}")

wsi = self._wsi()
wsi.tissue_seg_path = tissue_seg_path
coords_path = os.path.join(tmpdir, "20x_256px_0px_overlap", "patches", "slide_patches.h5")

def write_partial_coords_then_fail(**_kwargs):
with open(coords_path, "w", encoding="utf-8") as f:
f.write("partial")
raise RuntimeError("boom")

wsi.extract_tissue_coords.side_effect = write_partial_coords_then_fail
processor = self._processor_with_wsi(tmpdir, wsi)

with patch.object(processor_module, "tqdm", side_effect=lambda items, **kwargs: _DummyLoop(items, **kwargs)), \
patch.object(processor_module.gpd, "read_file", return_value=MagicMock(empty=False)), \
patch.object(processor_module, "update_task_state"), \
patch.object(processor_module, "update_log"):
coords_dir = processor.run_patching_job(
target_magnification=20,
patch_size=256,
overlap=0,
)

self.assertFalse(os.path.exists(os.path.join(coords_dir, "patches", "slide_patches.h5.lock")))
self.assertFalse(os.path.exists(coords_path))

def test_coords_existing_output_with_lock_is_reported_locked(self):
with tempfile.TemporaryDirectory() as tmpdir:
coords_dir = "20x_256px_0px_overlap"
patches_dir = os.path.join(tmpdir, coords_dir, "patches")
os.makedirs(patches_dir, exist_ok=True)
coords_path = os.path.join(patches_dir, "slide_patches.h5")
with open(coords_path, "w", encoding="utf-8") as f:
f.write("partial")
self.assertTrue(create_lock(coords_path))

wsi = self._wsi()
wsi.tissue_seg_path = os.path.join(tmpdir, "slide.geojson")
processor = self._processor_with_wsi(tmpdir, wsi)

with patch.object(processor_module, "tqdm", side_effect=lambda items, **kwargs: _DummyLoop(items, **kwargs)), \
patch.object(processor_module, "update_task_state") as update_state, \
patch.object(processor_module, "update_log"):
processor.run_patching_job(
target_magnification=20,
patch_size=256,
overlap=0,
)

wsi.extract_tissue_coords.assert_not_called()
self.assertEqual(update_state.call_args.kwargs["reason"], "locked")
self.assertTrue(os.path.exists(f"{coords_path}.lock"))

def test_patch_features_error_removes_lock_when_skip_errors(self):
with tempfile.TemporaryDirectory() as tmpdir:
coords_dir = "20x_256px_0px_overlap"
patches_dir = os.path.join(tmpdir, coords_dir, "patches")
os.makedirs(patches_dir, exist_ok=True)
with open(os.path.join(patches_dir, "slide_patches.h5"), "w", encoding="utf-8") as f:
f.write("coords")

wsi = self._wsi()
processor = self._processor_with_wsi(tmpdir, wsi)
patch_encoder = MagicMock()
patch_encoder.enc_name = "uni_v1"
features_path = os.path.join(tmpdir, coords_dir, "features_uni_v1", "slide.h5")

def write_partial_features_then_fail(**_kwargs):
with open(features_path, "w", encoding="utf-8") as f:
f.write("partial")
raise RuntimeError("boom")

wsi.extract_patch_features.side_effect = write_partial_features_then_fail

with patch.object(processor_module, "tqdm", side_effect=lambda items, **kwargs: _DummyLoop(items, **kwargs)), \
patch.object(processor_module, "update_task_state"), \
patch.object(processor_module, "update_log"):
features_dir = processor.run_patch_feature_extraction_job(
coords_dir=coords_dir,
patch_encoder=patch_encoder,
device="cpu",
saveas="h5",
batch_limit=1,
)

self.assertFalse(os.path.exists(os.path.join(features_dir, "slide.h5.lock")))
self.assertFalse(os.path.exists(features_path))

def test_slide_features_error_removes_lock_when_skip_errors(self):
with tempfile.TemporaryDirectory() as tmpdir:
coords_dir = "20x_256px_0px_overlap"
patch_features_dir = os.path.join(tmpdir, coords_dir, "features_mockpatch")
os.makedirs(patch_features_dir, exist_ok=True)
with open(os.path.join(patch_features_dir, "slide.h5"), "w", encoding="utf-8") as f:
f.write("features")

wsi = self._wsi()
processor = self._processor_with_wsi(tmpdir, wsi)
processor.run_patch_feature_extraction_job = MagicMock()
slide_encoder = MagicMock()
slide_encoder.enc_name = "mean-mockpatch"
slide_features_path = os.path.join(tmpdir, coords_dir, "slide_features_mean-mockpatch", "slide.h5")

def write_partial_slide_features_then_fail(**_kwargs):
with open(slide_features_path, "w", encoding="utf-8") as f:
f.write("partial")
raise RuntimeError("boom")

wsi.extract_slide_features.side_effect = write_partial_slide_features_then_fail

with patch.object(processor_module, "tqdm", side_effect=lambda items, **kwargs: _DummyLoop(items, **kwargs)), \
patch.object(processor_module, "update_task_state"), \
patch.object(processor_module, "update_log"):
slide_features_dir = processor.run_slide_feature_extraction_job(
coords_dir=coords_dir,
slide_encoder=slide_encoder,
device="cpu",
saveas="h5",
)

self.assertFalse(os.path.exists(os.path.join(slide_features_dir, "slide.h5.lock")))
self.assertFalse(os.path.exists(slide_features_path))


if __name__ == "__main__":
unittest.main()
44 changes: 43 additions & 1 deletion tests/test_wsi_core_behaviors.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import unittest
import json
from concurrent.futures import ThreadPoolExecutor
from threading import Barrier
from unittest.mock import patch
import tempfile
import os
import numpy as np

from trident.IO import splitext, coords_to_h5, read_coords
from trident.IO import splitext, coords_to_h5, read_coords, create_lock, remove_lock
import trident.wsi_objects.WSIFactory as wsifactory
from trident.wsi_objects.WSI import WSI
from trident.wsi_objects.WSIPatcher import WSIPatcher
Expand Down Expand Up @@ -43,6 +46,45 @@ def test_splitext_handles_ome_zarr(self):
self.assertEqual(ext, ".ome.zarr")


class TestIOLocks(unittest.TestCase):
def test_create_lock_does_not_overwrite_existing_lock(self):
with tempfile.TemporaryDirectory() as tmpdir:
target = os.path.join(tmpdir, "output.h5")

self.assertTrue(create_lock(target))
with open(f"{target}.lock", "r", encoding="utf-8") as f:
first_payload = json.load(f)

self.assertFalse(create_lock(target))
with open(f"{target}.lock", "r", encoding="utf-8") as f:
second_payload = json.load(f)

self.assertEqual(second_payload, first_payload)
remove_lock(target)
self.assertFalse(os.path.exists(f"{target}.lock"))

def test_create_lock_allows_only_one_concurrent_acquirer(self):
with tempfile.TemporaryDirectory() as tmpdir:
target = os.path.join(tmpdir, "output.h5")
workers = 8
barrier = Barrier(workers)

def acquire_once():
barrier.wait()
return create_lock(target)

with ThreadPoolExecutor(max_workers=workers) as executor:
results = list(executor.map(lambda _: acquire_once(), range(workers)))

self.assertEqual(sum(1 for result in results if result), 1)
self.assertTrue(os.path.exists(f"{target}.lock"))
with open(f"{target}.lock", "r", encoding="utf-8") as f:
payload = json.load(f)
self.assertIn("pid", payload)
self.assertIn("hostname", payload)
self.assertIn("created_at", payload)


class TestWSIFactoryRouting(unittest.TestCase):
def test_auto_reader_routes_ome_tif_to_openslide(self):
with patch.object(wsifactory, "OpenSlideWSI", return_value="open_reader") as open_mock, \
Expand Down
38 changes: 29 additions & 9 deletions trident/IO.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def get_weights_path(model_type: str, encoder_name: str) -> str:
return path


def create_lock(path: str, suffix: Optional[str] = None) -> None:
def create_lock(path: str, suffix: Optional[str] = None) -> bool:
"""
Create a lock file to signal that a particular file or process
is currently being worked on. This is especially useful in multiprocessing or distributed
Expand All @@ -238,18 +238,38 @@ def create_lock(path: str, suffix: Optional[str] = None) -> None:
-------
>>> create_lock("/path/to/resource")
>>> # Creates a file named "/path/to/resource.lock" to indicate the resource is locked.

Returns:
bool:
True if the lock was acquired, False if it already exists.
"""
if suffix is not None:
path = f"{path}_{suffix}"
lock_file = f"{path}.lock"
with open(lock_file, 'w') as f:
# Write metadata to allow safe dead-lock cleanup.
payload = {
"pid": os.getpid(),
"hostname": socket.gethostname(),
"created_at": time.time(),
}
f.write(json.dumps(payload))
payload = {
"pid": os.getpid(),
"hostname": socket.gethostname(),
"created_at": time.time(),
}
try:
fd = os.open(lock_file, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
except FileExistsError:
return False

try:
with os.fdopen(fd, 'w') as f:
fd = None
# Write metadata to allow safe dead-lock cleanup.
f.write(json.dumps(payload))
except Exception:
if fd is not None:
os.close(fd)
try:
os.remove(lock_file)
except FileNotFoundError:
pass
raise
return True

#####################

Expand Down
Loading
Loading