From cd368bd3573c6b13a7dab569bc0c28fda7bc4e9c Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 22 Oct 2025 10:00:29 +0100 Subject: [PATCH 01/15] :new: Define `DeepFeatureExtractor` --- .../models/engine/deep_feature_extractor.py | 177 ++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 tiatoolbox/models/engine/deep_feature_extractor.py diff --git a/tiatoolbox/models/engine/deep_feature_extractor.py b/tiatoolbox/models/engine/deep_feature_extractor.py new file mode 100644 index 000000000..6e2eef62c --- /dev/null +++ b/tiatoolbox/models/engine/deep_feature_extractor.py @@ -0,0 +1,177 @@ +"""Define DeepFeatureExtractor class.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +from typing_extensions import Unpack + +from tiatoolbox.models.dataset.dataset_abc import WSIStreamDataset + +from .semantic_segmentor import SemanticSegmentor, SemanticSegmentorRunParams + +if TYPE_CHECKING: # pragma: no cover + import os + from collections.abc import Callable + from pathlib import Path + + from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.models.engine.io_config import IOSegmentorConfig + from tiatoolbox.models.models_abc import ModelABC + from tiatoolbox.wsicore import WSIReader + + +class DeepFeatureExtractor(SemanticSegmentor): + """Generic CNN Feature Extractor. + + AN engine for using any CNN model as a feature extractor. Note, if + `model` is supplied in the arguments, it will ignore the + `pretrained_model` and `pretrained_weights` arguments. + + Args: + model (nn.Module): + Use externally defined PyTorch model for prediction with + weights already loaded. Default is `None`. If provided, + `pretrained_model` argument is ignored. + pretrained_model (str): + Name of the existing models support by tiatoolbox for + processing the data. By default, the corresponding + pretrained weights will also be downloaded. However, you can + override with your own set of weights via the + `pretrained_weights` argument. Argument is case-insensitive. + Refer to + :class:`tiatoolbox.models.architecture.vanilla.CNNBackbone` + for list of supported pretrained models. + pretrained_weights (str): + Path to the weight of the corresponding `pretrained_model`. + batch_size (int): + Number of images fed into the model each time. + num_loader_workers (int): + Number of workers to load the data. Take note that they will + also perform preprocessing. + num_postproc_workers (int): + This value is there to maintain input compatibility with + `tiatoolbox.models.classification` and is not used. + verbose (bool): + Whether to output logging information. + dataset_class (obj): + Dataset class to be used instead of default. + auto_generate_mask(bool): + To automatically generate tile/WSI tissue mask if is not + provided. + + Examples: + >>> # Sample output of a network + >>> from tiatoolbox.models.architecture.vanilla import CNNBackbone + >>> wsis = ['A/wsi.svs', 'B/wsi.svs'] + >>> # create resnet50 with pytorch pretrained weights + >>> model = CNNBackbone('resnet50') + >>> predictor = DeepFeatureExtractor(model=model) + >>> output = predictor.predict(wsis, mode='wsi') + >>> list(output.keys()) + [('A/wsi.svs', 'output/0') , ('B/wsi.svs', 'output/1')] + >>> # If a network have 2 output heads, for 'A/wsi.svs', + >>> # there will be 3 outputs, and they are respectively stored at + >>> # 'output/0.position.npy' # will always be output + >>> # 'output/0.features.0.npy' # output of head 0 + >>> # 'output/0.features.1.npy' # output of head 1 + >>> # Each file will contain a same number of items, and the item at each + >>> # index corresponds to 1 patch. The item in `.*position.npy` will + >>> # be the corresponding patch bounding box. The box coordinates are at + >>> # the inference resolution defined within the provided `ioconfig`. + + """ + + def __init__( + self: DeepFeatureExtractor, + model: str | ModelABC, + batch_size: int = 8, + num_workers: int = 0, + weights: str | Path | None = None, + dataset_class: Callable = WSIStreamDataset, + *, + device: str = "cpu", + verbose: bool = True, + ) -> None: + """Initialize :class:`DeepFeatureExtractor`.""" + super().__init__( + model=model, + batch_size=batch_size, + num_workers=num_workers, + weights=weights, + device=device, + verbose=verbose, + ) + self.process_prediction_per_batch = False + self.dataset_class = dataset_class + + def _process_predictions( + self: DeepFeatureExtractor, + cum_batch_predictions: list, + wsi_reader: WSIReader, # skipcq: PYL-W0613 # noqa: ARG002 + ioconfig: IOSegmentorConfig, + save_path: str, + cache_dir: str, # skipcq: PYL-W0613 # noqa: ARG002 + ) -> None: + """Define how the aggregated predictions are processed. + + This includes merging the prediction if necessary and also + saving afterward. + + Args: + cum_batch_predictions (list): + List of batch predictions. Each item within the list + should be of (location, patch_predictions). + wsi_reader (:class:`WSIReader`): + A reader for the image where the predictions come from. + Not used here. Added for consistency with the API. + ioconfig (:class:`IOSegmentorConfig`): + A configuration object contains input and output + information. + save_path (str): + Root path to save current WSI predictions. + cache_dir (str): + Root path to cache current WSI data. + Not used here. Added for consistency with the API. + + """ + # assume prediction_list is N, each item has L output elements + location_list, prediction_list = list(zip(*cum_batch_predictions, strict=False)) + # Nx4 (N x [tl_x, tl_y, br_x, br_y), denotes the location of output + # patch, this can exceed the image bound at the requested resolution + # remove singleton due to split. + location_list = np.array([v[0] for v in location_list]) + np.save(f"{save_path}.position.npy", location_list) + for idx, _ in enumerate(ioconfig.output_resolutions): + # assume resolution idx to be in the same order as L + # 0 idx is to remove singleton without removing other axes singleton + prediction_list = [v[idx][0] for v in prediction_list] + prediction_list = np.array(prediction_list) + np.save(f"{save_path}.features.{idx}.npy", prediction_list) + + def run( + self: DeepFeatureExtractor, + images: list[os.PathLike | Path | WSIReader] | np.ndarray, + masks: list[os.PathLike | Path] | np.ndarray | None = None, + labels: list | None = None, + ioconfig: IOSegmentorConfig | None = None, + *, + patch_mode: bool = True, + save_dir: os.PathLike | Path | None = None, + overwrite: bool = False, + output_type: str = "dict", + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> AnnotationStore | Path | str | dict | list[Path]: + """Run the DeepFeatureExtractor engine on input images.""" + return super().run( + images=images, + masks=masks, + labels=labels, + ioconfig=ioconfig, + patch_mode=patch_mode, + save_dir=save_dir, + overwrite=overwrite, + output_type=output_type, + **kwargs, + ) From f14fdaa8d1453800a750486cb8b4edcac59c8465 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 5 Nov 2025 14:16:13 +0000 Subject: [PATCH 02/15] :fire: Remove incorrect docstring --- .../models/engine/deep_feature_extractor.py | 60 +------------------ 1 file changed, 1 insertion(+), 59 deletions(-) diff --git a/tiatoolbox/models/engine/deep_feature_extractor.py b/tiatoolbox/models/engine/deep_feature_extractor.py index 6e2eef62c..4ed84da16 100644 --- a/tiatoolbox/models/engine/deep_feature_extractor.py +++ b/tiatoolbox/models/engine/deep_feature_extractor.py @@ -23,65 +23,7 @@ class DeepFeatureExtractor(SemanticSegmentor): - """Generic CNN Feature Extractor. - - AN engine for using any CNN model as a feature extractor. Note, if - `model` is supplied in the arguments, it will ignore the - `pretrained_model` and `pretrained_weights` arguments. - - Args: - model (nn.Module): - Use externally defined PyTorch model for prediction with - weights already loaded. Default is `None`. If provided, - `pretrained_model` argument is ignored. - pretrained_model (str): - Name of the existing models support by tiatoolbox for - processing the data. By default, the corresponding - pretrained weights will also be downloaded. However, you can - override with your own set of weights via the - `pretrained_weights` argument. Argument is case-insensitive. - Refer to - :class:`tiatoolbox.models.architecture.vanilla.CNNBackbone` - for list of supported pretrained models. - pretrained_weights (str): - Path to the weight of the corresponding `pretrained_model`. - batch_size (int): - Number of images fed into the model each time. - num_loader_workers (int): - Number of workers to load the data. Take note that they will - also perform preprocessing. - num_postproc_workers (int): - This value is there to maintain input compatibility with - `tiatoolbox.models.classification` and is not used. - verbose (bool): - Whether to output logging information. - dataset_class (obj): - Dataset class to be used instead of default. - auto_generate_mask(bool): - To automatically generate tile/WSI tissue mask if is not - provided. - - Examples: - >>> # Sample output of a network - >>> from tiatoolbox.models.architecture.vanilla import CNNBackbone - >>> wsis = ['A/wsi.svs', 'B/wsi.svs'] - >>> # create resnet50 with pytorch pretrained weights - >>> model = CNNBackbone('resnet50') - >>> predictor = DeepFeatureExtractor(model=model) - >>> output = predictor.predict(wsis, mode='wsi') - >>> list(output.keys()) - [('A/wsi.svs', 'output/0') , ('B/wsi.svs', 'output/1')] - >>> # If a network have 2 output heads, for 'A/wsi.svs', - >>> # there will be 3 outputs, and they are respectively stored at - >>> # 'output/0.position.npy' # will always be output - >>> # 'output/0.features.0.npy' # output of head 0 - >>> # 'output/0.features.1.npy' # output of head 1 - >>> # Each file will contain a same number of items, and the item at each - >>> # index corresponds to 1 patch. The item in `.*position.npy` will - >>> # be the corresponding patch bounding box. The box coordinates are at - >>> # the inference resolution defined within the provided `ioconfig`. - - """ + """Generic CNN Feature Extractor.""" def __init__( self: DeepFeatureExtractor, From aa4c812e025a48bcbe0db2e16f6b1a01a2bb0f6d Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 6 Nov 2025 17:26:58 +0000 Subject: [PATCH 03/15] :test_tube: Initial implementation --- tests/engines/test_feature_extractor.py | 166 ++++++++++++++ tiatoolbox/models/__init__.py | 2 + tiatoolbox/models/engine/__init__.py | 2 + .../models/engine/deep_feature_extractor.py | 207 ++++++++++++++---- 4 files changed, 339 insertions(+), 38 deletions(-) create mode 100644 tests/engines/test_feature_extractor.py diff --git a/tests/engines/test_feature_extractor.py b/tests/engines/test_feature_extractor.py new file mode 100644 index 000000000..7eed5ff03 --- /dev/null +++ b/tests/engines/test_feature_extractor.py @@ -0,0 +1,166 @@ +"""Test for feature extractor.""" + +import shutil +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import pytest +import torch +import zarr + +from tiatoolbox.models import IOSegmentorConfig +from tiatoolbox.models.architecture.vanilla import CNNBackbone, TimmBackbone +from tiatoolbox.models.engine.deep_feature_extractor import DeepFeatureExtractor +from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils.misc import select_device +from tiatoolbox.wsicore.wsireader import WSIReader + +ON_GPU = not toolbox_env.running_on_ci() and toolbox_env.has_gpu() + +# ------------------------------------------------------------------------------------- +# Engine +# ------------------------------------------------------------------------------------- + +device = "cuda" if toolbox_env.has_gpu() else "cpu" + + +def test_engine(remote_sample: Callable, track_tmp_path: Path) -> None: + """Test feature extraction with DeepFeatureExtractor engine.""" + save_dir = track_tmp_path / "output" + # # convert to pathlib Path to prevent wsireader complaint + mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) + + # * test providing pretrained from torch vs pretrained_model.yaml + shutil.rmtree(save_dir, ignore_errors=True) # default output dir test + + extractor = DeepFeatureExtractor(batch_size=1, model="fcn-tissue_mask") + output = extractor.run( + images=[mini_wsi_svs], + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=False, + save_dir=track_tmp_path / "wsi_out_check", + batch_size=2, + output_type="zarr", + ) + + output_ = zarr.open(output[mini_wsi_svs], mode="r") + assert len(output_["coordinates"].shape) == 2 + assert len(output_["probabilities"].shape) + + +@pytest.mark.parametrize( + "model", [CNNBackbone("resnet50"), TimmBackbone("efficientnet_b0", pretrained=True)] +) +def test_full_inference( + remote_sample: Callable, track_tmp_path: Path, model: Callable +) -> None: + """Test full inference with CNNBackbone and TimmBackbone models.""" + save_dir = track_tmp_path / "output" + # pre-emptive clean up + shutil.rmtree(save_dir, ignore_errors=True) # default output dir test + + mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) + + ioconfig = IOSegmentorConfig( + input_resolutions=[ + {"units": "mpp", "resolution": 0.25}, + ], + output_resolutions=[ + {"units": "mpp", "resolution": 0.25}, + ], + patch_input_shape=[512, 512], + patch_output_shape=[512, 512], + stride_shape=[256, 256], + save_resolution={"units": "mpp", "resolution": 8.0}, + ) + + extractor = DeepFeatureExtractor(batch_size=4, model=model) + output = extractor.run( + images=[mini_wsi_svs], + device=device, + save_dir=track_tmp_path / "wsi_out_check", + batch_size=2, + output_type="zarr", + ioconfig=ioconfig, + patch_mode=False, + ) + + output_ = zarr.open(output[mini_wsi_svs], mode="r") + + positions = output_["coordinates"] + features = output_["probabilities"] + + reader = WSIReader.open(mini_wsi_svs) + patches = [ + reader.read_bounds( + positions[patch_idx], + resolution=0.25, + units="mpp", + pad_constant_values=0, + coord_space="resolution", + ) + for patch_idx in range(4) + ] + patches = np.array(patches) + patches = torch.from_numpy(patches) # NHWC + patches = patches.permute(0, 3, 1, 2) # NCHW + patches = patches.type(torch.float32) + model = model.to("cpu") + # Inference mode + model.eval() + with torch.inference_mode(): + _features = model(patches).numpy() + # ! must maintain same batch size and likely same ordering + # ! else the output values will not exactly be the same (still < 1.0e-4 + # ! of epsilon though) + assert np.mean(np.abs(features[:4] - _features)) < 1.0e-1 + + +@pytest.mark.skipif( + toolbox_env.running_on_ci() or not ON_GPU, + reason="Local test on machine with GPU.", +) +def test_multi_gpu_feature_extraction( + remote_sample: Callable, track_tmp_path: Path +) -> None: + """Local functionality test for feature extraction using multiple GPUs.""" + save_dir = track_tmp_path / "output" + mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) + shutil.rmtree(save_dir, ignore_errors=True) + + # Use multiple GPUs + device = select_device(on_gpu=ON_GPU) + + wsi_ioconfig = IOSegmentorConfig( + input_resolutions=[{"units": "mpp", "resolution": 0.5}], + patch_input_shape=[224, 224], + output_resolutions=[{"units": "mpp", "resolution": 0.5}], + patch_output_shape=[224, 224], + stride_shape=[224, 224], + ) + + model = TimmBackbone(backbone="UNI", pretrained=True) + extractor = DeepFeatureExtractor( + model=model, + auto_generate_mask=True, + batch_size=32, + num_loader_workers=4, + num_postproc_workers=4, + ) + + output_list = extractor.predict( + [mini_wsi_svs], + mode="wsi", + device=device, + ioconfig=wsi_ioconfig, + crash_on_exception=True, + save_dir=save_dir, + ) + wsi_0_root_path = output_list[0][1] + positions = np.load(f"{wsi_0_root_path}.position.npy") + features = np.load(f"{wsi_0_root_path}.features.0.npy") + assert len(positions.shape) == 2 + assert len(features.shape) == 2 diff --git a/tiatoolbox/models/__init__.py b/tiatoolbox/models/__init__.py index 5de543aad..4265f7caf 100644 --- a/tiatoolbox/models/__init__.py +++ b/tiatoolbox/models/__init__.py @@ -11,6 +11,7 @@ from .architecture.nuclick import NuClick from .architecture.sccnn import SCCNN from .dataset import PatchDataset, WSIPatchDataset, WSIStreamDataset +from .engine.deep_feature_extractor import DeepFeatureExtractor from .engine.io_config import ( IOInstanceSegmentorConfig, IOPatchPredictorConfig, @@ -24,6 +25,7 @@ __all__ = [ "SCCNN", + "DeepFeatureExtractor", "HoVerNet", "HoVerNetPlus", "IDaRS", diff --git a/tiatoolbox/models/engine/__init__.py b/tiatoolbox/models/engine/__init__.py index 9c00ac4a2..2509b51a3 100644 --- a/tiatoolbox/models/engine/__init__.py +++ b/tiatoolbox/models/engine/__init__.py @@ -1,6 +1,7 @@ """Engines to run models implemented in tiatoolbox.""" from . import ( + deep_feature_extractor, engine_abc, nucleus_instance_segmentor, patch_predictor, @@ -8,6 +9,7 @@ ) __all__ = [ + "deep_feature_extractor", "engine_abc", "nucleus_instance_segmentor", "patch_predictor", diff --git a/tiatoolbox/models/engine/deep_feature_extractor.py b/tiatoolbox/models/engine/deep_feature_extractor.py index 4ed84da16..2abd80bec 100644 --- a/tiatoolbox/models/engine/deep_feature_extractor.py +++ b/tiatoolbox/models/engine/deep_feature_extractor.py @@ -4,10 +4,11 @@ from typing import TYPE_CHECKING -import numpy as np +import dask.array as da from typing_extensions import Unpack from tiatoolbox.models.dataset.dataset_abc import WSIStreamDataset +from tiatoolbox.utils.misc import get_tqdm from .semantic_segmentor import SemanticSegmentor, SemanticSegmentorRunParams @@ -16,6 +17,9 @@ from collections.abc import Callable from pathlib import Path + import numpy as np + from torch.utils.data import DataLoader + from tiatoolbox.annotation import AnnotationStore from tiatoolbox.models.engine.io_config import IOSegmentorConfig from tiatoolbox.models.models_abc import ModelABC @@ -48,49 +52,173 @@ def __init__( self.process_prediction_per_batch = False self.dataset_class = dataset_class - def _process_predictions( + def infer_wsi( + self: SemanticSegmentor, + dataloader: DataLoader, + save_path: Path, + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> dict[str, da.Array]: + """Perform model inference on a whole slide image (WSI). + + This method processes a WSI using the provided DataLoader, merges + patch-level predictions into a full-resolution canvas, and returns + the aggregated output. It supports memory-aware caching and optional + inclusion of coordinates and labels. + + Args: + dataloader (DataLoader): + PyTorch DataLoader configured for WSI processing. + save_path (Path): + Path to save the intermediate output. The intermediate output + is saved in a Zarr file. + **kwargs (SemanticSegmentorRunParams): + Additional runtime parameters, including: + - return_probabilities (bool): Whether to return probability maps. + - return_labels (bool): Whether to include labels in the output. + - memory_threshold (int): Memory usage threshold to trigger disk + caching. + + Returns: + dict[str, dask.array.Array]: + Dictionary containing merged prediction results: + - "probabilities": Full-resolution probability map. + - "coordinates": Patch coordinates. + - "labels": Ground truth labels (if `return_labels` is True). + + """ + _ = kwargs.get("patch_mode", False) + _ = save_path + keys = ["probabilities", "coordinates"] + probabilities, coordinates = [], [] + + # Main output dictionary + raw_predictions = dict( + zip(keys, [da.empty(shape=(0, 0))] * len(keys), strict=False) + ) + + # Inference loop + tqdm = get_tqdm() + tqdm_loop = ( + tqdm(dataloader, leave=False, desc="Inferring patches") + if self.verbose + else dataloader + ) + + for batch_data in tqdm_loop: + batch_output = self.model.infer_batch( + self.model, + batch_data["image"], + device=self.device, + ) + + probabilities.append(da.from_array(batch_output)) + coordinates.append( + da.from_array( + self._get_coordinates(batch_data), + ) + ) + + raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0) + raw_predictions["probabilities"] = da.concatenate(probabilities, axis=0) + + return raw_predictions + + def post_process_patches( self: DeepFeatureExtractor, - cum_batch_predictions: list, - wsi_reader: WSIReader, # skipcq: PYL-W0613 # noqa: ARG002 - ioconfig: IOSegmentorConfig, - save_path: str, - cache_dir: str, # skipcq: PYL-W0613 # noqa: ARG002 - ) -> None: - """Define how the aggregated predictions are processed. + raw_predictions: da.Array, + prediction_shape: tuple[int, ...], + prediction_dtype: type, + **kwargs: Unpack[DeepFeatureExtractor], + ) -> da.Array: + """Post-process raw patch predictions from model inference. - This includes merging the prediction if necessary and also - saving afterward. + This method applies the model's post-processing function to the raw predictions + obtained from `infer_patches()`. The output is wrapped in a Dask array for + efficient computation and memory handling. Args: - cum_batch_predictions (list): - List of batch predictions. Each item within the list - should be of (location, patch_predictions). - wsi_reader (:class:`WSIReader`): - A reader for the image where the predictions come from. - Not used here. Added for consistency with the API. - ioconfig (:class:`IOSegmentorConfig`): - A configuration object contains input and output - information. - save_path (str): - Root path to save current WSI predictions. - cache_dir (str): - Root path to cache current WSI data. - Not used here. Added for consistency with the API. + raw_predictions (da.Array | np.ndarray): + Raw model predictions. + prediction_shape (tuple[int, ...]): + Expected shape of the prediction output. + prediction_dtype (type): + Data type of the prediction output. + **kwargs (PredictorRunParams): + Additional runtime parameters, including `return_probabilities`. + + Returns: + dask.array.Array: Post-processed predictions as a Dask array. """ - # assume prediction_list is N, each item has L output elements - location_list, prediction_list = list(zip(*cum_batch_predictions, strict=False)) - # Nx4 (N x [tl_x, tl_y, br_x, br_y), denotes the location of output - # patch, this can exceed the image bound at the requested resolution - # remove singleton due to split. - location_list = np.array([v[0] for v in location_list]) - np.save(f"{save_path}.position.npy", location_list) - for idx, _ in enumerate(ioconfig.output_resolutions): - # assume resolution idx to be in the same order as L - # 0 idx is to remove singleton without removing other axes singleton - prediction_list = [v[idx][0] for v in prediction_list] - prediction_list = np.array(prediction_list) - np.save(f"{save_path}.features.{idx}.npy", prediction_list) + _ = kwargs.get("return_probabilities") + _ = prediction_shape + _ = prediction_dtype + + return raw_predictions + + def _update_run_params( + self: SemanticSegmentor, + images: list[os.PathLike | Path | WSIReader] | np.ndarray, + masks: list[os.PathLike | Path] | np.ndarray | None = None, + labels: list | None = None, + save_dir: os.PathLike | Path | None = None, + ioconfig: IOSegmentorConfig | None = None, + output_type: str = "dict", + *, + overwrite: bool = False, + patch_mode: bool, + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> Path | None: + """Update runtime parameters for the PatchPredictor engine. + + This method sets internal attributes such as caching, batch size, + IO configuration, and output format based on user input and keyword arguments. + It also configures whether to include probabilities in the output. + + Args: + images (list[PathLike | WSIReader] | np.ndarray): + Input images or patches. + masks (list[PathLike] | np.ndarray | None): + Optional masks for WSI processing. + labels (list | None): + Optional labels for input images. + save_dir (PathLike | None): + Directory to save output files. Required for WSI mode. + ioconfig (ModelIOConfigABC | None): + IO configuration for patch extraction and resolution. + output_type (str): + Desired output format: "dict", "zarr", or "annotationstore". + overwrite (bool): + Whether to overwrite existing output files. Default is False. + patch_mode (bool): + Whether to treat input as patches (`True`) or WSIs (`False`). + **kwargs (SemanticSegmentorRunParams): + Additional runtime parameters. + + Returns: + Path | None: + Path to the save directory if applicable, otherwise None. + + Raises: + ValueError: + If `labels` are requested for WSI processing. + + """ + if output_type != "zarr": + msg = "Only zarr output is supported for `DeepFeatureExtractor`." + raise ValueError(msg) + + return super()._update_run_params( + images=images, + masks=masks, + labels=labels, + save_dir=save_dir, + ioconfig=ioconfig, + overwrite=overwrite, + patch_mode=patch_mode, + output_type=output_type, + **kwargs, + ) def run( self: DeepFeatureExtractor, @@ -106,6 +234,9 @@ def run( **kwargs: Unpack[SemanticSegmentorRunParams], ) -> AnnotationStore | Path | str | dict | list[Path]: """Run the DeepFeatureExtractor engine on input images.""" + # return_probabilities is always True for FeatureExtractor. + kwargs["return_probabilities"] = True + return super().run( images=images, masks=masks, From 4df8ea4095405c1e4d59a2281c6db311de26d1a3 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 6 Nov 2025 17:39:11 +0000 Subject: [PATCH 04/15] :test_tube: Initial implementation --- .../models/engine/deep_feature_extractor.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tiatoolbox/models/engine/deep_feature_extractor.py b/tiatoolbox/models/engine/deep_feature_extractor.py index 2abd80bec..32904e649 100644 --- a/tiatoolbox/models/engine/deep_feature_extractor.py +++ b/tiatoolbox/models/engine/deep_feature_extractor.py @@ -114,7 +114,7 @@ def infer_wsi( probabilities.append(da.from_array(batch_output)) coordinates.append( da.from_array( - self._get_coordinates(batch_data), + batch_data["output_locs"].numpy(), ) ) @@ -156,6 +156,20 @@ def post_process_patches( return raw_predictions + def save_predictions( + self: SemanticSegmentor, + processed_predictions: dict, + output_type: str, + save_path: Path | None = None, + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> dict | Path: + """Save patch predictions to disk.""" + # no need to compute predictions + self.drop_keys.append("predictions") + return super().save_predictions( + processed_predictions, output_type, save_path=save_path, **kwargs + ) + def _update_run_params( self: SemanticSegmentor, images: list[os.PathLike | Path | WSIReader] | np.ndarray, From 8460a2d0a292868c000c5d32a5ebaf3f1144131a Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 6 Nov 2025 18:09:18 +0000 Subject: [PATCH 05/15] :white_check_mark: Add tests for `DeepFeatuureExtractor` --- tests/engines/test_feature_extractor.py | 19 +++++++++---------- .../models/engine/deep_feature_extractor.py | 4 ++-- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/engines/test_feature_extractor.py b/tests/engines/test_feature_extractor.py index 7eed5ff03..b15918a00 100644 --- a/tests/engines/test_feature_extractor.py +++ b/tests/engines/test_feature_extractor.py @@ -82,7 +82,7 @@ def test_full_inference( images=[mini_wsi_svs], device=device, save_dir=track_tmp_path / "wsi_out_check", - batch_size=2, + batch_size=4, output_type="zarr", ioconfig=ioconfig, patch_mode=False, @@ -145,22 +145,21 @@ def test_multi_gpu_feature_extraction( model = TimmBackbone(backbone="UNI", pretrained=True) extractor = DeepFeatureExtractor( model=model, - auto_generate_mask=True, batch_size=32, - num_loader_workers=4, - num_postproc_workers=4, + num_workers=4, ) - output_list = extractor.predict( + output = extractor.run( [mini_wsi_svs], - mode="wsi", + patch_mode=False, device=device, ioconfig=wsi_ioconfig, - crash_on_exception=True, save_dir=save_dir, + auto_get_mask=True, ) - wsi_0_root_path = output_list[0][1] - positions = np.load(f"{wsi_0_root_path}.position.npy") - features = np.load(f"{wsi_0_root_path}.features.0.npy") + output_ = zarr.open(output[mini_wsi_svs], mode="r") + + positions = output_["coordinates"] + features = output_["probabilities"] assert len(positions.shape) == 2 assert len(features.shape) == 2 diff --git a/tiatoolbox/models/engine/deep_feature_extractor.py b/tiatoolbox/models/engine/deep_feature_extractor.py index 32904e649..e9a420833 100644 --- a/tiatoolbox/models/engine/deep_feature_extractor.py +++ b/tiatoolbox/models/engine/deep_feature_extractor.py @@ -111,10 +111,10 @@ def infer_wsi( device=self.device, ) - probabilities.append(da.from_array(batch_output)) + probabilities.append(da.from_array(batch_output[0])) coordinates.append( da.from_array( - batch_data["output_locs"].numpy(), + self._get_coordinates(batch_data), ) ) From c9f0e59d369887aff76cc02131c3125b58ca4c5e Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 7 Nov 2025 15:40:25 +0000 Subject: [PATCH 06/15] :bug: Fix error due to inconsistent results Results are inconsistent as the model is redefined on a different device. --- tests/engines/test_feature_extractor.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/engines/test_feature_extractor.py b/tests/engines/test_feature_extractor.py index b15918a00..a674961aa 100644 --- a/tests/engines/test_feature_extractor.py +++ b/tests/engines/test_feature_extractor.py @@ -99,20 +99,20 @@ def test_full_inference( positions[patch_idx], resolution=0.25, units="mpp", - pad_constant_values=0, + pad_constant_values=255, coord_space="resolution", ) for patch_idx in range(4) ] patches = np.array(patches) patches = torch.from_numpy(patches) # NHWC - patches = patches.permute(0, 3, 1, 2) # NCHW - patches = patches.type(torch.float32) - model = model.to("cpu") + patches = patches.permute(0, 3, 1, 2).contiguous() # NCHW + patches = patches.to(device).type(torch.float32) + model = extractor.model # Inference mode model.eval() with torch.inference_mode(): - _features = model(patches).numpy() + _features = model(patches).cpu().numpy() # ! must maintain same batch size and likely same ordering # ! else the output values will not exactly be the same (still < 1.0e-4 # ! of epsilon though) From 35c964bc2b4b842b9785f62b7a30976f6f1204c5 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 7 Nov 2025 16:27:19 +0000 Subject: [PATCH 07/15] :white_check_mark: Add tests for coverage and update docstrings. --- tests/engines/test_feature_extractor.py | 41 +++- .../models/engine/deep_feature_extractor.py | 182 ++++++++++++++---- .../models/engine/semantic_segmentor.py | 2 +- 3 files changed, 187 insertions(+), 38 deletions(-) diff --git a/tests/engines/test_feature_extractor.py b/tests/engines/test_feature_extractor.py index a674961aa..b55ba19eb 100644 --- a/tests/engines/test_feature_extractor.py +++ b/tests/engines/test_feature_extractor.py @@ -25,7 +25,46 @@ device = "cuda" if toolbox_env.has_gpu() else "cpu" -def test_engine(remote_sample: Callable, track_tmp_path: Path) -> None: +def test_feature_extractor_patches( + remote_sample: Callable, track_tmp_path: Path +) -> None: + """Tests DeepFeatureExtractor on image patches.""" + extractor = DeepFeatureExtractor( + model="fcn-tissue_mask", batch_size=32, verbose=False, device=device + ) + + sample_image = remote_sample("thumbnail-1k-1k") + + inputs = [sample_image, sample_image] + + assert not extractor.patch_mode + output = extractor.run( + images=inputs, + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=True, + output_type="zarr", + save_dir=track_tmp_path / "wsi_out_check", + ) + + output_ = zarr.open(output, mode="r") + + assert 0.48 < np.mean(output_["probabilities"][:]) < 0.52 + + with pytest.raises( + ValueError, match=r".*Only zarr output is supported for `DeepFeatureExtractor`" + ): + _ = extractor.run( + images=inputs, + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=True, + ) + + +def test_feature_extractor_wsi(remote_sample: Callable, track_tmp_path: Path) -> None: """Test feature extraction with DeepFeatureExtractor engine.""" save_dir = track_tmp_path / "output" # # convert to pathlib Path to prevent wsireader complaint diff --git a/tiatoolbox/models/engine/deep_feature_extractor.py b/tiatoolbox/models/engine/deep_feature_extractor.py index e9a420833..65a46c96f 100644 --- a/tiatoolbox/models/engine/deep_feature_extractor.py +++ b/tiatoolbox/models/engine/deep_feature_extractor.py @@ -7,14 +7,12 @@ import dask.array as da from typing_extensions import Unpack -from tiatoolbox.models.dataset.dataset_abc import WSIStreamDataset from tiatoolbox.utils.misc import get_tqdm from .semantic_segmentor import SemanticSegmentor, SemanticSegmentorRunParams if TYPE_CHECKING: # pragma: no cover import os - from collections.abc import Callable from pathlib import Path import numpy as np @@ -27,7 +25,37 @@ class DeepFeatureExtractor(SemanticSegmentor): - """Generic CNN Feature Extractor.""" + r"""Generic CNN-based feature extractor for digital pathology images. + + This class extends :class:`SemanticSegmentor` to extract deep features from + whole slide images (WSIs) or image patches using a CNN model. It is designed + for use cases where the goal is to obtain intermediate feature representations + (e.g., embeddings) rather than final classification or segmentation outputs. + + The extracted features are returned or saved in Zarr format for downstream + analysis, such as clustering, visualization, or training other machine learning + models. + + Args: + model (str | ModelABC): + A PyTorch model instance or the name of a pretrained model from TIAToolbox. + batch_size (int): + Number of image patches processed per forward pass. Default is 8. + num_workers (int): + Number of workers for data loading. Default is 0. + weights (str | Path | None): + Path to model weights. If None, default weights are used. + device (str): + Device to run the model on (e.g., "cpu", "cuda"). Default is "cpu". + verbose (bool): + Whether to enable verbose logging. Default is True. + + Attributes: + process_prediction_per_batch (bool): + Flag to control whether predictions are processed per batch. + Default is False. + + """ def __init__( self: DeepFeatureExtractor, @@ -35,12 +63,29 @@ def __init__( batch_size: int = 8, num_workers: int = 0, weights: str | Path | None = None, - dataset_class: Callable = WSIStreamDataset, *, device: str = "cpu", verbose: bool = True, ) -> None: - """Initialize :class:`DeepFeatureExtractor`.""" + """Initialize :class:`DeepFeatureExtractor`. + + Args: + model (str | ModelABC): + A PyTorch model instance or the name of a pretrained model from + TIAToolbox. If a string is provided, the corresponding pretrained + weights will be downloaded unless overridden via `weights`. + batch_size (int): + Number of image patches processed per forward pass. Default is 8. + num_workers (int): + Number of workers for data loading. Default is 0. + weights (str | Path | None): + Path to model weights. If None, default weights are used. + device (str): + Device to run the model on (e.g., "cpu", "cuda"). Default is "cpu". + verbose (bool): + Whether to enable verbose logging. Default is True. + + """ super().__init__( model=model, batch_size=batch_size, @@ -50,7 +95,6 @@ def __init__( verbose=verbose, ) self.process_prediction_per_batch = False - self.dataset_class = dataset_class def infer_wsi( self: SemanticSegmentor, @@ -60,30 +104,25 @@ def infer_wsi( ) -> dict[str, da.Array]: """Perform model inference on a whole slide image (WSI). - This method processes a WSI using the provided DataLoader, merges - patch-level predictions into a full-resolution canvas, and returns - the aggregated output. It supports memory-aware caching and optional - inclusion of coordinates and labels. + This method processes a WSI using the provided DataLoader and extracts + deep features from each patch using the model. The extracted features + are returned as a Dask array along with the corresponding patch coordinates. Args: dataloader (DataLoader): PyTorch DataLoader configured for WSI processing. save_path (Path): - Path to save the intermediate output. The intermediate output - is saved in a Zarr file. + Path to save the intermediate output. (Unused in this implementation.) **kwargs (SemanticSegmentorRunParams): Additional runtime parameters, including: - - return_probabilities (bool): Whether to return probability maps. - - return_labels (bool): Whether to include labels in the output. - - memory_threshold (int): Memory usage threshold to trigger disk - caching. + - return_probabilities (bool): Whether to return feature maps. + - memory_threshold (int): Memory usage threshold for caching. Returns: dict[str, dask.array.Array]: - Dictionary containing merged prediction results: - - "probabilities": Full-resolution probability map. - - "coordinates": Patch coordinates. - - "labels": Ground truth labels (if `return_labels` is True). + Dictionary containing: + - "probabilities": Extracted feature maps from the model. + - "coordinates": Patch coordinates corresponding to the features. """ _ = kwargs.get("patch_mode", False) @@ -128,26 +167,27 @@ def post_process_patches( raw_predictions: da.Array, prediction_shape: tuple[int, ...], prediction_dtype: type, - **kwargs: Unpack[DeepFeatureExtractor], + **kwargs: Unpack[SemanticSegmentorRunParams], ) -> da.Array: """Post-process raw patch predictions from model inference. - This method applies the model's post-processing function to the raw predictions - obtained from `infer_patches()`. The output is wrapped in a Dask array for - efficient computation and memory handling. + This method overrides the base implementation to return raw feature maps + without applying any additional processing. It is intended for use cases + where intermediate CNN features are required as output. Args: - raw_predictions (da.Array | np.ndarray): - Raw model predictions. + raw_predictions (dask.array.Array): + Raw model predictions as a Dask array. prediction_shape (tuple[int, ...]): Expected shape of the prediction output. prediction_dtype (type): Data type of the prediction output. - **kwargs (PredictorRunParams): - Additional runtime parameters, including `return_probabilities`. + **kwargs (SemanticSegmentorRunParams): + Additional runtime parameters. Returns: - dask.array.Array: Post-processed predictions as a Dask array. + dask.array.Array: + Unmodified raw predictions. """ _ = kwargs.get("return_probabilities") @@ -163,7 +203,35 @@ def save_predictions( save_path: Path | None = None, **kwargs: Unpack[SemanticSegmentorRunParams], ) -> dict | Path: - """Save patch predictions to disk.""" + """Save patch-level feature predictions to disk or return them in memory. + + This method saves the extracted deep features in the specified output format. + Only the "zarr" format is supported for this engine. The method disables + saving the "predictions" key, as it is not relevant for feature extraction. + + Args: + processed_predictions (dict): + Dictionary containing processed model outputs. + output_type (str): + Desired output format. Must be "zarr". + save_path (Path | None): + Path to save the output file. Required for "zarr" format. + **kwargs (SemanticSegmentorRunParams): + Additional runtime parameters, including: + - output_file (str): Name of the output file. + - scale_factor (tuple[float, float]): For coordinate transformation. + - class_dict (dict): Optional class index-to-name mapping. + + Returns: + dict | Path: + - If `output_type` is "zarr": returns the path to the saved Zarr file. + - If `output_type` is "dict": returns predictions as a dictionary. + + Raises: + ValueError: + If an unsupported output format is provided. + + """ # no need to compute predictions self.drop_keys.append("predictions") return super().save_predictions( @@ -183,11 +251,11 @@ def _update_run_params( patch_mode: bool, **kwargs: Unpack[SemanticSegmentorRunParams], ) -> Path | None: - """Update runtime parameters for the PatchPredictor engine. + """Update runtime parameters for the DeepFeatureExtractor engine. This method sets internal attributes such as caching, batch size, IO configuration, and output format based on user input and keyword arguments. - It also configures whether to include probabilities in the output. + It also validates that the output format is supported. Args: images (list[PathLike | WSIReader] | np.ndarray): @@ -198,10 +266,10 @@ def _update_run_params( Optional labels for input images. save_dir (PathLike | None): Directory to save output files. Required for WSI mode. - ioconfig (ModelIOConfigABC | None): + ioconfig (IOSegmentorConfig | None): IO configuration for patch extraction and resolution. output_type (str): - Desired output format: "dict", "zarr", or "annotationstore". + Desired output format. Must be "zarr". overwrite (bool): Whether to overwrite existing output files. Default is False. patch_mode (bool): @@ -215,7 +283,7 @@ def _update_run_params( Raises: ValueError: - If `labels` are requested for WSI processing. + If `output_type` is not "zarr", which is the only supported format. """ if output_type != "zarr": @@ -247,7 +315,49 @@ def run( output_type: str = "dict", **kwargs: Unpack[SemanticSegmentorRunParams], ) -> AnnotationStore | Path | str | dict | list[Path]: - """Run the DeepFeatureExtractor engine on input images.""" + """Run the DeepFeatureExtractor engine on input images. + + This method orchestrates the full inference pipeline, including preprocessing, + model inference, and saving of extracted deep features. It supports both + patch-level and whole slide image (WSI) modes. The output is returned or saved + in Zarr format. + + Note: + The `return_probabilities` flag is always set to True for this engine, + as it is designed to extract intermediate feature maps. + + Args: + images (list[PathLike | WSIReader] | np.ndarray): + Input images or patches. Can be a list of file paths, WSIReader objects, + or a NumPy array of image patches. + masks (list[PathLike] | np.ndarray | None): + Optional masks for WSI processing. Only used when `patch_mode` is False. + labels (list | None): + Optional labels for input images. Only one label per image is supported. + ioconfig (IOSegmentorConfig | None): + IO configuration for patch extraction and resolution. + patch_mode (bool): + Whether to treat input as patches (`True`) or WSIs (`False`). + Default is True. + save_dir (PathLike | None): + Directory to save output files. Required for WSI mode. + overwrite (bool): + Whether to overwrite existing output files. Default is False. + output_type (str): + Desired output format. Must be "zarr". + **kwargs (SemanticSegmentorRunParams): + Additional runtime parameters to update engine attributes. + + Returns: + AnnotationStore | Path | str | dict | list[Path]: + - If `patch_mode` is True: returns predictions or path to saved output. + - If `patch_mode` is False: returns a dictionary mapping each WSI + to its output path. + + Raises: + ValueError: + If `output_type` is not "zarr". + """ # return_probabilities is always True for FeatureExtractor. kwargs["return_probabilities"] = True diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index a33bcf028..1a0e4d95f 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -670,7 +670,7 @@ def _update_run_params( patch_mode: bool, **kwargs: Unpack[SemanticSegmentorRunParams], ) -> Path | None: - """Update runtime parameters for the PatchPredictor engine. + """Update runtime parameters for the SemanticSegmentor engine. This method sets internal attributes such as caching, batch size, IO configuration, and output format based on user input and keyword arguments. From 4b6df14993853ec849745685d22cafcffc359638 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 10 Nov 2025 16:57:35 +0000 Subject: [PATCH 08/15] :white_check_mark: Add cache support for large WSIs. --- tests/engines/test_feature_extractor.py | 3 +- .../models/engine/deep_feature_extractor.py | 131 +++++++++++++++++- 2 files changed, 127 insertions(+), 7 deletions(-) diff --git a/tests/engines/test_feature_extractor.py b/tests/engines/test_feature_extractor.py index b55ba19eb..270cabc8c 100644 --- a/tests/engines/test_feature_extractor.py +++ b/tests/engines/test_feature_extractor.py @@ -44,7 +44,6 @@ def test_feature_extractor_patches( return_labels=False, device=device, patch_mode=True, - output_type="zarr", save_dir=track_tmp_path / "wsi_out_check", ) @@ -83,6 +82,7 @@ def test_feature_extractor_wsi(remote_sample: Callable, track_tmp_path: Path) -> save_dir=track_tmp_path / "wsi_out_check", batch_size=2, output_type="zarr", + memory_threshold=1, ) output_ = zarr.open(output[mini_wsi_svs], mode="r") @@ -195,6 +195,7 @@ def test_multi_gpu_feature_extraction( ioconfig=wsi_ioconfig, save_dir=save_dir, auto_get_mask=True, + output_type="zarr", ) output_ = zarr.open(output[mini_wsi_svs], mode="r") diff --git a/tiatoolbox/models/engine/deep_feature_extractor.py b/tiatoolbox/models/engine/deep_feature_extractor.py index 65a46c96f..609fbe176 100644 --- a/tiatoolbox/models/engine/deep_feature_extractor.py +++ b/tiatoolbox/models/engine/deep_feature_extractor.py @@ -2,9 +2,13 @@ from __future__ import annotations +import gc from typing import TYPE_CHECKING import dask.array as da +import psutil +import zarr +from dask import compute from typing_extensions import Unpack from tiatoolbox.utils.misc import get_tqdm @@ -24,6 +28,62 @@ from tiatoolbox.wsicore import WSIReader +def save_to_cache( + probabilities: list[da.Array], + coordinates: list[da.Array], + probabilities_zarr: zarr.Array, + coordinates_zarr: zarr.Array, + save_path: str | Path = "temp.zarr", +) -> tuple[zarr.Array, zarr.Array]: + """Save to cache.""" + if len(probabilities) == 0: + return probabilities_zarr, coordinates_zarr + + coordinates = da.concatenate(coordinates, axis=0) + probabilities = da.concatenate(probabilities, axis=0) + + computed_values = compute(*[probabilities, coordinates]) + probabilities_computed, coordinates_computed = computed_values + + chunk_shape = tuple(chunk[0] for chunk in probabilities.chunks) + if probabilities_zarr is None: + zarr_group = zarr.open(str(save_path), mode="w") + + probabilities_zarr = zarr_group.create_dataset( + name="canvas", + shape=(0, *probabilities_computed.shape[1:]), + chunks=(chunk_shape[0], *probabilities_computed.shape[1:]), + dtype=probabilities_computed.dtype, + overwrite=True, + ) + + coordinates_zarr = zarr_group.create_dataset( + name="count", + shape=(0, *coordinates_computed.shape[1:]), + dtype=coordinates_computed.dtype, + chunks=(chunk_shape[0], *coordinates_computed.shape[1:]), + overwrite=True, + ) + + probabilities_zarr.resize( + ( + probabilities_zarr.shape[0] + probabilities_computed.shape[0], + *probabilities_zarr.shape[1:], + ) + ) + probabilities_zarr[-probabilities_computed.shape[0] :] = probabilities_computed + + coordinates_zarr.resize( + ( + coordinates_zarr.shape[0] + coordinates_computed.shape[0], + *coordinates_zarr.shape[1:], + ) + ) + coordinates_zarr[-coordinates_computed.shape[0] :] = coordinates_computed + + return probabilities_zarr, coordinates_zarr + + class DeepFeatureExtractor(SemanticSegmentor): r"""Generic CNN-based feature extractor for digital pathology images. @@ -125,7 +185,9 @@ def infer_wsi( - "coordinates": Patch coordinates corresponding to the features. """ - _ = kwargs.get("patch_mode", False) + # Default Memory threshold percentage is 80. + memory_threshold = kwargs.get("memory_threshold", 80) + vm = psutil.virtual_memory() _ = save_path keys = ["probabilities", "coordinates"] probabilities, coordinates = [], [] @@ -143,6 +205,9 @@ def infer_wsi( else dataloader ) + probabilities_zarr, coordinates_zarr = None, None + + probabilities_used_percent = 0 for batch_data in tqdm_loop: batch_output = self.model.infer_batch( self.model, @@ -157,8 +222,59 @@ def infer_wsi( ) ) - raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0) - raw_predictions["probabilities"] = da.concatenate(probabilities, axis=0) + used_percent = vm.percent + probabilities_used_percent = ( + probabilities_used_percent + (probabilities[-1].nbytes / vm.free) * 100 + ) + if ( + used_percent > memory_threshold + or probabilities_used_percent > memory_threshold + ): + tqdm_loop.desc = "Spill intermediate data to disk" + used_percent = ( + probabilities_used_percent + if (probabilities_used_percent > memory_threshold) + else used_percent + ) + msg = ( + f"Current Memory usage: {used_percent} % " + f"exceeds specified threshold: {memory_threshold}. " + f"Saving intermediate results to disk." + ) + + tqdm.write(msg) + # Flush data in Memory and clear dask graph + probabilities_zarr, coordinates_zarr = save_to_cache( + probabilities, + coordinates, + probabilities_zarr, + coordinates_zarr, + save_path=save_path, + ) + + probabilities, coordinates = [], [] + probabilities_used_percent = 0 + gc.collect() + tqdm_loop.desc = "Inferring patches" + + if probabilities_zarr is not None: + probabilities_zarr, coordinates_zarr = save_to_cache( + probabilities, + coordinates, + probabilities_zarr, + coordinates_zarr, + save_path=save_path, + ) + # Wrap zarr in dask array + raw_predictions["probabilities"] = da.from_zarr( + probabilities_zarr, chunks=probabilities_zarr.chunks + ) + raw_predictions["coordinates"] = da.from_zarr( + coordinates_zarr, chunks=coordinates_zarr.chunks + ) + else: + raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0) + raw_predictions["probabilities"] = da.concatenate(probabilities, axis=0) return raw_predictions @@ -286,8 +402,11 @@ def _update_run_params( If `output_type` is not "zarr", which is the only supported format. """ - if output_type != "zarr": - msg = "Only zarr output is supported for `DeepFeatureExtractor`." + if output_type not in ["zarr", "dict"]: + msg = ( + f"output_type: `{output_type}` is not supported for " + f"`DeepFeatureExtractor` engine." + ) raise ValueError(msg) return super()._update_run_params( @@ -344,7 +463,7 @@ def run( overwrite (bool): Whether to overwrite existing output files. Default is False. output_type (str): - Desired output format. Must be "zarr". + Desired output format. Must be "zarr" or "dict". **kwargs (SemanticSegmentorRunParams): Additional runtime parameters to update engine attributes. From 998ddcbcd0438eb43c71e70980dc4eb1f4d0b813 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 10 Nov 2025 17:02:27 +0000 Subject: [PATCH 09/15] :white_check_mark: Add support for `dict` output. --- tests/engines/test_feature_extractor.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/engines/test_feature_extractor.py b/tests/engines/test_feature_extractor.py index 270cabc8c..96d1a0361 100644 --- a/tests/engines/test_feature_extractor.py +++ b/tests/engines/test_feature_extractor.py @@ -26,7 +26,7 @@ def test_feature_extractor_patches( - remote_sample: Callable, track_tmp_path: Path + remote_sample: Callable, ) -> None: """Tests DeepFeatureExtractor on image patches.""" extractor = DeepFeatureExtractor( @@ -44,15 +44,14 @@ def test_feature_extractor_patches( return_labels=False, device=device, patch_mode=True, - save_dir=track_tmp_path / "wsi_out_check", ) - output_ = zarr.open(output, mode="r") - - assert 0.48 < np.mean(output_["probabilities"][:]) < 0.52 + assert 0.48 < np.mean(output["probabilities"][:]) < 0.52 with pytest.raises( - ValueError, match=r".*Only zarr output is supported for `DeepFeatureExtractor`" + ValueError, + match=r".*output_type: `annotationstore` is not supported " + r"for `DeepFeatureExtractor` engine", ): _ = extractor.run( images=inputs, @@ -60,6 +59,7 @@ def test_feature_extractor_patches( return_labels=False, device=device, patch_mode=True, + output_type="annotationstore", ) From 38f84fbec5ad7fba519449212d47246312c70a9c Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 10 Nov 2025 17:12:44 +0000 Subject: [PATCH 10/15] [skip ci] :memo: Update docstring --- .../models/engine/deep_feature_extractor.py | 41 +++++++++++++++---- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/tiatoolbox/models/engine/deep_feature_extractor.py b/tiatoolbox/models/engine/deep_feature_extractor.py index 609fbe176..4e79c74ba 100644 --- a/tiatoolbox/models/engine/deep_feature_extractor.py +++ b/tiatoolbox/models/engine/deep_feature_extractor.py @@ -35,7 +35,30 @@ def save_to_cache( coordinates_zarr: zarr.Array, save_path: str | Path = "temp.zarr", ) -> tuple[zarr.Array, zarr.Array]: - """Save to cache.""" + """Save computed feature and coordinate arrays to Zarr cache. + + This function computes the given Dask arrays (`probabilities` and `coordinates`), + resizes the corresponding Zarr datasets to accommodate the new data, and appends + the results. If the Zarr datasets do not exist, it initializes them within the + specified Zarr group. + + Args: + probabilities (list[dask.array.Array]): + List of Dask arrays representing extracted feature maps. + coordinates (list[dask.array.Array]): + List of Dask arrays representing patch coordinates. + probabilities_zarr (zarr.Array | None): + Existing Zarr dataset for feature maps. If None, a new one is created. + coordinates_zarr (zarr.Array | None): + Existing Zarr dataset for coordinates. If None, a new one is created. + save_path (str | Path): + Path to the Zarr group for saving datasets. Defaults to "temp.zarr". + + Returns: + tuple[zarr.Array, zarr.Array]: + Updated Zarr datasets for feature maps and coordinates. + + """ if len(probabilities) == 0: return probabilities_zarr, coordinates_zarr @@ -165,18 +188,21 @@ def infer_wsi( """Perform model inference on a whole slide image (WSI). This method processes a WSI using the provided DataLoader and extracts - deep features from each patch using the model. The extracted features - are returned as a Dask array along with the corresponding patch coordinates. + deep features from each patch using the model. It supports memory-aware + caching by spilling intermediate results to disk when memory usage exceeds + a specified threshold. The final output includes feature maps and their + corresponding spatial coordinates. Args: dataloader (DataLoader): PyTorch DataLoader configured for WSI processing. save_path (Path): - Path to save the intermediate output. (Unused in this implementation.) + Path to save intermediate Zarr output. Used for caching. **kwargs (SemanticSegmentorRunParams): Additional runtime parameters, including: - return_probabilities (bool): Whether to return feature maps. - - memory_threshold (int): Memory usage threshold for caching. + - memory_threshold (int): Memory usage threshold (%) to trigger + disk caching. Returns: dict[str, dask.array.Array]: @@ -385,7 +411,7 @@ def _update_run_params( ioconfig (IOSegmentorConfig | None): IO configuration for patch extraction and resolution. output_type (str): - Desired output format. Must be "zarr". + Desired output format. Must be "zarr" or "dict". overwrite (bool): Whether to overwrite existing output files. Default is False. patch_mode (bool): @@ -399,7 +425,8 @@ def _update_run_params( Raises: ValueError: - If `output_type` is not "zarr", which is the only supported format. + If `output_type` is not "zarr" or "dict", which are the + only supported formats. """ if output_type not in ["zarr", "dict"]: From 3ab5f6801e86dee9f76452f9ebc6c6a36e425055 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 11 Nov 2025 10:19:15 +0000 Subject: [PATCH 11/15] :sparkles: Add command line interface to deep feature extractor --- tests/engines/test_feature_extractor.py | 27 +++++ tests/engines/test_semantic_segmentor.py | 2 +- tiatoolbox/cli/__init__.py | 2 + tiatoolbox/cli/deep_feature_extractor.py | 113 ++++++++++++++++++ .../models/engine/deep_feature_extractor.py | 39 +++++- 5 files changed, 181 insertions(+), 2 deletions(-) create mode 100644 tiatoolbox/cli/deep_feature_extractor.py diff --git a/tests/engines/test_feature_extractor.py b/tests/engines/test_feature_extractor.py index 96d1a0361..31976ad0d 100644 --- a/tests/engines/test_feature_extractor.py +++ b/tests/engines/test_feature_extractor.py @@ -8,7 +8,9 @@ import pytest import torch import zarr +from click.testing import CliRunner +from tiatoolbox import cli from tiatoolbox.models import IOSegmentorConfig from tiatoolbox.models.architecture.vanilla import CNNBackbone, TimmBackbone from tiatoolbox.models.engine.deep_feature_extractor import DeepFeatureExtractor @@ -203,3 +205,28 @@ def test_multi_gpu_feature_extraction( features = output_["probabilities"] assert len(positions.shape) == 2 assert len(features.shape) == 2 + + +# ------------------------------------------------------------------------------------- +# Command Line Interface +# ------------------------------------------------------------------------------------- + + +def test_cli_model_single_file(sample_svs: Path, track_tmp_path: Path) -> None: + """Test for feature extractor CLI single file.""" + runner = CliRunner() + models_wsi_result = runner.invoke( + cli.main, + [ + "deep-feature-extractor", + "--img-input", + str(sample_svs), + "--patch-mode", + "False", + "--output-path", + str(track_tmp_path / "output"), + ], + ) + + assert models_wsi_result.exit_code == 0 + assert (track_tmp_path / "output" / (sample_svs.stem + ".zarr")).exists() diff --git a/tests/engines/test_semantic_segmentor.py b/tests/engines/test_semantic_segmentor.py index ae37a074a..14d492a8c 100644 --- a/tests/engines/test_semantic_segmentor.py +++ b/tests/engines/test_semantic_segmentor.py @@ -498,7 +498,7 @@ def test_wsi_segmentor_annotationstore( def test_cli_model_single_file(sample_svs: Path, track_tmp_path: Path) -> None: - """Test for models CLI single file.""" + """Test semantic segmentor CLI single file.""" runner = CliRunner() models_wsi_result = runner.invoke( cli.main, diff --git a/tiatoolbox/cli/__init__.py b/tiatoolbox/cli/__init__.py index 38c69aa85..b11f31f96 100644 --- a/tiatoolbox/cli/__init__.py +++ b/tiatoolbox/cli/__init__.py @@ -7,6 +7,7 @@ from tiatoolbox import __version__ from tiatoolbox.cli.common import tiatoolbox_cli +from tiatoolbox.cli.deep_feature_extractor import deep_feature_extractor from tiatoolbox.cli.nucleus_instance_segment import nucleus_instance_segment from tiatoolbox.cli.patch_predictor import patch_predictor from tiatoolbox.cli.read_bounds import read_bounds @@ -43,6 +44,7 @@ def main() -> click.BaseCommand: main.add_command(read_bounds) main.add_command(save_tiles) main.add_command(semantic_segmentor) +main.add_command(deep_feature_extractor) main.add_command(slide_info) main.add_command(slide_thumbnail) main.add_command(tissue_mask) diff --git a/tiatoolbox/cli/deep_feature_extractor.py b/tiatoolbox/cli/deep_feature_extractor.py new file mode 100644 index 000000000..327fcc58a --- /dev/null +++ b/tiatoolbox/cli/deep_feature_extractor.py @@ -0,0 +1,113 @@ +"""Command line interface for deep feature extractor.""" + +from __future__ import annotations + +from tiatoolbox.cli.common import ( + cli_auto_get_mask, + cli_batch_size, + cli_device, + cli_file_type, + cli_img_input, + cli_masks, + cli_memory_threshold, + cli_model, + cli_num_workers, + cli_output_path, + cli_output_type, + cli_patch_mode, + cli_return_labels, + cli_return_probabilities, + cli_verbose, + cli_weights, + cli_yaml_config_path, + prepare_ioconfig, + prepare_model_cli, + tiatoolbox_cli, +) + + +@tiatoolbox_cli.command() +@cli_img_input() +@cli_output_path( + usage_help="Output directory where model features will be saved.", + default="deep_feature_extractor", +) +@cli_file_type( + default="*.png, *.jpg, *.jpeg, *.tif, *.tiff, *.svs, *.ndpi, *.jp2, *.mrxs", +) +@cli_model(default="fcn-tissue_mask") +@cli_weights() +@cli_device(default="cpu") +@cli_batch_size(default=1) +@cli_yaml_config_path() +@cli_masks(default=None) +@cli_num_workers(default=0) +@cli_output_type( + default="zarr", +) +@cli_memory_threshold(default=80) +@cli_patch_mode(default=False) +@cli_return_probabilities(default=True) +@cli_return_labels(default=False) +@cli_auto_get_mask(default=True) +@cli_verbose(default=True) +def deep_feature_extractor( + model: str, + weights: str, + img_input: str, + file_types: str, + masks: str | None, + output_path: str, + batch_size: int, + yaml_config_path: str, + num_workers: int, + device: str, + output_type: str, + memory_threshold: int, + *, + patch_mode: bool, + return_probabilities: bool, + return_labels: bool, + auto_get_mask: bool, + verbose: bool, +) -> None: + """Process a set of input images with a deep feature extractor engine.""" + from tiatoolbox.models import ( # noqa: PLC0415 + DeepFeatureExtractor, + IOSegmentorConfig, + ) + + files_all, masks_all, output_path = prepare_model_cli( + img_input=img_input, + output_path=output_path, + masks=masks, + file_types=file_types, + ) + + ioconfig = prepare_ioconfig( + IOSegmentorConfig, + pretrained_weights=weights, + yaml_config_path=yaml_config_path, + ) + + extractor = DeepFeatureExtractor( + model=model, + weights=weights, + batch_size=batch_size, + num_workers=num_workers, + verbose=verbose, + ) + + _ = extractor.run( + images=files_all, + masks=masks_all, + patch_mode=patch_mode, + ioconfig=ioconfig, + device=device, + save_dir=output_path, + output_type=output_type, + return_probabilities=return_probabilities, + return_labels=return_labels, + auto_get_mask=auto_get_mask, + memory_threshold=memory_threshold, + ) diff --git a/tiatoolbox/models/engine/deep_feature_extractor.py b/tiatoolbox/models/engine/deep_feature_extractor.py index 4e79c74ba..3fc307f73 100644 --- a/tiatoolbox/models/engine/deep_feature_extractor.py +++ b/tiatoolbox/models/engine/deep_feature_extractor.py @@ -1,4 +1,41 @@ -"""Define DeepFeatureExtractor class.""" +"""Deep Feature Extraction Engine for Digital Pathology. + +This module defines the `DeepFeatureExtractor` class, which extends +`SemanticSegmentor` to extract intermediate CNN feature representations +from whole slide images (WSIs) or image patches. Unlike segmentation +or classification engines, this extractor focuses on generating feature +embeddings for downstream tasks such as clustering, visualization, or +training other machine learning models. + +Key Components: +--------------- +Functions: + - save_to_cache: + Utility to spill intermediate feature and coordinate arrays to + disk using Zarr for memory-efficient processing. + +Classes: + - DeepFeatureExtractor: + Core engine for extracting deep features from WSIs or patches. + Supports memory-aware caching and outputs in Zarr format. + +Features: +--------- +- Handles large-scale WSIs with memory-aware caching. +- Outputs feature maps and patch coordinates for downstream analysis. +- Compatible with TIAToolbox pretrained models and custom PyTorch models. +- Supports both patch-based and WSI-based workflows. + +Example: +-------- +>>> from tiatoolbox.models.engine.deep_feature_extractor import DeepFeatureExtractor +>>> extractor = DeepFeatureExtractor(model="resnet50-kather100k") +>>> wsis = ["slide1.svs", "slide2.svs"] +>>> output = extractor.run(wsis, patch_mode=False, output_type="zarr") +>>> print(output) +'/path/to/output.zarr' + +""" from __future__ import annotations From 227e3177743137de3e9df2ef139b0f1becc2bd9b Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 11 Nov 2025 10:40:46 +0000 Subject: [PATCH 12/15] :white_check_mark: Improve coverage --- tests/engines/test_feature_extractor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/engines/test_feature_extractor.py b/tests/engines/test_feature_extractor.py index 31976ad0d..a3d806aac 100644 --- a/tests/engines/test_feature_extractor.py +++ b/tests/engines/test_feature_extractor.py @@ -69,7 +69,7 @@ def test_feature_extractor_wsi(remote_sample: Callable, track_tmp_path: Path) -> """Test feature extraction with DeepFeatureExtractor engine.""" save_dir = track_tmp_path / "output" # # convert to pathlib Path to prevent wsireader complaint - mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) + mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs")) # * test providing pretrained from torch vs pretrained_model.yaml shutil.rmtree(save_dir, ignore_errors=True) # default output dir test @@ -82,7 +82,7 @@ def test_feature_extractor_wsi(remote_sample: Callable, track_tmp_path: Path) -> device=device, patch_mode=False, save_dir=track_tmp_path / "wsi_out_check", - batch_size=2, + batch_size=1, output_type="zarr", memory_threshold=1, ) From 62cfe01cb1a2a484bab07dcd0f7ed3bc4a738457 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 12 Nov 2025 13:58:30 +0000 Subject: [PATCH 13/15] :bug: Address Co-Pilot suggestions. --- tests/engines/test_feature_extractor.py | 2 +- tiatoolbox/models/engine/deep_feature_extractor.py | 13 ++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/engines/test_feature_extractor.py b/tests/engines/test_feature_extractor.py index a3d806aac..85e03eef9 100644 --- a/tests/engines/test_feature_extractor.py +++ b/tests/engines/test_feature_extractor.py @@ -89,7 +89,7 @@ def test_feature_extractor_wsi(remote_sample: Callable, track_tmp_path: Path) -> output_ = zarr.open(output[mini_wsi_svs], mode="r") assert len(output_["coordinates"].shape) == 2 - assert len(output_["probabilities"].shape) + assert len(output_["probabilities"].shape) == 4 @pytest.mark.parametrize( diff --git a/tiatoolbox/models/engine/deep_feature_extractor.py b/tiatoolbox/models/engine/deep_feature_extractor.py index 3fc307f73..85c3afe84 100644 --- a/tiatoolbox/models/engine/deep_feature_extractor.py +++ b/tiatoolbox/models/engine/deep_feature_extractor.py @@ -68,8 +68,8 @@ def save_to_cache( probabilities: list[da.Array], coordinates: list[da.Array], - probabilities_zarr: zarr.Array, - coordinates_zarr: zarr.Array, + probabilities_zarr: zarr.Array | None, + coordinates_zarr: zarr.Array | None, save_path: str | Path = "temp.zarr", ) -> tuple[zarr.Array, zarr.Array]: """Save computed feature and coordinate arrays to Zarr cache. @@ -217,7 +217,7 @@ def __init__( self.process_prediction_per_batch = False def infer_wsi( - self: SemanticSegmentor, + self: DeepFeatureExtractor, dataloader: DataLoader, save_path: Path, **kwargs: Unpack[SemanticSegmentorRunParams], @@ -251,7 +251,6 @@ def infer_wsi( # Default Memory threshold percentage is 80. memory_threshold = kwargs.get("memory_threshold", 80) vm = psutil.virtual_memory() - _ = save_path keys = ["probabilities", "coordinates"] probabilities, coordinates = [], [] @@ -376,7 +375,7 @@ def post_process_patches( return raw_predictions def save_predictions( - self: SemanticSegmentor, + self: DeepFeatureExtractor, processed_predictions: dict, output_type: str, save_path: Path | None = None, @@ -418,7 +417,7 @@ def save_predictions( ) def _update_run_params( - self: SemanticSegmentor, + self: DeepFeatureExtractor, images: list[os.PathLike | Path | WSIReader] | np.ndarray, masks: list[os.PathLike | Path] | np.ndarray | None = None, labels: list | None = None, @@ -539,7 +538,7 @@ def run( Raises: ValueError: - If `output_type` is not "zarr". + If `output_type` is not "zarr" or "dict". """ # return_probabilities is always True for FeatureExtractor. kwargs["return_probabilities"] = True From 4e62d4a93529d5ebdc79ffe5622def7cc06be260 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 13 Nov 2025 11:48:26 +0000 Subject: [PATCH 14/15] :bug: Fix test assertion --- tests/engines/test_feature_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/engines/test_feature_extractor.py b/tests/engines/test_feature_extractor.py index 85e03eef9..c705dc431 100644 --- a/tests/engines/test_feature_extractor.py +++ b/tests/engines/test_feature_extractor.py @@ -89,7 +89,7 @@ def test_feature_extractor_wsi(remote_sample: Callable, track_tmp_path: Path) -> output_ = zarr.open(output[mini_wsi_svs], mode="r") assert len(output_["coordinates"].shape) == 2 - assert len(output_["probabilities"].shape) == 4 + assert len(output_["probabilities"].shape) == 3 @pytest.mark.parametrize( From 6c3b82106938e0042cee8e830acc47b4c890af69 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 14 Nov 2025 12:30:18 +0000 Subject: [PATCH 15/15] :memo: Use features instead of probabilities in the ouptut. Using features in inference will require major change in the base class. --- tests/engines/test_feature_extractor.py | 8 ++++---- tiatoolbox/models/engine/deep_feature_extractor.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/engines/test_feature_extractor.py b/tests/engines/test_feature_extractor.py index c705dc431..830d34bd7 100644 --- a/tests/engines/test_feature_extractor.py +++ b/tests/engines/test_feature_extractor.py @@ -48,7 +48,7 @@ def test_feature_extractor_patches( patch_mode=True, ) - assert 0.48 < np.mean(output["probabilities"][:]) < 0.52 + assert 0.48 < np.mean(output["features"][:]) < 0.52 with pytest.raises( ValueError, @@ -89,7 +89,7 @@ def test_feature_extractor_wsi(remote_sample: Callable, track_tmp_path: Path) -> output_ = zarr.open(output[mini_wsi_svs], mode="r") assert len(output_["coordinates"].shape) == 2 - assert len(output_["probabilities"].shape) == 3 + assert len(output_["features"].shape) == 3 @pytest.mark.parametrize( @@ -132,7 +132,7 @@ def test_full_inference( output_ = zarr.open(output[mini_wsi_svs], mode="r") positions = output_["coordinates"] - features = output_["probabilities"] + features = output_["features"] reader = WSIReader.open(mini_wsi_svs) patches = [ @@ -202,7 +202,7 @@ def test_multi_gpu_feature_extraction( output_ = zarr.open(output[mini_wsi_svs], mode="r") positions = output_["coordinates"] - features = output_["probabilities"] + features = output_["features"] assert len(positions.shape) == 2 assert len(features.shape) == 2 diff --git a/tiatoolbox/models/engine/deep_feature_extractor.py b/tiatoolbox/models/engine/deep_feature_extractor.py index 85c3afe84..c0bd2aaa0 100644 --- a/tiatoolbox/models/engine/deep_feature_extractor.py +++ b/tiatoolbox/models/engine/deep_feature_extractor.py @@ -412,6 +412,7 @@ def save_predictions( """ # no need to compute predictions self.drop_keys.append("predictions") + processed_predictions["features"] = processed_predictions.pop("probabilities") return super().save_predictions( processed_predictions, output_type, save_path=save_path, **kwargs )