diff --git a/tests/engines/test_feature_extractor.py b/tests/engines/test_feature_extractor.py new file mode 100644 index 000000000..830d34bd7 --- /dev/null +++ b/tests/engines/test_feature_extractor.py @@ -0,0 +1,232 @@ +"""Test for feature extractor.""" + +import shutil +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import pytest +import torch +import zarr +from click.testing import CliRunner + +from tiatoolbox import cli +from tiatoolbox.models import IOSegmentorConfig +from tiatoolbox.models.architecture.vanilla import CNNBackbone, TimmBackbone +from tiatoolbox.models.engine.deep_feature_extractor import DeepFeatureExtractor +from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils.misc import select_device +from tiatoolbox.wsicore.wsireader import WSIReader + +ON_GPU = not toolbox_env.running_on_ci() and toolbox_env.has_gpu() + +# ------------------------------------------------------------------------------------- +# Engine +# ------------------------------------------------------------------------------------- + +device = "cuda" if toolbox_env.has_gpu() else "cpu" + + +def test_feature_extractor_patches( + remote_sample: Callable, +) -> None: + """Tests DeepFeatureExtractor on image patches.""" + extractor = DeepFeatureExtractor( + model="fcn-tissue_mask", batch_size=32, verbose=False, device=device + ) + + sample_image = remote_sample("thumbnail-1k-1k") + + inputs = [sample_image, sample_image] + + assert not extractor.patch_mode + output = extractor.run( + images=inputs, + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=True, + ) + + assert 0.48 < np.mean(output["features"][:]) < 0.52 + + with pytest.raises( + ValueError, + match=r".*output_type: `annotationstore` is not supported " + r"for `DeepFeatureExtractor` engine", + ): + _ = extractor.run( + images=inputs, + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=True, + output_type="annotationstore", + ) + + +def test_feature_extractor_wsi(remote_sample: Callable, track_tmp_path: Path) -> None: + """Test feature extraction with DeepFeatureExtractor engine.""" + save_dir = track_tmp_path / "output" + # # convert to pathlib Path to prevent wsireader complaint + mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs")) + + # * test providing pretrained from torch vs pretrained_model.yaml + shutil.rmtree(save_dir, ignore_errors=True) # default output dir test + + extractor = DeepFeatureExtractor(batch_size=1, model="fcn-tissue_mask") + output = extractor.run( + images=[mini_wsi_svs], + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=False, + save_dir=track_tmp_path / "wsi_out_check", + batch_size=1, + output_type="zarr", + memory_threshold=1, + ) + + output_ = zarr.open(output[mini_wsi_svs], mode="r") + assert len(output_["coordinates"].shape) == 2 + assert len(output_["features"].shape) == 3 + + +@pytest.mark.parametrize( + "model", [CNNBackbone("resnet50"), TimmBackbone("efficientnet_b0", pretrained=True)] +) +def test_full_inference( + remote_sample: Callable, track_tmp_path: Path, model: Callable +) -> None: + """Test full inference with CNNBackbone and TimmBackbone models.""" + save_dir = track_tmp_path / "output" + # pre-emptive clean up + shutil.rmtree(save_dir, ignore_errors=True) # default output dir test + + mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) + + ioconfig = IOSegmentorConfig( + input_resolutions=[ + {"units": "mpp", "resolution": 0.25}, + ], + output_resolutions=[ + {"units": "mpp", "resolution": 0.25}, + ], + patch_input_shape=[512, 512], + patch_output_shape=[512, 512], + stride_shape=[256, 256], + save_resolution={"units": "mpp", "resolution": 8.0}, + ) + + extractor = DeepFeatureExtractor(batch_size=4, model=model) + output = extractor.run( + images=[mini_wsi_svs], + device=device, + save_dir=track_tmp_path / "wsi_out_check", + batch_size=4, + output_type="zarr", + ioconfig=ioconfig, + patch_mode=False, + ) + + output_ = zarr.open(output[mini_wsi_svs], mode="r") + + positions = output_["coordinates"] + features = output_["features"] + + reader = WSIReader.open(mini_wsi_svs) + patches = [ + reader.read_bounds( + positions[patch_idx], + resolution=0.25, + units="mpp", + pad_constant_values=255, + coord_space="resolution", + ) + for patch_idx in range(4) + ] + patches = np.array(patches) + patches = torch.from_numpy(patches) # NHWC + patches = patches.permute(0, 3, 1, 2).contiguous() # NCHW + patches = patches.to(device).type(torch.float32) + model = extractor.model + # Inference mode + model.eval() + with torch.inference_mode(): + _features = model(patches).cpu().numpy() + # ! must maintain same batch size and likely same ordering + # ! else the output values will not exactly be the same (still < 1.0e-4 + # ! of epsilon though) + assert np.mean(np.abs(features[:4] - _features)) < 1.0e-1 + + +@pytest.mark.skipif( + toolbox_env.running_on_ci() or not ON_GPU, + reason="Local test on machine with GPU.", +) +def test_multi_gpu_feature_extraction( + remote_sample: Callable, track_tmp_path: Path +) -> None: + """Local functionality test for feature extraction using multiple GPUs.""" + save_dir = track_tmp_path / "output" + mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) + shutil.rmtree(save_dir, ignore_errors=True) + + # Use multiple GPUs + device = select_device(on_gpu=ON_GPU) + + wsi_ioconfig = IOSegmentorConfig( + input_resolutions=[{"units": "mpp", "resolution": 0.5}], + patch_input_shape=[224, 224], + output_resolutions=[{"units": "mpp", "resolution": 0.5}], + patch_output_shape=[224, 224], + stride_shape=[224, 224], + ) + + model = TimmBackbone(backbone="UNI", pretrained=True) + extractor = DeepFeatureExtractor( + model=model, + batch_size=32, + num_workers=4, + ) + + output = extractor.run( + [mini_wsi_svs], + patch_mode=False, + device=device, + ioconfig=wsi_ioconfig, + save_dir=save_dir, + auto_get_mask=True, + output_type="zarr", + ) + output_ = zarr.open(output[mini_wsi_svs], mode="r") + + positions = output_["coordinates"] + features = output_["features"] + assert len(positions.shape) == 2 + assert len(features.shape) == 2 + + +# ------------------------------------------------------------------------------------- +# Command Line Interface +# ------------------------------------------------------------------------------------- + + +def test_cli_model_single_file(sample_svs: Path, track_tmp_path: Path) -> None: + """Test for feature extractor CLI single file.""" + runner = CliRunner() + models_wsi_result = runner.invoke( + cli.main, + [ + "deep-feature-extractor", + "--img-input", + str(sample_svs), + "--patch-mode", + "False", + "--output-path", + str(track_tmp_path / "output"), + ], + ) + + assert models_wsi_result.exit_code == 0 + assert (track_tmp_path / "output" / (sample_svs.stem + ".zarr")).exists() diff --git a/tests/engines/test_semantic_segmentor.py b/tests/engines/test_semantic_segmentor.py index ae37a074a..14d492a8c 100644 --- a/tests/engines/test_semantic_segmentor.py +++ b/tests/engines/test_semantic_segmentor.py @@ -498,7 +498,7 @@ def test_wsi_segmentor_annotationstore( def test_cli_model_single_file(sample_svs: Path, track_tmp_path: Path) -> None: - """Test for models CLI single file.""" + """Test semantic segmentor CLI single file.""" runner = CliRunner() models_wsi_result = runner.invoke( cli.main, diff --git a/tiatoolbox/cli/__init__.py b/tiatoolbox/cli/__init__.py index 38c69aa85..b11f31f96 100644 --- a/tiatoolbox/cli/__init__.py +++ b/tiatoolbox/cli/__init__.py @@ -7,6 +7,7 @@ from tiatoolbox import __version__ from tiatoolbox.cli.common import tiatoolbox_cli +from tiatoolbox.cli.deep_feature_extractor import deep_feature_extractor from tiatoolbox.cli.nucleus_instance_segment import nucleus_instance_segment from tiatoolbox.cli.patch_predictor import patch_predictor from tiatoolbox.cli.read_bounds import read_bounds @@ -43,6 +44,7 @@ def main() -> click.BaseCommand: main.add_command(read_bounds) main.add_command(save_tiles) main.add_command(semantic_segmentor) +main.add_command(deep_feature_extractor) main.add_command(slide_info) main.add_command(slide_thumbnail) main.add_command(tissue_mask) diff --git a/tiatoolbox/cli/deep_feature_extractor.py b/tiatoolbox/cli/deep_feature_extractor.py new file mode 100644 index 000000000..327fcc58a --- /dev/null +++ b/tiatoolbox/cli/deep_feature_extractor.py @@ -0,0 +1,113 @@ +"""Command line interface for deep feature extractor.""" + +from __future__ import annotations + +from tiatoolbox.cli.common import ( + cli_auto_get_mask, + cli_batch_size, + cli_device, + cli_file_type, + cli_img_input, + cli_masks, + cli_memory_threshold, + cli_model, + cli_num_workers, + cli_output_path, + cli_output_type, + cli_patch_mode, + cli_return_labels, + cli_return_probabilities, + cli_verbose, + cli_weights, + cli_yaml_config_path, + prepare_ioconfig, + prepare_model_cli, + tiatoolbox_cli, +) + + +@tiatoolbox_cli.command() +@cli_img_input() +@cli_output_path( + usage_help="Output directory where model features will be saved.", + default="deep_feature_extractor", +) +@cli_file_type( + default="*.png, *.jpg, *.jpeg, *.tif, *.tiff, *.svs, *.ndpi, *.jp2, *.mrxs", +) +@cli_model(default="fcn-tissue_mask") +@cli_weights() +@cli_device(default="cpu") +@cli_batch_size(default=1) +@cli_yaml_config_path() +@cli_masks(default=None) +@cli_num_workers(default=0) +@cli_output_type( + default="zarr", +) +@cli_memory_threshold(default=80) +@cli_patch_mode(default=False) +@cli_return_probabilities(default=True) +@cli_return_labels(default=False) +@cli_auto_get_mask(default=True) +@cli_verbose(default=True) +def deep_feature_extractor( + model: str, + weights: str, + img_input: str, + file_types: str, + masks: str | None, + output_path: str, + batch_size: int, + yaml_config_path: str, + num_workers: int, + device: str, + output_type: str, + memory_threshold: int, + *, + patch_mode: bool, + return_probabilities: bool, + return_labels: bool, + auto_get_mask: bool, + verbose: bool, +) -> None: + """Process a set of input images with a deep feature extractor engine.""" + from tiatoolbox.models import ( # noqa: PLC0415 + DeepFeatureExtractor, + IOSegmentorConfig, + ) + + files_all, masks_all, output_path = prepare_model_cli( + img_input=img_input, + output_path=output_path, + masks=masks, + file_types=file_types, + ) + + ioconfig = prepare_ioconfig( + IOSegmentorConfig, + pretrained_weights=weights, + yaml_config_path=yaml_config_path, + ) + + extractor = DeepFeatureExtractor( + model=model, + weights=weights, + batch_size=batch_size, + num_workers=num_workers, + verbose=verbose, + ) + + _ = extractor.run( + images=files_all, + masks=masks_all, + patch_mode=patch_mode, + ioconfig=ioconfig, + device=device, + save_dir=output_path, + output_type=output_type, + return_probabilities=return_probabilities, + return_labels=return_labels, + auto_get_mask=auto_get_mask, + memory_threshold=memory_threshold, + ) diff --git a/tiatoolbox/models/__init__.py b/tiatoolbox/models/__init__.py index 5de543aad..4265f7caf 100644 --- a/tiatoolbox/models/__init__.py +++ b/tiatoolbox/models/__init__.py @@ -11,6 +11,7 @@ from .architecture.nuclick import NuClick from .architecture.sccnn import SCCNN from .dataset import PatchDataset, WSIPatchDataset, WSIStreamDataset +from .engine.deep_feature_extractor import DeepFeatureExtractor from .engine.io_config import ( IOInstanceSegmentorConfig, IOPatchPredictorConfig, @@ -24,6 +25,7 @@ __all__ = [ "SCCNN", + "DeepFeatureExtractor", "HoVerNet", "HoVerNetPlus", "IDaRS", diff --git a/tiatoolbox/models/engine/__init__.py b/tiatoolbox/models/engine/__init__.py index 9c00ac4a2..2509b51a3 100644 --- a/tiatoolbox/models/engine/__init__.py +++ b/tiatoolbox/models/engine/__init__.py @@ -1,6 +1,7 @@ """Engines to run models implemented in tiatoolbox.""" from . import ( + deep_feature_extractor, engine_abc, nucleus_instance_segmentor, patch_predictor, @@ -8,6 +9,7 @@ ) __all__ = [ + "deep_feature_extractor", "engine_abc", "nucleus_instance_segmentor", "patch_predictor", diff --git a/tiatoolbox/models/engine/deep_feature_extractor.py b/tiatoolbox/models/engine/deep_feature_extractor.py new file mode 100644 index 000000000..c0bd2aaa0 --- /dev/null +++ b/tiatoolbox/models/engine/deep_feature_extractor.py @@ -0,0 +1,557 @@ +"""Deep Feature Extraction Engine for Digital Pathology. + +This module defines the `DeepFeatureExtractor` class, which extends +`SemanticSegmentor` to extract intermediate CNN feature representations +from whole slide images (WSIs) or image patches. Unlike segmentation +or classification engines, this extractor focuses on generating feature +embeddings for downstream tasks such as clustering, visualization, or +training other machine learning models. + +Key Components: +--------------- +Functions: + - save_to_cache: + Utility to spill intermediate feature and coordinate arrays to + disk using Zarr for memory-efficient processing. + +Classes: + - DeepFeatureExtractor: + Core engine for extracting deep features from WSIs or patches. + Supports memory-aware caching and outputs in Zarr format. + +Features: +--------- +- Handles large-scale WSIs with memory-aware caching. +- Outputs feature maps and patch coordinates for downstream analysis. +- Compatible with TIAToolbox pretrained models and custom PyTorch models. +- Supports both patch-based and WSI-based workflows. + +Example: +-------- +>>> from tiatoolbox.models.engine.deep_feature_extractor import DeepFeatureExtractor +>>> extractor = DeepFeatureExtractor(model="resnet50-kather100k") +>>> wsis = ["slide1.svs", "slide2.svs"] +>>> output = extractor.run(wsis, patch_mode=False, output_type="zarr") +>>> print(output) +'/path/to/output.zarr' + +""" + +from __future__ import annotations + +import gc +from typing import TYPE_CHECKING + +import dask.array as da +import psutil +import zarr +from dask import compute +from typing_extensions import Unpack + +from tiatoolbox.utils.misc import get_tqdm + +from .semantic_segmentor import SemanticSegmentor, SemanticSegmentorRunParams + +if TYPE_CHECKING: # pragma: no cover + import os + from pathlib import Path + + import numpy as np + from torch.utils.data import DataLoader + + from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.models.engine.io_config import IOSegmentorConfig + from tiatoolbox.models.models_abc import ModelABC + from tiatoolbox.wsicore import WSIReader + + +def save_to_cache( + probabilities: list[da.Array], + coordinates: list[da.Array], + probabilities_zarr: zarr.Array | None, + coordinates_zarr: zarr.Array | None, + save_path: str | Path = "temp.zarr", +) -> tuple[zarr.Array, zarr.Array]: + """Save computed feature and coordinate arrays to Zarr cache. + + This function computes the given Dask arrays (`probabilities` and `coordinates`), + resizes the corresponding Zarr datasets to accommodate the new data, and appends + the results. If the Zarr datasets do not exist, it initializes them within the + specified Zarr group. + + Args: + probabilities (list[dask.array.Array]): + List of Dask arrays representing extracted feature maps. + coordinates (list[dask.array.Array]): + List of Dask arrays representing patch coordinates. + probabilities_zarr (zarr.Array | None): + Existing Zarr dataset for feature maps. If None, a new one is created. + coordinates_zarr (zarr.Array | None): + Existing Zarr dataset for coordinates. If None, a new one is created. + save_path (str | Path): + Path to the Zarr group for saving datasets. Defaults to "temp.zarr". + + Returns: + tuple[zarr.Array, zarr.Array]: + Updated Zarr datasets for feature maps and coordinates. + + """ + if len(probabilities) == 0: + return probabilities_zarr, coordinates_zarr + + coordinates = da.concatenate(coordinates, axis=0) + probabilities = da.concatenate(probabilities, axis=0) + + computed_values = compute(*[probabilities, coordinates]) + probabilities_computed, coordinates_computed = computed_values + + chunk_shape = tuple(chunk[0] for chunk in probabilities.chunks) + if probabilities_zarr is None: + zarr_group = zarr.open(str(save_path), mode="w") + + probabilities_zarr = zarr_group.create_dataset( + name="canvas", + shape=(0, *probabilities_computed.shape[1:]), + chunks=(chunk_shape[0], *probabilities_computed.shape[1:]), + dtype=probabilities_computed.dtype, + overwrite=True, + ) + + coordinates_zarr = zarr_group.create_dataset( + name="count", + shape=(0, *coordinates_computed.shape[1:]), + dtype=coordinates_computed.dtype, + chunks=(chunk_shape[0], *coordinates_computed.shape[1:]), + overwrite=True, + ) + + probabilities_zarr.resize( + ( + probabilities_zarr.shape[0] + probabilities_computed.shape[0], + *probabilities_zarr.shape[1:], + ) + ) + probabilities_zarr[-probabilities_computed.shape[0] :] = probabilities_computed + + coordinates_zarr.resize( + ( + coordinates_zarr.shape[0] + coordinates_computed.shape[0], + *coordinates_zarr.shape[1:], + ) + ) + coordinates_zarr[-coordinates_computed.shape[0] :] = coordinates_computed + + return probabilities_zarr, coordinates_zarr + + +class DeepFeatureExtractor(SemanticSegmentor): + r"""Generic CNN-based feature extractor for digital pathology images. + + This class extends :class:`SemanticSegmentor` to extract deep features from + whole slide images (WSIs) or image patches using a CNN model. It is designed + for use cases where the goal is to obtain intermediate feature representations + (e.g., embeddings) rather than final classification or segmentation outputs. + + The extracted features are returned or saved in Zarr format for downstream + analysis, such as clustering, visualization, or training other machine learning + models. + + Args: + model (str | ModelABC): + A PyTorch model instance or the name of a pretrained model from TIAToolbox. + batch_size (int): + Number of image patches processed per forward pass. Default is 8. + num_workers (int): + Number of workers for data loading. Default is 0. + weights (str | Path | None): + Path to model weights. If None, default weights are used. + device (str): + Device to run the model on (e.g., "cpu", "cuda"). Default is "cpu". + verbose (bool): + Whether to enable verbose logging. Default is True. + + Attributes: + process_prediction_per_batch (bool): + Flag to control whether predictions are processed per batch. + Default is False. + + """ + + def __init__( + self: DeepFeatureExtractor, + model: str | ModelABC, + batch_size: int = 8, + num_workers: int = 0, + weights: str | Path | None = None, + *, + device: str = "cpu", + verbose: bool = True, + ) -> None: + """Initialize :class:`DeepFeatureExtractor`. + + Args: + model (str | ModelABC): + A PyTorch model instance or the name of a pretrained model from + TIAToolbox. If a string is provided, the corresponding pretrained + weights will be downloaded unless overridden via `weights`. + batch_size (int): + Number of image patches processed per forward pass. Default is 8. + num_workers (int): + Number of workers for data loading. Default is 0. + weights (str | Path | None): + Path to model weights. If None, default weights are used. + device (str): + Device to run the model on (e.g., "cpu", "cuda"). Default is "cpu". + verbose (bool): + Whether to enable verbose logging. Default is True. + + """ + super().__init__( + model=model, + batch_size=batch_size, + num_workers=num_workers, + weights=weights, + device=device, + verbose=verbose, + ) + self.process_prediction_per_batch = False + + def infer_wsi( + self: DeepFeatureExtractor, + dataloader: DataLoader, + save_path: Path, + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> dict[str, da.Array]: + """Perform model inference on a whole slide image (WSI). + + This method processes a WSI using the provided DataLoader and extracts + deep features from each patch using the model. It supports memory-aware + caching by spilling intermediate results to disk when memory usage exceeds + a specified threshold. The final output includes feature maps and their + corresponding spatial coordinates. + + Args: + dataloader (DataLoader): + PyTorch DataLoader configured for WSI processing. + save_path (Path): + Path to save intermediate Zarr output. Used for caching. + **kwargs (SemanticSegmentorRunParams): + Additional runtime parameters, including: + - return_probabilities (bool): Whether to return feature maps. + - memory_threshold (int): Memory usage threshold (%) to trigger + disk caching. + + Returns: + dict[str, dask.array.Array]: + Dictionary containing: + - "probabilities": Extracted feature maps from the model. + - "coordinates": Patch coordinates corresponding to the features. + + """ + # Default Memory threshold percentage is 80. + memory_threshold = kwargs.get("memory_threshold", 80) + vm = psutil.virtual_memory() + keys = ["probabilities", "coordinates"] + probabilities, coordinates = [], [] + + # Main output dictionary + raw_predictions = dict( + zip(keys, [da.empty(shape=(0, 0))] * len(keys), strict=False) + ) + + # Inference loop + tqdm = get_tqdm() + tqdm_loop = ( + tqdm(dataloader, leave=False, desc="Inferring patches") + if self.verbose + else dataloader + ) + + probabilities_zarr, coordinates_zarr = None, None + + probabilities_used_percent = 0 + for batch_data in tqdm_loop: + batch_output = self.model.infer_batch( + self.model, + batch_data["image"], + device=self.device, + ) + + probabilities.append(da.from_array(batch_output[0])) + coordinates.append( + da.from_array( + self._get_coordinates(batch_data), + ) + ) + + used_percent = vm.percent + probabilities_used_percent = ( + probabilities_used_percent + (probabilities[-1].nbytes / vm.free) * 100 + ) + if ( + used_percent > memory_threshold + or probabilities_used_percent > memory_threshold + ): + tqdm_loop.desc = "Spill intermediate data to disk" + used_percent = ( + probabilities_used_percent + if (probabilities_used_percent > memory_threshold) + else used_percent + ) + msg = ( + f"Current Memory usage: {used_percent} % " + f"exceeds specified threshold: {memory_threshold}. " + f"Saving intermediate results to disk." + ) + + tqdm.write(msg) + # Flush data in Memory and clear dask graph + probabilities_zarr, coordinates_zarr = save_to_cache( + probabilities, + coordinates, + probabilities_zarr, + coordinates_zarr, + save_path=save_path, + ) + + probabilities, coordinates = [], [] + probabilities_used_percent = 0 + gc.collect() + tqdm_loop.desc = "Inferring patches" + + if probabilities_zarr is not None: + probabilities_zarr, coordinates_zarr = save_to_cache( + probabilities, + coordinates, + probabilities_zarr, + coordinates_zarr, + save_path=save_path, + ) + # Wrap zarr in dask array + raw_predictions["probabilities"] = da.from_zarr( + probabilities_zarr, chunks=probabilities_zarr.chunks + ) + raw_predictions["coordinates"] = da.from_zarr( + coordinates_zarr, chunks=coordinates_zarr.chunks + ) + else: + raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0) + raw_predictions["probabilities"] = da.concatenate(probabilities, axis=0) + + return raw_predictions + + def post_process_patches( + self: DeepFeatureExtractor, + raw_predictions: da.Array, + prediction_shape: tuple[int, ...], + prediction_dtype: type, + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> da.Array: + """Post-process raw patch predictions from model inference. + + This method overrides the base implementation to return raw feature maps + without applying any additional processing. It is intended for use cases + where intermediate CNN features are required as output. + + Args: + raw_predictions (dask.array.Array): + Raw model predictions as a Dask array. + prediction_shape (tuple[int, ...]): + Expected shape of the prediction output. + prediction_dtype (type): + Data type of the prediction output. + **kwargs (SemanticSegmentorRunParams): + Additional runtime parameters. + + Returns: + dask.array.Array: + Unmodified raw predictions. + + """ + _ = kwargs.get("return_probabilities") + _ = prediction_shape + _ = prediction_dtype + + return raw_predictions + + def save_predictions( + self: DeepFeatureExtractor, + processed_predictions: dict, + output_type: str, + save_path: Path | None = None, + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> dict | Path: + """Save patch-level feature predictions to disk or return them in memory. + + This method saves the extracted deep features in the specified output format. + Only the "zarr" format is supported for this engine. The method disables + saving the "predictions" key, as it is not relevant for feature extraction. + + Args: + processed_predictions (dict): + Dictionary containing processed model outputs. + output_type (str): + Desired output format. Must be "zarr". + save_path (Path | None): + Path to save the output file. Required for "zarr" format. + **kwargs (SemanticSegmentorRunParams): + Additional runtime parameters, including: + - output_file (str): Name of the output file. + - scale_factor (tuple[float, float]): For coordinate transformation. + - class_dict (dict): Optional class index-to-name mapping. + + Returns: + dict | Path: + - If `output_type` is "zarr": returns the path to the saved Zarr file. + - If `output_type` is "dict": returns predictions as a dictionary. + + Raises: + ValueError: + If an unsupported output format is provided. + + """ + # no need to compute predictions + self.drop_keys.append("predictions") + processed_predictions["features"] = processed_predictions.pop("probabilities") + return super().save_predictions( + processed_predictions, output_type, save_path=save_path, **kwargs + ) + + def _update_run_params( + self: DeepFeatureExtractor, + images: list[os.PathLike | Path | WSIReader] | np.ndarray, + masks: list[os.PathLike | Path] | np.ndarray | None = None, + labels: list | None = None, + save_dir: os.PathLike | Path | None = None, + ioconfig: IOSegmentorConfig | None = None, + output_type: str = "dict", + *, + overwrite: bool = False, + patch_mode: bool, + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> Path | None: + """Update runtime parameters for the DeepFeatureExtractor engine. + + This method sets internal attributes such as caching, batch size, + IO configuration, and output format based on user input and keyword arguments. + It also validates that the output format is supported. + + Args: + images (list[PathLike | WSIReader] | np.ndarray): + Input images or patches. + masks (list[PathLike] | np.ndarray | None): + Optional masks for WSI processing. + labels (list | None): + Optional labels for input images. + save_dir (PathLike | None): + Directory to save output files. Required for WSI mode. + ioconfig (IOSegmentorConfig | None): + IO configuration for patch extraction and resolution. + output_type (str): + Desired output format. Must be "zarr" or "dict". + overwrite (bool): + Whether to overwrite existing output files. Default is False. + patch_mode (bool): + Whether to treat input as patches (`True`) or WSIs (`False`). + **kwargs (SemanticSegmentorRunParams): + Additional runtime parameters. + + Returns: + Path | None: + Path to the save directory if applicable, otherwise None. + + Raises: + ValueError: + If `output_type` is not "zarr" or "dict", which are the + only supported formats. + + """ + if output_type not in ["zarr", "dict"]: + msg = ( + f"output_type: `{output_type}` is not supported for " + f"`DeepFeatureExtractor` engine." + ) + raise ValueError(msg) + + return super()._update_run_params( + images=images, + masks=masks, + labels=labels, + save_dir=save_dir, + ioconfig=ioconfig, + overwrite=overwrite, + patch_mode=patch_mode, + output_type=output_type, + **kwargs, + ) + + def run( + self: DeepFeatureExtractor, + images: list[os.PathLike | Path | WSIReader] | np.ndarray, + masks: list[os.PathLike | Path] | np.ndarray | None = None, + labels: list | None = None, + ioconfig: IOSegmentorConfig | None = None, + *, + patch_mode: bool = True, + save_dir: os.PathLike | Path | None = None, + overwrite: bool = False, + output_type: str = "dict", + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> AnnotationStore | Path | str | dict | list[Path]: + """Run the DeepFeatureExtractor engine on input images. + + This method orchestrates the full inference pipeline, including preprocessing, + model inference, and saving of extracted deep features. It supports both + patch-level and whole slide image (WSI) modes. The output is returned or saved + in Zarr format. + + Note: + The `return_probabilities` flag is always set to True for this engine, + as it is designed to extract intermediate feature maps. + + Args: + images (list[PathLike | WSIReader] | np.ndarray): + Input images or patches. Can be a list of file paths, WSIReader objects, + or a NumPy array of image patches. + masks (list[PathLike] | np.ndarray | None): + Optional masks for WSI processing. Only used when `patch_mode` is False. + labels (list | None): + Optional labels for input images. Only one label per image is supported. + ioconfig (IOSegmentorConfig | None): + IO configuration for patch extraction and resolution. + patch_mode (bool): + Whether to treat input as patches (`True`) or WSIs (`False`). + Default is True. + save_dir (PathLike | None): + Directory to save output files. Required for WSI mode. + overwrite (bool): + Whether to overwrite existing output files. Default is False. + output_type (str): + Desired output format. Must be "zarr" or "dict". + **kwargs (SemanticSegmentorRunParams): + Additional runtime parameters to update engine attributes. + + Returns: + AnnotationStore | Path | str | dict | list[Path]: + - If `patch_mode` is True: returns predictions or path to saved output. + - If `patch_mode` is False: returns a dictionary mapping each WSI + to its output path. + + Raises: + ValueError: + If `output_type` is not "zarr" or "dict". + """ + # return_probabilities is always True for FeatureExtractor. + kwargs["return_probabilities"] = True + + return super().run( + images=images, + masks=masks, + labels=labels, + ioconfig=ioconfig, + patch_mode=patch_mode, + save_dir=save_dir, + overwrite=overwrite, + output_type=output_type, + **kwargs, + ) diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index a33bcf028..1a0e4d95f 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -670,7 +670,7 @@ def _update_run_params( patch_mode: bool, **kwargs: Unpack[SemanticSegmentorRunParams], ) -> Path | None: - """Update runtime parameters for the PatchPredictor engine. + """Update runtime parameters for the SemanticSegmentor engine. This method sets internal attributes such as caching, batch size, IO configuration, and output format based on user input and keyword arguments.