diff --git a/README.md b/README.md index 77175654..a679ccb7 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ This project was developed by the [Mahmood Lab](https://faisal.ai/) at Harvard M Optional install profiles: - `pip install -e ".[patch-encoders]"` for CONCH/MUSK/CTransPath-related extras. - `pip install -e ".[slide-encoders]"` for PRISM/GigaPath/Madeleine-related extras. +- `pip install -e ".[omezarr]"` for OME Zarr WSI reader support - `pip install -e ".[convert]"` for slide conversion dependencies. - `pip install -e ".[full]"` to install all pip-installable optional dependencies. @@ -232,7 +233,7 @@ main() - **A**: Yes using the `--custom_list_of_wsis` argument. Provide a list of WSI names in a CSV (with slide extension, `wsi`). Optionally, provide the mpp (field `mpp`) - **Q**: Do I need to install any additional packages to use Trident? - - **A**: `pip install -e .` installs core dependencies. Some optional components still require extra installs. Use profiles (`.[patch-encoders]`, `.[slide-encoders]`, `.[convert]`, or `.[full]`) and run `trident-doctor` for preflight checks. + - **A**: `pip install -e .` installs core dependencies. Some optional components still require extra installs. Use profiles (`.[patch-encoders]`, `.[slide-encoders]`, `.[convert]`, `.[omezarr]` or `.[full]`) and run `trident-doctor` for preflight checks. ## License and Terms of Use diff --git a/docs/quickstart.rst b/docs/quickstart.rst index 768423b9..42b44cc7 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -123,7 +123,7 @@ WSI discovery and reading - ``--search_nested``: recursively discover slides in nested subfolders. - ``--custom_list_of_wsis``: CSV subset list to process selected slides only. - ``--custom_mpp_keys``: metadata keys to read MPP from non-standard slide headers. -- ``--reader_type``: force backend reader (``openslide``, ``cucim``, ``image``, ``sdpc``). +- ``--reader_type``: force backend reader (``openslide``, ``cucim``, ``image``, ``sdpc``, ``omezarr``). When to change: diff --git a/pyproject.toml b/pyproject.toml index d064723b..34dd7036 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,10 @@ musk = { git = "https://github.com/lilab-stanford/MUSK", optional = true } gigapath = { git = "https://github.com/prov-gigapath/prov-gigapath.git", optional = true } madeleine = { git = "https://github.com/mahmoodlab/MADELEINE.git", optional = true } pylibCZIrw = { version = "*", optional = true } +cf-units = { version = ">=3.1", optional = true } +ngff-zarr = { version = ">=0.12", optional = true } +zarr = { version = ">=3.0.0a0", optional = true } +dask = { version = ">=2024.10", optional = true } [tool.poetry.extras] patch-encoders = [ @@ -82,6 +86,12 @@ full = [ "madeleine", "pylibCZIrw", ] +omezarr = [ + "ngff-zarr", + "dask", + "cf-units", + "zarr" +] [tool.poetry.dev-dependencies] # Optional development dependencies diff --git a/run_batch_of_slides.py b/run_batch_of_slides.py index a9f70da0..3a098547 100644 --- a/run_batch_of_slides.py +++ b/run_batch_of_slides.py @@ -59,8 +59,8 @@ def build_parser() -> argparse.ArgumentParser: help='Custom keys used to store the resolution as MPP (micron per pixel) in your list of whole-slide image.') parser.add_argument('--custom_list_of_wsis', type=str, default=None, help='Custom list of WSIs specified in a csv file.') - parser.add_argument('--reader_type', type=str, choices=['openslide', 'image', 'cucim', 'sdpc'], default=None, - help='Force the use of a specific WSI image reader. Options are ["openslide", "image", "cucim", "sdpc"]. Defaults to None (auto-determine which reader to use).') + parser.add_argument('--reader_type', type=str, choices=['openslide', 'image', 'cucim', 'sdpc', 'omezarr'], default=None, + help='Force the use of a specific WSI image reader. Options are ["openslide", "image", "cucim", "sdpc", "omezarr"]. Defaults to None (auto-determine which reader to use).') parser.add_argument("--search_nested", action="store_true", help=("If set, recursively search for whole-slide images (WSIs) within all subdirectories of " "`wsi_source`. Uses `os.walk` to include slides from nested folders. " diff --git a/trident/Processor.py b/trident/Processor.py index f1f0e934..f61abc58 100644 --- a/trident/Processor.py +++ b/trident/Processor.py @@ -11,7 +11,7 @@ from trident import load_wsi, WSIReaderType from trident.IO import create_lock, remove_lock, is_locked, update_log, collect_valid_slides, splitext from trident.Maintenance import deprecated -from trident.wsi_objects.WSIFactory import OPENSLIDE_EXTENSIONS, PIL_EXTENSIONS, SDPC_EXTENSIONS +from trident.wsi_objects.WSIFactory import OPENSLIDE_EXTENSIONS, PIL_EXTENSIONS, SDPC_EXTENSIONS, OMEZARR_EXTENSIONS class Processor: @@ -74,7 +74,7 @@ def __init__( Maximum number of workers for data loading. If None, the default behavior will be used. Defaults to None. reader_type (WSIReaderType, optional): - Force the image reader engine to use. Options are are ["openslide", "image", "cucim"]. Defaults to None + Force the image reader engine to use. Options are are ["openslide", "image", "cucim", "sdpc", "omezarr"]. Defaults to None (auto-determine the right engine based on image extension). search_nested (bool, optional): If True, the processor will recursively search for WSIs within all subdirectories of `wsi_source`. @@ -108,7 +108,7 @@ def __init__( self.job_dir = job_dir self.wsi_source = wsi_source - self.wsi_ext = wsi_ext or (list(PIL_EXTENSIONS) + list(OPENSLIDE_EXTENSIONS) + list(SDPC_EXTENSIONS)) + self.wsi_ext = wsi_ext or (list(PIL_EXTENSIONS) + list(OPENSLIDE_EXTENSIONS) + list(SDPC_EXTENSIONS) + list(OMEZARR_EXTENSIONS)) self.skip_errors = skip_errors self.custom_mpp_keys = custom_mpp_keys self.max_workers = max_workers diff --git a/trident/__init__.py b/trident/__init__.py index 6cc936fd..6938995a 100644 --- a/trident/__init__.py +++ b/trident/__init__.py @@ -9,6 +9,7 @@ from trident.wsi_objects.CuCIMWSI import CuCIMWSI from trident.wsi_objects.ImageWSI import ImageWSI from trident.wsi_objects.SDPCWSI import SDPCWSI +from trident.wsi_objects.OMEZarrWSI import OMEZarrWSI from trident.wsi_objects.WSIFactory import load_wsi, WSIReaderType from trident.wsi_objects.WSIPatcher import OpenSlideWSIPatcher, WSIPatcher from trident.wsi_objects.WSIPatcherDataset import WSIPatcherDataset @@ -28,6 +29,7 @@ "ImageWSI", "CuCIMWSI", "SDPCWSI", + "OMEZarrWSI", "WSIPatcher", "OpenSlideWSIPatcher", "WSIPatcherDataset", diff --git a/trident/wsi_objects/OMEZarrWSI.py b/trident/wsi_objects/OMEZarrWSI.py new file mode 100644 index 00000000..251376d8 --- /dev/null +++ b/trident/wsi_objects/OMEZarrWSI.py @@ -0,0 +1,346 @@ +from typing import Tuple, Union, Any +from trident.wsi_objects.WSI import WSI, ReadMode +from PIL import Image +import numpy as np + +try: + from zarr import open as zarr_open + from dask.config import set as dask_config_set + from ngff_zarr import from_ngff_zarr + from cf_units import Unit as cf_Unit + + _HAS_OME_ZARR = True + _EXCEPT_MESSAGE = None +except ImportError as e: # ModuleNotFoundError is likely + _HAS_OME_ZARR = False + _EXCEPT_MESSAGE = e + + +class OMEZarrWSI(WSI): + """ + WSI implementation for reading zarrfiles following the OME specification. + """ + + def __init__(self, slide_path: str, **kwargs: Any) -> None: + """ + Initialize a OMEZarr instance for OME-Zarr whole-slide images. + + Parameters + ---------- + slide_path : str + Path to an .zarr OME multiscale file. + **kwargs : dict + Additional keyword arguments forwarded to the base `WSI` class. + - lazy_init (bool, default=True): Whether to defer loading WSI and metadata. + + Example + ------- + >>> wsi = OMEZarrWSI(slide_path="path/to/wsi", lazy_init=False) + >>> print(wsi) + + """ + super().__init__(slide_path, **kwargs) + + def _lazy_initialize(self) -> None: + """ + Lazily initialize the WSI using ngff-zarr. + + This method opens a whole-slide image using the ngff-zarr backend, extracting + key metadata including dimensions, magnification, and multiresolution pyramid + information. If a tissue segmentation mask is provided, it is also loaded. + + Raises + ------ + FileNotFoundError + If the WSI file or the tissue segmentation mask cannot be found. + RuntimeError + If an unexpected error occurs during WSI initialization. Including if there + are not 3 dimensions in an image, as read_region depends on this property. + + Notes + ----- + After initialization, the following attributes are set: + - `width` and `height`: spatial dimensions of the base level. + - `dimensions`: (width, height) tuple from the highest resolution. + - `level_count`: number of resolution levels in the image pyramid. + - `level_downsamples`: downsampling factors for each level. + - `level_dimensions`: image dimensions at each level. + - `properties`: metadata object from ngff-zarr. + - `mpp`: microns per pixel, inferred if not manually specified. + - `mag`: estimated magnification level (via WSI.py). + - `gdf_contours`: loaded from `tissue_seg_path` if provided (via WSI.py). + """ + + super()._lazy_initialize() + + if not self._initialized: + + if not _HAS_OME_ZARR: + raise ImportError( + "ngff-zarr, zarr, dask, and cf_units are required for omezarr support. " + "Install them with pip, or pip install .[omezarr] when installing TRIDENT. " + f"When trying to import, got message {_EXCEPT_MESSAGE}" + ) + + try: + self.img = from_ngff_zarr( + self.slide_path + ) # Multiscales dataclass from ngff-zarr + + idx_tuple, dimname_tuple = self._fetch_dimension_metadata() + self._idx_x, self._idx_y, self._idx_c = idx_tuple + self._xname, self._yname, self._cname = dimname_tuple + + self._transpose_order = (self._idx_y, self._idx_x, self._idx_c) + + # x -> width, y -> height + _get_W_and_H = lambda ngffimg: ( + ngffimg.data.shape[self._idx_x], + ngffimg.data.shape[self._idx_y], + ) + self.dimensions = _get_W_and_H( + self.img.images[0] + ) # use the top level image (largest resolution) + + self.width, self.height = self.dimensions + self.level_count = len(self.img.images) + self.level_dimensions = tuple(map(_get_W_and_H, self.img.images)) + self.level_downsamples = self._fetch_downsamples() + if self.mpp is None: + self.mpp = self._fetch_mpp() + self.mag = self._fetch_magnification() + try: + self.properties = dict( + zarr_open(self.slide_path).attrs + ) # get the whole zarr.json object + except: + self.properties = None + + self._initialized = True + + except Exception as e: + raise RuntimeError( + f"Failed to initialize WSI with ngff-zarr: {e}" + ) from e + + def _fetch_mpp(self): + """ + Retrieve microns per pixel (MPP) from OME Zarr metadata. The OME spec + has a designated axes unit property in UDUNITS-2, so custom_mpp_keys not requried. + + Returns + ------- + np.float64 + MPP value in microns per pixel. + """ + try: + scale, scale_unit = ( + self.img.images[0].scale[self._xname], + self.img.images[0].axes_units[self._xname], + ) + return cf_Unit(scale_unit).convert( + scale, cf_Unit("micrometers") + ) # mpp for the x axis at the highest res image + except: + raise ValueError( + f"Unable to extract MPP from slide metadata: '{self.slide_path}'.\n" + "Suggestions:\n" + "- Set the unit in the x/width axes metadata of the OME-Zarr Multiscales " + "(likely having to update the corresponding scale property).\n" + "- Set the MPP explicitly via the class constructor.\n" + "- If using the `run_batch_of_slides.py` script, pass the MPP via the " + "`--custom_list_of_wsis` argument in a CSV file. Refer to TRIDENT/README/Q&A." + ) + + + def _fetch_downsamples(self): + """ + Calculate the downsampling factors for each resolution level. + + Computes the ratio of the highest resolution level's x-axis dimension to + each subsequent level's x-axis dimension. The base level defaults to 1.0. + + Returns + ------- + Tuple[float] + Downsample factors for each level in the image pyramid. + """ + return tuple( + [1.0] + + [ + ( + self.img.images[0].data.shape[self._idx_x] + / ngff_img.data.shape[self._idx_x] + ) + for ngff_img in self.img.images[1:] + ] + ) + + def _fetch_dimension_metadata(self): + """ + Parse dimension metadata to identify spatial and channel axes. + + Extracts and maps the indices and original string names for the x-axis, + y-axis, and channel dimensions from the image metadata. + + Returns + ------- + Tuple[Tuple[int, int, int], Tuple[str, str, str]] + A pair of tuples containing the integer indices (idx_x, idx_y, idx_c) + and the matched string names (x_name, y_name, c_name), respectively. + + Raises + ------ + AssertionError + If the image does not have exactly 3 dimensions or contains unrecognized dimension names. + ValueError + If the dimensions do not consist of exactly one X-type, one Y-type, and one C-type axis. + """ + + dimnames = self.img.metadata.dimension_names + possible_dimnames_lowercase = {"x", "y", "c", "width", "height", 'channel'} + + strlower = lambda x: x.lower() + assert (len(dimnames) == 3) and ( + set(map(strlower, dimnames)).issubset(possible_dimnames_lowercase) + ), f"Err, read_region expects 3 dimensional image data with {possible_dimnames_lowercase} dim names, found {dimnames}" + + try: + _xname = next(d for d in dimnames if d.lower() in {"x", "width"}) + _yname = next(d for d in dimnames if d.lower() in {"y", "height"}) + _cname = next(d for d in dimnames if d.lower() in {"c", "channel"}) + except: + raise ValueError( + "Err, expecting one of each space/channel type dim in " + f"{possible_dimnames_lowercase}, found {dimnames}." + ) + + _dimname_to_index = { + name: i for i, name in enumerate(self.img.metadata.dimension_names) + } + _idx_x, _idx_y, _idx_c = ( + _dimname_to_index[_xname], + _dimname_to_index[_yname], + _dimname_to_index[_cname], + ) + + return (_idx_x, _idx_y, _idx_c), (_xname, _yname, _cname) + + def read_region( + self, + location: Tuple[int, int], + level: int, + size: Tuple[int, int], + read_as: ReadMode = "pil", + ) -> Union[Image.Image, np.ndarray]: + """ + Extract a specific region from the whole-slide image (WSI). + + Parameters + ---------- + location : Tuple[int, int] + (x, y) coordinates of the top-left corner of the region to extract. + level : int + Pyramid level to read from. + size : Tuple[int, int] + (width, height) of the region to extract. + read_as : {'pil', 'numpy'}, optional + Output format for the region: + - 'pil': returns a PIL Image (default) + - 'numpy': returns a NumPy array (H, W, 3) + + Returns + ------- + Union[PIL.Image.Image, np.ndarray] + Extracted image region in the specified format. + + Raises + ------ + ValueError + If `read_as` is not one of 'pil' or 'numpy'. + + Examples + -------- + >>> region = wsi.read_region((0, 0), level=0, size=(512, 512), read_as='numpy') + >>> print(region.shape) + (512, 512, 3) + """ + # 'location' is relative to the level as calls are made to the data array + downsample_factor = self.level_downsamples[level] + location_ = ( + int(location[0] / downsample_factor), + int(location[1] / downsample_factor), + ) + + x, y = location_ + width_size, height_size = size + + region_as_slice = [None, None, None] + region_as_slice[self._idx_y] = slice(y, y + height_size) + region_as_slice[self._idx_x] = slice(x, x + width_size) + region_as_slice[self._idx_c] = slice(None) + region_as_slice = tuple(region_as_slice) + + # prevent deadlock that occurs when reading while nested in pytorch's distributed operations + with dask_config_set(scheduler="synchronous"): + region = ( + self.img.images[level] + .data[region_as_slice] + .compute() + .transpose(self._transpose_order) + ) + + if read_as == "pil": + return Image.fromarray(region).convert("RGB") + elif read_as == "numpy": + return region + else: + raise ValueError( + f"Invalid `read_as` value: {read_as}. Must be 'pil', 'numpy'." + ) + + def get_dimensions(self) -> Tuple[int, int]: + """ + Return the dimensions (width, height) of the WSI. + + Returns + ------- + tuple of int + (width, height) in pixels. + """ + return self.dimensions + + def get_thumbnail(self, size: tuple[int, int]) -> Image.Image: + """ + Generate a thumbnail of the WSI. + + Parameters + ---------- + size : tuple of int + Desired (width, height) of the thumbnail. + + Returns + ------- + PIL.Image.Image + RGB thumbnail as a PIL Image. + """ + width, height = size + # takes the average ratio between the thumbsize and the object's (level dimension) size then applies abs(x - 1) so min finds + # the size ratio closest to 1 + get_dim_to_size_adjusted_ratio = lambda x: abs( + (((x[0] / width) + (x[1] / height)) / 2) - 1 + ) + # get the min index rather than value + closest_level = min( + range(self.level_count), + key=lambda i: list( + map(get_dim_to_size_adjusted_ratio, self.level_dimensions) + )[i], + ) + + thumbimg_data = ( + self.img.images[closest_level] + .data.compute() + .transpose(self._transpose_order) + ) + return Image.fromarray(thumbimg_data).convert("RGB").resize(size) diff --git a/trident/wsi_objects/WSIFactory.py b/trident/wsi_objects/WSIFactory.py index 7d5b5217..ac778601 100644 --- a/trident/wsi_objects/WSIFactory.py +++ b/trident/wsi_objects/WSIFactory.py @@ -6,11 +6,13 @@ from trident.wsi_objects.ImageWSI import ImageWSI from trident.wsi_objects.CuCIMWSI import CuCIMWSI from trident.wsi_objects.SDPCWSI import SDPCWSI -WSIReaderType = Literal['openslide', 'image', 'cucim', 'sdpc'] +from trident.wsi_objects.OMEZarrWSI import OMEZarrWSI +WSIReaderType = Literal['openslide', 'image', 'cucim', 'sdpc', 'omezarr'] OPENSLIDE_EXTENSIONS = {'.svs', '.tif', '.tiff', '.ndpi', '.vms', '.vmu', '.scn', '.mrxs'} CUCIM_EXTENSIONS = {'.svs', '.tif', '.tiff'} SDPC_EXTENSIONS = {'.sdpc'} PIL_EXTENSIONS = {'.png', '.jpg', '.jpeg'} +OMEZARR_EXTENSIONS = {'.zarr'} def load_wsi( @@ -18,7 +20,7 @@ def load_wsi( reader_type: Optional[WSIReaderType] = None, lazy_init: bool = False, **kwargs -) -> Union[OpenSlideWSI, ImageWSI, CuCIMWSI, SDPCWSI]: +) -> Union[OpenSlideWSI, ImageWSI, CuCIMWSI, SDPCWSI, OMEZarrWSI]: """ Load a whole-slide image (WSI) using the appropriate backend. @@ -30,7 +32,7 @@ def load_wsi( ---------- slide_path : str Path to the whole-slide image. - reader_type : {'openslide', 'image', 'cucim', 'sdpc'}, optional + reader_type : {'openslide', 'image', 'cucim', 'sdpc', 'omezarr'}, optional Manually specify the WSI reader to use. If None (default), selection is automatic based on file extension. lazy_init : bool, optional @@ -41,7 +43,7 @@ def load_wsi( Returns ------- - Union[OpenSlideWSI, ImageWSI, CuCIMWSI, SDPCWSI] + Union[OpenSlideWSI, ImageWSI, CuCIMWSI, SDPCWSI, OMEZarrWSI] An instance of the appropriate WSI reader. Raises @@ -53,7 +55,7 @@ def load_wsi( """ ext = os.path.splitext(slide_path)[1].lower() - assert reader_type in ['openslide', 'image', 'cucim', 'sdpc', None], f"Unknown reader_type: {reader_type}. Choose from 'openslide', 'image', 'cucim', or 'sdpc'." + assert reader_type in ['openslide', 'image', 'cucim', 'sdpc', 'omezarr', None], f"Unknown reader_type: {reader_type}. Choose from 'openslide', 'image', 'cucim', or 'sdpc' 'omezarr'." if reader_type == 'openslide': return OpenSlideWSI(slide_path=slide_path, lazy_init=lazy_init, **kwargs) @@ -78,11 +80,22 @@ def load_wsi( f"Unsupported file format '{ext}' for CuCIM. " f"Supported whole-slide image formats are: {', '.join(CUCIM_EXTENSIONS)}." ) - + + elif reader_type == 'omezarr': + if ext in OMEZARR_EXTENSIONS: + return OMEZarrWSI(slide_path=slide_path, lazy_init=lazy_init, **kwargs) + else: + raise ValueError( + f"Unsupported file format '{ext}' for Ome-Zarr. " + f"Supported whole-slide image formats are: {', '.join(OMEZARR_EXTENSIONS)}." + ) + elif reader_type is None: if ext in OPENSLIDE_EXTENSIONS: return OpenSlideWSI(slide_path=slide_path, lazy_init=lazy_init, **kwargs) elif ext in SDPC_EXTENSIONS: return SDPCWSI(slide_path=slide_path, lazy_init=lazy_init, **kwargs) + elif ext in OMEZARR_EXTENSIONS: + return OMEZarrWSI(slide_path=slide_path, lazy_init=lazy_init, **kwargs) else: return ImageWSI(slide_path=slide_path, lazy_init=lazy_init, **kwargs)