Skip to content
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
cd368bd
:new: Define `DeepFeatureExtractor`
shaneahmed Oct 22, 2025
f14fdaa
:fire: Remove incorrect docstring
shaneahmed Nov 5, 2025
3fa9136
Merge branch 'dev-define-engines-abc' into dev-define-DeepFeatureExtr…
shaneahmed Nov 6, 2025
d3d0650
Merge remote-tracking branch 'origin/dev-define-DeepFeatureExtractor'…
shaneahmed Nov 6, 2025
aa4c812
:test_tube: Initial implementation
shaneahmed Nov 6, 2025
4df8ea4
:test_tube: Initial implementation
shaneahmed Nov 6, 2025
8460a2d
:white_check_mark: Add tests for `DeepFeatuureExtractor`
shaneahmed Nov 6, 2025
c9f0e59
:bug: Fix error due to inconsistent results
shaneahmed Nov 7, 2025
35c964b
:white_check_mark: Add tests for coverage and update docstrings.
shaneahmed Nov 7, 2025
4b6df14
:white_check_mark: Add cache support for large WSIs.
shaneahmed Nov 10, 2025
998ddcb
:white_check_mark: Add support for `dict` output.
shaneahmed Nov 10, 2025
38f84fb
[skip ci] :memo: Update docstring
shaneahmed Nov 10, 2025
3ab5f68
:sparkles: Add command line interface to deep feature extractor
shaneahmed Nov 11, 2025
227e317
:white_check_mark: Improve coverage
shaneahmed Nov 11, 2025
62cfe01
:bug: Address Co-Pilot suggestions.
shaneahmed Nov 12, 2025
4e62d4a
:bug: Fix test assertion
shaneahmed Nov 13, 2025
6c3b821
:memo: Use features instead of probabilities in the ouptut.
shaneahmed Nov 14, 2025
5c07200
Merge branch 'dev-define-engines-abc' into dev-define-DeepFeatureExtr…
shaneahmed Nov 17, 2025
d452bca
Merge branch 'dev-define-engines-abc' into dev-define-DeepFeatureExtr…
shaneahmed Dec 1, 2025
1080d5d
:fire: Remove references to CNN
shaneahmed Dec 1, 2025
4461912
:fire: Remove ioconfig from EngineRunParams
shaneahmed Dec 1, 2025
98e2312
:memo: Update docstrings to include all the kwargs options
shaneahmed Dec 1, 2025
3838285
:art: Remove explicit assignment of ioconfig outside kwargs
shaneahmed Dec 1, 2025
901aa8e
:memo: Fix docstring
shaneahmed Dec 1, 2025
d2be484
Revert ":memo: Fix docstring"
shaneahmed Dec 1, 2025
0842581
Revert ":art: Remove explicit assignment of ioconfig outside kwargs"
shaneahmed Dec 1, 2025
7db8642
:art: Move labels input to kwargs as it's not required for all the en…
shaneahmed Dec 2, 2025
ab80306
:building_construction: Move input_resolutions from **kwargs to named…
shaneahmed Dec 2, 2025
305588f
:building_construction: Move patch_input_shape from **kwargs to named…
shaneahmed Dec 2, 2025
b5e6535
:memo: Update docstring for RunParams
shaneahmed Dec 2, 2025
ec2f431
Merge branch 'dev-define-engines-abc' into dev-define-DeepFeatureExtr…
shaneahmed Dec 2, 2025
e015933
:fire: Remove large models from tests on GitHub actions.
shaneahmed Dec 2, 2025
d24dfbe
:sparkles: Add support for architectures in `timm.list_models()`
shaneahmed Dec 2, 2025
acc6f97
:white_check_mark: Fix test for alexnet raise error check, as timm su…
shaneahmed Dec 2, 2025
cce3883
:building_construction: Rebase deep feature extractor on patchpredictor
shaneahmed Dec 2, 2025
68923e6
:fire: Labels are not required in cli
shaneahmed Dec 3, 2025
9872abf
:white_check_mark: Add `patch_input_shape` parameter to cli
shaneahmed Dec 3, 2025
220d030
:white_check_mark: Add `overwrite` and `stride_shape` parameter to cli
shaneahmed Dec 3, 2025
df7ce8b
:white_check_mark: Add `scale_factor` parameter to cli
shaneahmed Dec 3, 2025
5315bba
:white_check_mark: Add `output_file` parameter to cli
shaneahmed Dec 3, 2025
306408c
:white_check_mark: Add `class_dict` parameter to cli
shaneahmed Dec 3, 2025
75ca70d
:white_check_mark: Add `class_dict` parameter to cli
shaneahmed Dec 3, 2025
751971e
:white_check_mark: Add `input-resolutions` parameter to cli
shaneahmed Dec 3, 2025
f01adf8
:bug: Fix missing import
shaneahmed Dec 3, 2025
ffc5d24
:white_check_mark: Update cli for `semantic_segmentor`
shaneahmed Dec 3, 2025
cf1da0b
:white_check_mark: Add tests for additional inputs for cli
shaneahmed Dec 3, 2025
691093e
:white_check_mark: Add tests for additional inputs for cli
shaneahmed Dec 3, 2025
c508549
:white_check_mark: Add tests for deep feature extractor cli
shaneahmed Dec 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 232 additions & 0 deletions tests/engines/test_feature_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
"""Test for feature extractor."""

import shutil
from collections.abc import Callable
from pathlib import Path

import numpy as np
import pytest
import torch
import zarr
from click.testing import CliRunner

from tiatoolbox import cli
from tiatoolbox.models import IOSegmentorConfig
from tiatoolbox.models.architecture.vanilla import CNNBackbone, TimmBackbone
from tiatoolbox.models.engine.deep_feature_extractor import DeepFeatureExtractor
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import WSIReader

ON_GPU = not toolbox_env.running_on_ci() and toolbox_env.has_gpu()

# -------------------------------------------------------------------------------------
# Engine
# -------------------------------------------------------------------------------------

device = "cuda" if toolbox_env.has_gpu() else "cpu"


def test_feature_extractor_patches(
remote_sample: Callable,
) -> None:
"""Tests DeepFeatureExtractor on image patches."""
extractor = DeepFeatureExtractor(
model="fcn-tissue_mask", batch_size=32, verbose=False, device=device
)

sample_image = remote_sample("thumbnail-1k-1k")

inputs = [sample_image, sample_image]

assert not extractor.patch_mode
output = extractor.run(
images=inputs,
return_probabilities=True,
return_labels=False,
device=device,
patch_mode=True,
)

assert 0.48 < np.mean(output["features"][:]) < 0.52

with pytest.raises(
ValueError,
match=r".*output_type: `annotationstore` is not supported "
r"for `DeepFeatureExtractor` engine",
):
_ = extractor.run(
images=inputs,
return_probabilities=True,
return_labels=False,
device=device,
patch_mode=True,
output_type="annotationstore",
)


def test_feature_extractor_wsi(remote_sample: Callable, track_tmp_path: Path) -> None:
"""Test feature extraction with DeepFeatureExtractor engine."""
save_dir = track_tmp_path / "output"
# # convert to pathlib Path to prevent wsireader complaint
mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs"))

# * test providing pretrained from torch vs pretrained_model.yaml
shutil.rmtree(save_dir, ignore_errors=True) # default output dir test

extractor = DeepFeatureExtractor(batch_size=1, model="fcn-tissue_mask")
output = extractor.run(
images=[mini_wsi_svs],
return_probabilities=False,
return_labels=False,
device=device,
patch_mode=False,
save_dir=track_tmp_path / "wsi_out_check",
batch_size=1,
output_type="zarr",
memory_threshold=1,
)

output_ = zarr.open(output[mini_wsi_svs], mode="r")
assert len(output_["coordinates"].shape) == 2
assert len(output_["features"].shape) == 3


@pytest.mark.parametrize(
"model", [CNNBackbone("resnet50"), TimmBackbone("efficientnet_b0", pretrained=True)]
)
def test_full_inference(
remote_sample: Callable, track_tmp_path: Path, model: Callable
) -> None:
"""Test full inference with CNNBackbone and TimmBackbone models."""
save_dir = track_tmp_path / "output"
# pre-emptive clean up
shutil.rmtree(save_dir, ignore_errors=True) # default output dir test

mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs"))

ioconfig = IOSegmentorConfig(
input_resolutions=[
{"units": "mpp", "resolution": 0.25},
],
output_resolutions=[
{"units": "mpp", "resolution": 0.25},
],
patch_input_shape=[512, 512],
patch_output_shape=[512, 512],
stride_shape=[256, 256],
save_resolution={"units": "mpp", "resolution": 8.0},
)

extractor = DeepFeatureExtractor(batch_size=4, model=model)
output = extractor.run(
images=[mini_wsi_svs],
device=device,
save_dir=track_tmp_path / "wsi_out_check",
batch_size=4,
output_type="zarr",
ioconfig=ioconfig,
patch_mode=False,
)

output_ = zarr.open(output[mini_wsi_svs], mode="r")

positions = output_["coordinates"]
features = output_["features"]

reader = WSIReader.open(mini_wsi_svs)
patches = [
reader.read_bounds(
positions[patch_idx],
resolution=0.25,
units="mpp",
pad_constant_values=255,
coord_space="resolution",
)
for patch_idx in range(4)
]
patches = np.array(patches)
patches = torch.from_numpy(patches) # NHWC
patches = patches.permute(0, 3, 1, 2).contiguous() # NCHW
patches = patches.to(device).type(torch.float32)
model = extractor.model
# Inference mode
model.eval()
with torch.inference_mode():
_features = model(patches).cpu().numpy()
# ! must maintain same batch size and likely same ordering
# ! else the output values will not exactly be the same (still < 1.0e-4
# ! of epsilon though)
assert np.mean(np.abs(features[:4] - _features)) < 1.0e-1


@pytest.mark.skipif(
toolbox_env.running_on_ci() or not ON_GPU,
reason="Local test on machine with GPU.",
)
def test_multi_gpu_feature_extraction(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adamshephard Please can you test this? Thanks

remote_sample: Callable, track_tmp_path: Path
) -> None:
"""Local functionality test for feature extraction using multiple GPUs."""
save_dir = track_tmp_path / "output"
mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs"))
shutil.rmtree(save_dir, ignore_errors=True)

# Use multiple GPUs
device = select_device(on_gpu=ON_GPU)

wsi_ioconfig = IOSegmentorConfig(
input_resolutions=[{"units": "mpp", "resolution": 0.5}],
patch_input_shape=[224, 224],
output_resolutions=[{"units": "mpp", "resolution": 0.5}],
patch_output_shape=[224, 224],
stride_shape=[224, 224],
)

model = TimmBackbone(backbone="UNI", pretrained=True)
extractor = DeepFeatureExtractor(
model=model,
batch_size=32,
num_workers=4,
)

output = extractor.run(
[mini_wsi_svs],
patch_mode=False,
device=device,
ioconfig=wsi_ioconfig,
save_dir=save_dir,
auto_get_mask=True,
output_type="zarr",
)
output_ = zarr.open(output[mini_wsi_svs], mode="r")

positions = output_["coordinates"]
features = output_["features"]
assert len(positions.shape) == 2
assert len(features.shape) == 2


# -------------------------------------------------------------------------------------
# Command Line Interface
# -------------------------------------------------------------------------------------


def test_cli_model_single_file(sample_svs: Path, track_tmp_path: Path) -> None:
"""Test for feature extractor CLI single file."""
runner = CliRunner()
models_wsi_result = runner.invoke(
cli.main,
[
"deep-feature-extractor",
"--img-input",
str(sample_svs),
"--patch-mode",
"False",
"--output-path",
str(track_tmp_path / "output"),
],
)

assert models_wsi_result.exit_code == 0
assert (track_tmp_path / "output" / (sample_svs.stem + ".zarr")).exists()
2 changes: 1 addition & 1 deletion tests/engines/test_semantic_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def test_wsi_segmentor_annotationstore(


def test_cli_model_single_file(sample_svs: Path, track_tmp_path: Path) -> None:
"""Test for models CLI single file."""
"""Test semantic segmentor CLI single file."""
runner = CliRunner()
models_wsi_result = runner.invoke(
cli.main,
Expand Down
2 changes: 2 additions & 0 deletions tiatoolbox/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from tiatoolbox import __version__
from tiatoolbox.cli.common import tiatoolbox_cli
from tiatoolbox.cli.deep_feature_extractor import deep_feature_extractor
from tiatoolbox.cli.nucleus_instance_segment import nucleus_instance_segment
from tiatoolbox.cli.patch_predictor import patch_predictor
from tiatoolbox.cli.read_bounds import read_bounds
Expand Down Expand Up @@ -43,6 +44,7 @@ def main() -> click.BaseCommand:
main.add_command(read_bounds)
main.add_command(save_tiles)
main.add_command(semantic_segmentor)
main.add_command(deep_feature_extractor)
main.add_command(slide_info)
main.add_command(slide_thumbnail)
main.add_command(tissue_mask)
Expand Down
113 changes: 113 additions & 0 deletions tiatoolbox/cli/deep_feature_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""Command line interface for deep feature extractor."""

from __future__ import annotations

from tiatoolbox.cli.common import (
cli_auto_get_mask,
cli_batch_size,
cli_device,
cli_file_type,
cli_img_input,
cli_masks,
cli_memory_threshold,
cli_model,
cli_num_workers,
cli_output_path,
cli_output_type,
cli_patch_mode,
cli_return_labels,
cli_return_probabilities,
cli_verbose,
cli_weights,
cli_yaml_config_path,
prepare_ioconfig,
prepare_model_cli,
tiatoolbox_cli,
)


@tiatoolbox_cli.command()
@cli_img_input()
@cli_output_path(
usage_help="Output directory where model features will be saved.",
default="deep_feature_extractor",
)
@cli_file_type(
default="*.png, *.jpg, *.jpeg, *.tif, *.tiff, *.svs, *.ndpi, *.jp2, *.mrxs",
)
@cli_model(default="fcn-tissue_mask")
@cli_weights()
@cli_device(default="cpu")
@cli_batch_size(default=1)
@cli_yaml_config_path()
@cli_masks(default=None)
@cli_num_workers(default=0)
@cli_output_type(
default="zarr",
)
@cli_memory_threshold(default=80)
@cli_patch_mode(default=False)
@cli_return_probabilities(default=True)
@cli_return_labels(default=False)
@cli_auto_get_mask(default=True)
@cli_verbose(default=True)
def deep_feature_extractor(
model: str,
weights: str,
img_input: str,
file_types: str,
masks: str | None,
output_path: str,
batch_size: int,
yaml_config_path: str,
num_workers: int,
device: str,
output_type: str,
memory_threshold: int,
*,
patch_mode: bool,
return_probabilities: bool,
return_labels: bool,
auto_get_mask: bool,
verbose: bool,
) -> None:
"""Process a set of input images with a deep feature extractor engine."""
from tiatoolbox.models import ( # noqa: PLC0415
DeepFeatureExtractor,
IOSegmentorConfig,
)

files_all, masks_all, output_path = prepare_model_cli(
img_input=img_input,
output_path=output_path,
masks=masks,
file_types=file_types,
)

ioconfig = prepare_ioconfig(
IOSegmentorConfig,
pretrained_weights=weights,
yaml_config_path=yaml_config_path,
)

extractor = DeepFeatureExtractor(
model=model,
weights=weights,
batch_size=batch_size,
num_workers=num_workers,
verbose=verbose,
)

_ = extractor.run(
images=files_all,
masks=masks_all,
patch_mode=patch_mode,
ioconfig=ioconfig,
device=device,
save_dir=output_path,
output_type=output_type,
return_probabilities=return_probabilities,
return_labels=return_labels,
auto_get_mask=auto_get_mask,
memory_threshold=memory_threshold,
)
2 changes: 2 additions & 0 deletions tiatoolbox/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .architecture.nuclick import NuClick
from .architecture.sccnn import SCCNN
from .dataset import PatchDataset, WSIPatchDataset, WSIStreamDataset
from .engine.deep_feature_extractor import DeepFeatureExtractor
from .engine.io_config import (
IOInstanceSegmentorConfig,
IOPatchPredictorConfig,
Expand All @@ -24,6 +25,7 @@

__all__ = [
"SCCNN",
"DeepFeatureExtractor",
"HoVerNet",
"HoVerNetPlus",
"IDaRS",
Expand Down
Loading