diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 0ab9fe63d5..3fd33b76da 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -293,6 +293,7 @@ AsDiscrete, DistanceTransformEDT, FillHoles, + GenerateHeatmap, Invert, KeepLargestConnectedComponent, LabelFilter, @@ -319,6 +320,9 @@ FillHolesD, FillHolesd, FillHolesDict, + GenerateHeatmapd, + GenerateHeatmapD, + GenerateHeatmapDict, InvertD, Invertd, InvertDict, diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 2e733c4f6c..47623b748d 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -38,7 +38,14 @@ remove_small_objects, ) from monai.transforms.utils_pytorch_numpy_unification import unravel_index -from monai.utils import TransformBackends, convert_data_type, convert_to_tensor, ensure_tuple, look_up_option +from monai.utils import ( + TransformBackends, + convert_data_type, + convert_to_tensor, + ensure_tuple, + get_equivalent_dtype, + look_up_option, +) from monai.utils.type_conversion import convert_to_dst_type __all__ = [ @@ -54,6 +61,7 @@ "SobelGradients", "VoteEnsemble", "Invert", + "GenerateHeatmap", "DistanceTransformEDT", ] @@ -742,6 +750,154 @@ def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayO return self.post_convert(out_pt, img) +class GenerateHeatmap(Transform): + """ + Generate per-landmark Gaussian heatmaps for 2D or 3D coordinates. + + Notes: + - Coordinates are interpreted in voxel units and expected in (Y, X) for 2D or (Z, Y, X) for 3D. + - Target spatial_shape is (Y, X) for 2D and (Z, Y, X) for 3D. + - Output layout uses channel-first convention with one channel per landmark. + - Input points shape: (N, D) where N is number of landmarks, D is spatial dimensions (2 or 3). + - Output heatmap shape: (N, Y, X) for 2D or (N, Z, Y, X) for 3D. + - Each channel index corresponds to one landmark. + + Args: + sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions. + spatial_shape: optional fallback spatial shape. If ``None`` it must be provided when calling the transform. + truncated: extent, in multiples of ``sigma``, used to crop the gaussian support window. + normalize: normalize every heatmap channel to ``[0, 1]`` when ``True``. + dtype: target dtype for the generated heatmaps (accepts numpy or torch dtypes). + + Raises: + ValueError: when ``sigma`` is non-positive or ``spatial_shape`` cannot be resolved. + + """ + + backend = [TransformBackends.NUMPY, TransformBackends.TORCH] + + def __init__( + self, + sigma: Sequence[float] | float = 5.0, + spatial_shape: Sequence[int] | None = None, + truncated: float = 4.0, + normalize: bool = True, + dtype: np.dtype | torch.dtype | type = np.float32, + ) -> None: + if isinstance(sigma, Sequence) and not isinstance(sigma, (str, bytes)): + if any(s <= 0 for s in sigma): + raise ValueError("Argument `sigma` values must be positive.") + self._sigma = tuple(float(s) for s in sigma) + else: + if float(sigma) <= 0: + raise ValueError("Argument `sigma` must be positive.") + self._sigma = (float(sigma),) + if truncated <= 0: + raise ValueError("Argument `truncated` must be positive.") + self.truncated = float(truncated) + self.normalize = normalize + self.torch_dtype = get_equivalent_dtype(dtype, torch.Tensor) + self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray) + # Validate that dtype is floating-point for meaningful Gaussian values + if not self.torch_dtype.is_floating_point: + raise ValueError(f"Argument `dtype` must be a floating-point type, got {self.torch_dtype}") + self.spatial_shape = None if spatial_shape is None else tuple(int(s) for s in spatial_shape) + + def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None = None) -> NdarrayOrTensor: + """ + Args: + points: landmark coordinates as ndarray/Tensor with shape (N, D), + ordered as (Y, X) for 2D or (Z, Y, X) for 3D, where N is the number + of landmarks and D is the spatial dimensionality. + spatial_shape: spatial size as a sequence. If None, uses the value provided at construction. + + Returns: + Heatmaps with shape (N, *spatial), one channel per landmark. + + Raises: + ValueError: if points shape/dimension or spatial_shape is invalid. + """ + original_points = points + points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False) + + if points_t.ndim != 2: + raise ValueError( + f"Argument `points` must be a 2D array with shape (num_points, spatial_dims), got shape {points_t.shape}." + ) + + if points_t.shape[-1] not in (2, 3): + raise ValueError("GenerateHeatmap only supports 2D or 3D landmarks.") + + device = points_t.device + num_points, spatial_dims = points_t.shape + + target_shape = self._resolve_spatial_shape(spatial_shape, spatial_dims) + sigma = self._resolve_sigma(spatial_dims) + + # Create sparse image with impulses at landmark locations + heatmap = torch.zeros((num_points, *target_shape), dtype=self.torch_dtype, device=device) + bounds_t = torch.as_tensor(target_shape, device=device, dtype=points_t.dtype) + + for idx, center in enumerate(points_t): + if not torch.isfinite(center).all(): + continue + if not ((center >= 0).all() and (center < bounds_t).all()): + continue + # Round to nearest integer for impulse placement, then clamp to valid index range + center_int = center.round().long() + # Clamp indices to [0, size-1] to avoid out-of-bounds (e.g., 9.7 rounds to 10 in size-10 array) + bounds_max = (bounds_t - 1).long() + center_int = torch.minimum(torch.maximum(center_int, torch.zeros_like(center_int)), bounds_max) + # Place impulse (use maximum in case of overlapping landmarks) + current_val = heatmap[idx][tuple(center_int)] + heatmap[idx][tuple(center_int)] = torch.maximum( + current_val, torch.tensor(1.0, dtype=self.torch_dtype, device=device) + ) + + # Apply Gaussian blur using GaussianFilter + # Reshape to (num_points, 1, *spatial) for per-channel filtering + heatmap_input = heatmap.unsqueeze(1) # Add channel dimension + + gaussian_filter = GaussianFilter( + spatial_dims=spatial_dims, sigma=sigma, truncated=self.truncated, approx="erf", requires_grad=False + ).to(device=device, dtype=self.torch_dtype) + + heatmap_blurred = gaussian_filter(heatmap_input) + heatmap = heatmap_blurred.squeeze(1) # Remove channel dimension + + # Normalize per channel if requested + if self.normalize: + for idx in range(num_points): + peak = heatmap[idx].amax() + if peak > 0: + heatmap[idx].div_(peak) + + target_dtype = self.torch_dtype if isinstance(original_points, (torch.Tensor, MetaTensor)) else self.numpy_dtype + converted, _, _ = convert_to_dst_type(heatmap, original_points, dtype=target_dtype) + return converted + + def _resolve_spatial_shape(self, call_shape: Sequence[int] | None, spatial_dims: int) -> tuple[int, ...]: + shape = call_shape if call_shape is not None else self.spatial_shape + if shape is None: + raise ValueError("Argument `spatial_shape` must be provided either at construction time or call time.") + shape_tuple = ensure_tuple(shape) + if len(shape_tuple) != spatial_dims: + if len(shape_tuple) == 1: + shape_tuple = shape_tuple * spatial_dims # type: ignore + else: + raise ValueError( + "Argument `spatial_shape` length must match the landmarks' spatial dims (or pass a single int to broadcast)." + ) + return tuple(int(s) for s in shape_tuple) + + def _resolve_sigma(self, spatial_dims: int) -> tuple[float, ...]: + if len(self._sigma) == spatial_dims: + return self._sigma + if len(self._sigma) == 1: + return self._sigma * spatial_dims + raise ValueError("Argument `sigma` sequence length must equal the number of spatial dimensions.") + + class ProbNMS(Transform): """ Performs probability based non-maximum suppression (NMS) on the probabilities map via diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 7e1e074f71..65fdd22b22 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -35,6 +35,7 @@ AsDiscrete, DistanceTransformEDT, FillHoles, + GenerateHeatmap, KeepLargestConnectedComponent, LabelFilter, LabelToContour, @@ -48,6 +49,7 @@ from monai.transforms.utility.array import ToTensor from monai.transforms.utils import allow_missing_keys_mode, convert_applied_interp_mode from monai.utils import PostFix, convert_to_tensor, ensure_tuple, ensure_tuple_rep +from monai.utils.type_conversion import convert_to_dst_type __all__ = [ "ActivationsD", @@ -95,6 +97,9 @@ "DistanceTransformEDTd", "DistanceTransformEDTD", "DistanceTransformEDTDict", + "GenerateHeatmapd", + "GenerateHeatmapD", + "GenerateHeatmapDict", ] DEFAULT_POST_FIX = PostFix.meta() @@ -508,6 +513,208 @@ def __init__(self, keys: KeysCollection, output_key: str | None = None, num_clas super().__init__(keys, ensemble, output_key) +class GenerateHeatmapd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.GenerateHeatmap`. + Converts landmark coordinates into gaussian heatmaps and optionally copies metadata from a reference image. + + Args: + keys: keys of the corresponding items in the dictionary, where each key references a tensor + of landmark point coordinates with shape (N, D), where N is the number of landmarks + and D is the spatial dimensionality (2 or 3). + sigma: standard deviation for the Gaussian kernel. Can be a single value or a sequence matching the number + of spatial dimensions. + heatmap_keys: keys to store output heatmaps. Default: "{key}_heatmap" for each key. + ref_image_keys: keys of reference images to inherit spatial metadata from. When provided, heatmaps will + have the same shape, affine, and spatial metadata as the reference images. + spatial_shape: spatial dimensions of output heatmaps. Can be: + - Single shape (tuple): applied to all keys + - List of shapes: one per key (must match keys length) + truncated: truncation distance for Gaussian kernel computation (in sigmas). + normalize: if True, normalize each heatmap's peak value to 1.0. + dtype: output data type for heatmaps. Defaults to np.float32. + allow_missing_keys: if True, don't raise error if some keys are missing in data. + + Returns: + Dictionary with original data plus generated heatmaps at specified keys. + + Raises: + ValueError: If heatmap_keys/ref_image_keys length doesn't match keys length. + ValueError: If no spatial shape can be determined (need spatial_shape or ref_image_keys). + ValueError: If input points have invalid shape (must be 2D array with shape (N, D)). + + Example: + .. code-block:: python + + import numpy as np + from monai.transforms import GenerateHeatmapd + + # Create sample data with landmark points and a reference image + data = { + "landmarks": np.array([[10.0, 15.0], [20.0, 25.0]]), # 2 points in 2D + "image": np.zeros((32, 32)) # reference image + } + + # Transform with reference image + transform = GenerateHeatmapd( + keys="landmarks", + sigma=2.0, + ref_image_keys="image" + ) + result = transform(data) + # result["landmarks_heatmap"] has shape (2, 32, 32) - one channel per landmark + + # Or with explicit spatial_shape + transform = GenerateHeatmapd( + keys="landmarks", + sigma=2.0, + spatial_shape=(64, 64) + ) + result = transform(data) + # result["landmarks_heatmap"] has shape (2, 64, 64) + + Notes: + - Default heatmap_keys are generated as "{key}_heatmap" for each input key + - Shape inference precedence: static spatial_shape > ref_image + - Input points shape: (N, D) where N is number of landmarks, D is spatial dimensions + - Output heatmap shape: (N, H, W) for 2D or (N, H, W, D) for 3D + - When using ref_image_keys, heatmaps inherit affine and spatial metadata from reference + """ + + backend = GenerateHeatmap.backend + + # Error messages + _ERR_HEATMAP_KEYS_LEN = "Argument `heatmap_keys` length must match keys length." + _ERR_REF_KEYS_LEN = "Argument `ref_image_keys` length must match keys length when provided." + _ERR_SHAPE_LEN = "Argument `spatial_shape` length must match keys length when providing per-key shapes." + _ERR_NO_SHAPE = "Unable to determine spatial shape for GenerateHeatmapd. Provide spatial_shape or ref_image_keys." + _ERR_INVALID_POINTS = "Landmark arrays must be 2D with shape (N, D)." + _ERR_REF_NO_SHAPE = "Reference data must define a shape attribute." + + def __init__( + self, + keys: KeysCollection, + sigma: Sequence[float] | float = 5.0, + heatmap_keys: KeysCollection | None = None, + ref_image_keys: KeysCollection | None = None, + spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None = None, + truncated: float = 4.0, + normalize: bool = True, + dtype: np.dtype | torch.dtype | type = np.float32, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.heatmap_keys = self._prepare_heatmap_keys(heatmap_keys) + self.ref_image_keys = self._prepare_optional_keys(ref_image_keys) + self.static_shapes = self._prepare_shapes(spatial_shape) + self.generator = GenerateHeatmap( + sigma=sigma, spatial_shape=None, truncated=truncated, normalize=normalize, dtype=dtype + ) + + def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]: + d = dict(data) + for key, out_key, ref_key, static_shape in self.key_iterator( + d, self.heatmap_keys, self.ref_image_keys, self.static_shapes + ): + points = d[key] + shape = self._determine_shape(points, static_shape, d, ref_key) + # The GenerateHeatmap transform will handle type conversion based on input points + heatmap = self.generator(points, spatial_shape=shape) + # If there's a reference image and we need to match its type/device + reference = d.get(ref_key) if ref_key is not None and ref_key in d else None + if reference is not None and isinstance(reference, (torch.Tensor, np.ndarray)): + # Convert to match reference type and device while preserving heatmap's dtype + heatmap, _, _ = convert_to_dst_type( + heatmap, reference, dtype=heatmap.dtype, device=getattr(reference, "device", None) + ) + # Copy metadata if reference is MetaTensor + if isinstance(reference, MetaTensor) and isinstance(heatmap, MetaTensor): + heatmap.affine = reference.affine + self._update_spatial_metadata(heatmap, shape) + d[out_key] = heatmap + return d + + def _prepare_heatmap_keys(self, heatmap_keys: KeysCollection | None) -> tuple[Hashable, ...]: + if heatmap_keys is None: + return tuple(f"{key}_heatmap" for key in self.keys) + keys_tuple = ensure_tuple(heatmap_keys) + if len(keys_tuple) == 1 and len(self.keys) > 1: + keys_tuple = keys_tuple * len(self.keys) + if len(keys_tuple) != len(self.keys): + raise ValueError(self._ERR_HEATMAP_KEYS_LEN) + return keys_tuple + + def _prepare_optional_keys(self, maybe_keys: KeysCollection | None) -> tuple[Hashable | None, ...]: + if maybe_keys is None: + return (None,) * len(self.keys) + keys_tuple = ensure_tuple(maybe_keys) + if len(keys_tuple) == 1 and len(self.keys) > 1: + keys_tuple = keys_tuple * len(self.keys) + if len(keys_tuple) != len(self.keys): + raise ValueError(self._ERR_REF_KEYS_LEN) + return tuple(keys_tuple) + + def _prepare_shapes( + self, spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None + ) -> tuple[tuple[int, ...] | None, ...]: + if spatial_shape is None: + return (None,) * len(self.keys) + shape_tuple = ensure_tuple(spatial_shape) + if shape_tuple and all(isinstance(v, (int, np.integer)) for v in shape_tuple): + shape = tuple(int(v) for v in shape_tuple) + return (shape,) * len(self.keys) + if len(shape_tuple) == 1 and len(self.keys) > 1: + shape_tuple = shape_tuple * len(self.keys) + if len(shape_tuple) != len(self.keys): + raise ValueError(self._ERR_SHAPE_LEN) + prepared: list[tuple[int, ...] | None] = [] + for item in shape_tuple: + if item is None: + prepared.append(None) + else: + dims = ensure_tuple(item) + prepared.append(tuple(int(v) for v in dims)) + return tuple(prepared) + + def _determine_shape( + self, points: Any, static_shape: tuple[int, ...] | None, data: Mapping[Hashable, Any], ref_key: Hashable | None + ) -> tuple[int, ...]: + points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False) + if points_t.ndim != 2: + raise ValueError(f"{self._ERR_INVALID_POINTS} Got {points_t.ndim}D tensor.") + spatial_dims = int(points_t.shape[-1]) + if static_shape is not None: + if len(static_shape) == 1 and spatial_dims > 1: + static_shape = tuple([static_shape[0]] * spatial_dims) + if len(static_shape) != spatial_dims: + raise ValueError( + f"Provided static spatial_shape has {len(static_shape)} dims; expected {spatial_dims}." + ) + return static_shape + if ref_key is not None and ref_key in data: + return self._shape_from_reference(data[ref_key], spatial_dims) + raise ValueError(self._ERR_NO_SHAPE) + + def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int, ...]: + if isinstance(reference, MetaTensor): + meta_shape = reference.meta.get("spatial_shape") + if meta_shape is not None: + dims = ensure_tuple(meta_shape) + if len(dims) == spatial_dims: + return tuple(int(v) for v in dims) + return tuple(int(v) for v in reference.shape[-spatial_dims:]) + if hasattr(reference, "shape"): + return tuple(int(v) for v in reference.shape[-spatial_dims:]) + raise ValueError(self._ERR_REF_NO_SHAPE) + + def _update_spatial_metadata(self, heatmap: MetaTensor, spatial_shape: tuple[int, ...]) -> None: + """Set spatial_shape explicitly from resolved shape.""" + heatmap.meta["spatial_shape"] = tuple(int(v) for v in spatial_shape) + + +GenerateHeatmapD = GenerateHeatmapDict = GenerateHeatmapd + + class ProbNMSd(MapTransform): """ Performs probability based non-maximum suppression (NMS) on the probabilities map via diff --git a/tests/transforms/test_generate_heatmap.py b/tests/transforms/test_generate_heatmap.py new file mode 100644 index 0000000000..0dd108429d --- /dev/null +++ b/tests/transforms/test_generate_heatmap.py @@ -0,0 +1,261 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms.post.array import GenerateHeatmap +from tests.test_utils import TEST_NDARRAYS + + +def _argmax_nd(x) -> np.ndarray: + """argmax for N-D array → returns coordinate vector (z,y,x) or (y,x).""" + if isinstance(x, torch.Tensor): + x = x.cpu().numpy() + return np.asarray(np.unravel_index(np.argmax(x), x.shape)) + + +# Test cases for 2D array inputs with different data types +TEST_CASES_2D = [ + [ + f"2d_basic_type{idx}", + p(np.array([[4.2, 7.8], [12.3, 3.6]], dtype=np.float32)), + {"sigma": 1.5, "spatial_shape": (16, 16)}, + (2, 16, 16), + ] + for idx, p in enumerate(TEST_NDARRAYS) +] + +# Test cases for 3D torch outputs with explicit dtype +TEST_CASES_3D_TORCH = [ + [ + f"3d_torch_{str(dtype).replace('torch.', '')}", + torch.tensor([[1.5, 2.5, 3.5]], dtype=torch.float32), + {"sigma": 1.0, "spatial_shape": (8, 8, 8), "dtype": dtype}, + (1, 8, 8, 8), + dtype, + ] + for dtype in [torch.float32, torch.float64] +] + +# Test cases for 3D numpy outputs with explicit dtype +TEST_CASES_3D_NUMPY = [ + [ + f"3d_numpy_{dtype_obj.__name__}", + np.array([[1.5, 2.5, 3.5]], dtype=np.float32), + {"sigma": 1.0, "spatial_shape": (8, 8, 8), "dtype": dtype_obj}, + (1, 8, 8, 8), + dtype_obj, + ] + for dtype_obj in [np.float32, np.float64] +] + +# Test cases for different sigma values +TEST_CASES_SIGMA = [ + [ + f"sigma_{sigma}", + np.array([[8.0, 8.0]], dtype=np.float32), + {"sigma": sigma, "spatial_shape": (16, 16)}, + (1, 16, 16), + ] + for sigma in [0.5, 1.0, 2.0, 3.0] +] + +# Test cases for truncated parameter +TEST_CASES_TRUNCATED = [ + [ + f"truncated_{truncated}", + np.array([[8.0, 8.0]], dtype=np.float32), + {"sigma": 2.0, "spatial_shape": (32, 32), "truncated": truncated}, + (1, 32, 32), + ] + for truncated in [2.0, 4.0, 6.0] +] + +# Test cases for device and dtype propagation (torch only) +test_device = "cuda:0" if torch.cuda.is_available() else "cpu" +test_dtypes = [torch.float32, torch.float64] +if torch.cuda.is_available(): + test_dtypes.append(torch.float16) + +TEST_CASES_DEVICE_DTYPE = [ + [ + f"{test_device.split(':')[0]}_{str(dtype).replace('torch.', '')}", + torch.tensor([[3.0, 4.0, 5.0]], dtype=torch.float32, device=test_device), + {"sigma": 1.2, "spatial_shape": (10, 10, 10), "dtype": dtype}, + (1, 10, 10, 10), + dtype, + test_device, + ] + for dtype in test_dtypes +] + + +class TestGenerateHeatmap(unittest.TestCase): + @parameterized.expand(TEST_CASES_2D) + def test_array_2d(self, _, points, params, expected_shape): + transform = GenerateHeatmap(**params) + heatmap = transform(points) + + # Check output type matches input type + if isinstance(points, torch.Tensor): + self.assertIsInstance(heatmap, torch.Tensor) + self.assertEqual(heatmap.dtype, torch.float32) # Default dtype for torch + heatmap_np = heatmap.cpu().numpy() + points_np = points.cpu().numpy() + else: + self.assertIsInstance(heatmap, np.ndarray) + self.assertEqual(heatmap.dtype, np.float32) # Default dtype for numpy + heatmap_np = heatmap + points_np = points + + self.assertEqual(heatmap.shape, expected_shape) + np.testing.assert_allclose(heatmap_np.max(axis=(1, 2)), np.ones(expected_shape[0]), rtol=1e-5, atol=1e-5) + + # peak should be close to original point location (<= 1px tolerance due to discretization) + for idx in range(expected_shape[0]): + peak = _argmax_nd(heatmap_np[idx]) + self.assertTrue(np.all(np.abs(peak - points_np[idx]) <= 1.0), msg=f"peak={peak}, point={points_np[idx]}") + self.assertLess(heatmap_np[idx, 0, 0], 1e-3) + + @parameterized.expand(TEST_CASES_3D_TORCH) + def test_array_3d_torch_output(self, _, points, params, expected_shape, expected_dtype): + transform = GenerateHeatmap(**params) + heatmap = transform(points) + + self.assertIsInstance(heatmap, torch.Tensor) + self.assertEqual(heatmap.device, points.device) + self.assertEqual(tuple(heatmap.shape), expected_shape) + self.assertEqual(heatmap.dtype, expected_dtype) + self.assertTrue(torch.isclose(heatmap.max(), torch.tensor(1.0, dtype=heatmap.dtype, device=heatmap.device))) + + @parameterized.expand(TEST_CASES_3D_NUMPY) + def test_array_3d_numpy_output(self, _, points, params, expected_shape, expected_dtype): + transform = GenerateHeatmap(**params) + heatmap = transform(points) + + self.assertIsInstance(heatmap, np.ndarray) + self.assertEqual(heatmap.shape, expected_shape) + self.assertEqual(heatmap.dtype, expected_dtype) + np.testing.assert_allclose(heatmap.max(), 1.0, rtol=1e-5) + + @parameterized.expand(TEST_CASES_DEVICE_DTYPE) + def test_array_torch_device_and_dtype_propagation( + self, _, pts, params, expected_shape, expected_dtype, expected_device + ): + tr = GenerateHeatmap(**params) + hm = tr(pts) + + self.assertIsInstance(hm, torch.Tensor) + self.assertEqual(str(hm.device).split(":")[0], expected_device.split(":")[0]) + self.assertEqual(hm.dtype, expected_dtype) + self.assertEqual(tuple(hm.shape), expected_shape) + self.assertTrue(torch.all(hm >= 0)) + + def test_array_channel_order_identity(self): + # ensure the order of channels follows the order of input points + pts = np.array([[2.0, 2.0], [12.0, 2.0], [2.0, 12.0]], dtype=np.float32) # point A # point B # point C + hm = GenerateHeatmap(sigma=1.2, spatial_shape=(16, 16))(pts) + + self.assertIsInstance(hm, np.ndarray) + self.assertEqual(hm.shape, (3, 16, 16)) + + peaks = np.vstack([_argmax_nd(hm[i]) for i in range(3)]) + # y,x close to points + np.testing.assert_allclose(peaks, pts, atol=1.0) + + def test_array_points_out_of_bounds(self): + # points outside spatial domain: heatmap should still be valid (no NaN/Inf) and not all-zeros + pts = np.array( + [[-5.0, -5.0], [100.0, 100.0], [8.0, 8.0]], # outside top-left # outside bottom-right # inside + dtype=np.float32, + ) + hm = GenerateHeatmap(sigma=2.0, spatial_shape=(16, 16))(pts) + + self.assertIsInstance(hm, np.ndarray) + self.assertEqual(hm.shape, (3, 16, 16)) + self.assertFalse(np.isnan(hm).any() or np.isinf(hm).any()) + + # inside point channel should have max≈1; others may clip at border (≤1) + self.assertGreater(hm[2].max(), 0.9) + + @parameterized.expand(TEST_CASES_SIGMA) + def test_array_sigma_scaling_effect(self, _, pt, params, expected_shape): + heatmap = GenerateHeatmap(**params)(pt)[0] + self.assertEqual(heatmap.shape, expected_shape[1:]) + + # All should have peak normalized to 1.0 + np.testing.assert_allclose(heatmap.max(), 1.0, rtol=1e-5) + + # Verify heatmap is valid + self.assertFalse(np.isnan(heatmap).any() or np.isinf(heatmap).any()) + + def test_invalid_points_shape_raises(self): + # points must be (N, D) with D in {2,3} + tr = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8)) + with self.assertRaises((ValueError, AssertionError, IndexError, RuntimeError)): + tr(np.zeros((2,), dtype=np.float32)) # wrong rank + + with self.assertRaises((ValueError, AssertionError, IndexError, RuntimeError)): + tr(np.zeros((2, 4), dtype=np.float32)) # D=4 unsupported + + @parameterized.expand(TEST_CASES_TRUNCATED) + def test_truncated_parameter(self, _, pt, params, expected_shape): + heatmap = GenerateHeatmap(**params)(pt)[0] + + # All should have same peak value (normalized to 1.0) + np.testing.assert_allclose(heatmap.max(), 1.0, rtol=1e-5) + + # Verify shape and no NaN/Inf + self.assertEqual(heatmap.shape, expected_shape[1:]) + self.assertFalse(np.isnan(heatmap).any() or np.isinf(heatmap).any()) + + def test_torch_to_torch_type_preservation(self): + """Test that torch input produces torch output""" + pts = torch.tensor([[4.0, 4.0]], dtype=torch.float32) + hm = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8))(pts) + + self.assertIsInstance(hm, torch.Tensor) + self.assertEqual(hm.dtype, torch.float32) + self.assertEqual(hm.device, pts.device) + + def test_numpy_to_numpy_type_preservation(self): + """Test that numpy input produces numpy output""" + pts = np.array([[4.0, 4.0]], dtype=np.float32) + hm = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8))(pts) + + self.assertIsInstance(hm, np.ndarray) + self.assertEqual(hm.dtype, np.float32) + + def test_dtype_override_torch(self): + """Test dtype parameter works with torch tensors""" + pts = torch.tensor([[4.0, 4.0, 4.0]], dtype=torch.float32) + hm = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8, 8), dtype=torch.float64)(pts) + + self.assertIsInstance(hm, torch.Tensor) + self.assertEqual(hm.dtype, torch.float64) + + def test_dtype_override_numpy(self): + """Test dtype parameter works with numpy arrays""" + pts = np.array([[4.0, 4.0, 4.0]], dtype=np.float32) + hm = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8, 8), dtype=np.float64)(pts) + + self.assertIsInstance(hm, np.ndarray) + self.assertEqual(hm.dtype, np.float64) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/transforms/test_generate_heatmapd.py b/tests/transforms/test_generate_heatmapd.py new file mode 100644 index 0000000000..0867a959e5 --- /dev/null +++ b/tests/transforms/test_generate_heatmapd.py @@ -0,0 +1,225 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import MetaTensor +from monai.transforms.post.dictionary import GenerateHeatmapd +from tests.test_utils import assert_allclose + +# Test cases for dictionary transforms with reference image +# Only test with non-MetaTensor types to avoid affine conflicts +TEST_CASES_WITH_REF = [ + [ + "dict_with_ref_3d_numpy", + np.array([[2.5, 2.5, 3.0], [5.0, 5.0, 4.0]], dtype=np.float32), + {"sigma": 2.0}, + (2, 8, 8, 8), + torch.float32, + True, # uses reference image + ], + [ + "dict_with_ref_3d_torch", + torch.tensor([[2.5, 2.5, 3.0], [5.0, 5.0, 4.0]], dtype=torch.float32), + {"sigma": 2.0}, + (2, 8, 8, 8), + torch.float32, + True, # uses reference image + ], +] + +# Test cases for dictionary transforms with static spatial shape +TEST_CASES_STATIC_SHAPE = [ + [ + f"dict_static_shape_{len(shape)}d", + np.array([[1.0] * len(shape)], dtype=np.float32), + {"spatial_shape": shape}, + (1, *shape), + np.float32, + ] + for shape in [(6, 6), (8, 8, 8), (10, 10, 10)] +] + +# Test cases for dtype control +TEST_CASES_DTYPE = [ + [ + f"dict_dtype_{str(dtype).replace('torch.', '')}", + np.array([[2.0, 3.0, 4.0]], dtype=np.float32), + {"sigma": 1.4, "dtype": dtype}, + (1, 10, 10, 10), + dtype, + ] + for dtype in [torch.float16, torch.float32, torch.float64] +] + +# Test cases for various sigma values +TEST_CASES_SIGMA_VALUES = [ + [ + f"dict_sigma_{sigma}", + np.array([[4.0, 4.0, 4.0]], dtype=np.float32), + {"sigma": sigma, "spatial_shape": (8, 8, 8)}, + (1, 8, 8, 8), + ] + for sigma in [0.5, 1.0, 2.0, 3.0] +] + + +class TestGenerateHeatmapd(unittest.TestCase): + @parameterized.expand(TEST_CASES_WITH_REF) + def test_dict_with_reference_meta(self, _, points, params, expected_shape, *_unused): + affine = torch.eye(4) + image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine) + image.meta["spatial_shape"] = (8, 8, 8) + data = {"points": points, "image": image} + + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", ref_image_keys="image", **params) + result = transform(data) + heatmap = result["heatmap"] + + self.assertIsInstance(heatmap, MetaTensor) + self.assertEqual(tuple(heatmap.shape), expected_shape) + self.assertEqual(heatmap.meta["spatial_shape"], (8, 8, 8)) + # The heatmap should inherit the reference image's affine + assert_allclose(heatmap.affine, image.affine, type_test=False) + + # Check max values are normalized to 1.0 + max_vals = heatmap.cpu().numpy().max(axis=tuple(range(1, len(expected_shape)))) + np.testing.assert_allclose(max_vals, np.ones(expected_shape[0]), rtol=1e-5, atol=1e-5) + + @parameterized.expand(TEST_CASES_STATIC_SHAPE) + def test_dict_static_shape(self, _, points, params, expected_shape, expected_dtype): + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", **params) + result = transform({"points": points}) + heatmap = result["heatmap"] + + self.assertIsInstance(heatmap, np.ndarray) + self.assertEqual(heatmap.shape, expected_shape) + self.assertEqual(heatmap.dtype, expected_dtype) + + # Verify no NaN or Inf values + self.assertFalse(np.isnan(heatmap).any() or np.isinf(heatmap).any()) + + # Verify max value is 1.0 for normalized heatmaps + np.testing.assert_allclose(heatmap.max(), 1.0, rtol=1e-5) + + def test_dict_missing_shape_raises(self): + # Without ref image or explicit spatial_shape, must raise + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap") + with self.assertRaisesRegex(ValueError, "spatial_shape|ref_image_keys"): + transform({"points": np.zeros((1, 2), dtype=np.float32)}) + + @parameterized.expand(TEST_CASES_DTYPE) + def test_dict_dtype_control(self, _, points, params, expected_shape, expected_dtype): + ref = MetaTensor(torch.zeros((1, 10, 10, 10), dtype=torch.float32), affine=torch.eye(4)) + d = {"pts": points, "img": ref} + + tr = GenerateHeatmapd(keys="pts", heatmap_keys="hm", ref_image_keys="img", **params) + out = tr(d) + hm = out["hm"] + + self.assertIsInstance(hm, MetaTensor) + self.assertEqual(tuple(hm.shape), expected_shape) + self.assertEqual(hm.dtype, expected_dtype) + + @parameterized.expand(TEST_CASES_SIGMA_VALUES) + def test_dict_various_sigma(self, _, points, params, expected_shape): + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", **params) + result = transform({"points": points}) + heatmap = result["heatmap"] + + self.assertEqual(heatmap.shape, expected_shape) + # Verify heatmap is normalized + np.testing.assert_allclose(heatmap.max(), 1.0, rtol=1e-5) + # Verify no NaN or Inf + self.assertFalse(np.isnan(heatmap).any() or np.isinf(heatmap).any()) + + def test_dict_multiple_keys(self): + """Test dictionary transform with multiple input/output keys""" + points1 = np.array([[2.0, 2.0]], dtype=np.float32) + points2 = np.array([[4.0, 4.0]], dtype=np.float32) + + data = {"pts1": points1, "pts2": points2} + transform = GenerateHeatmapd( + keys=["pts1", "pts2"], heatmap_keys=["hm1", "hm2"], spatial_shape=(8, 8), sigma=1.0 + ) + + result = transform(data) + + self.assertIn("hm1", result) + self.assertIn("hm2", result) + self.assertEqual(result["hm1"].shape, (1, 8, 8)) + self.assertEqual(result["hm2"].shape, (1, 8, 8)) + + # Verify peaks are at different locations + self.assertNotEqual(np.argmax(result["hm1"]), np.argmax(result["hm2"])) + + def test_dict_mismatched_heatmap_keys_length(self): + """Test ValueError when heatmap_keys length doesn't match keys""" + with self.assertRaises(ValueError): + GenerateHeatmapd( + keys=["pts1", "pts2"], + heatmap_keys=["hm1", "hm2", "hm3"], # Mismatch: 3 heatmap keys for 2 input keys + spatial_shape=(8, 8), + ) + + def test_dict_mismatched_ref_image_keys_length(self): + """Test ValueError when ref_image_keys length doesn't match keys""" + with self.assertRaises(ValueError): + GenerateHeatmapd( + keys=["pts1", "pts2"], + heatmap_keys=["hm1", "hm2"], + ref_image_keys=["img1", "img2", "img3"], # Mismatch: 3 ref keys for 2 input keys + spatial_shape=(8, 8), + ) + + def test_dict_per_key_spatial_shape_mismatch(self): + """Test ValueError when per-key spatial_shape length doesn't match keys""" + with self.assertRaises(ValueError): + GenerateHeatmapd( + keys=["pts1", "pts2"], + heatmap_keys=["hm1", "hm2"], + spatial_shape=[(8, 8), (8, 8), (8, 8)], # Mismatch: 3 shapes for 2 keys + sigma=1.0, + ) + + def test_metatensor_points_with_ref(self): + """Test MetaTensor points with reference image - documents current behavior""" + from monai.data import MetaTensor + + # Create MetaTensor points with non-identity affine + points_affine = torch.tensor([[2.0, 0, 0, 0], [0, 2.0, 0, 0], [0, 0, 2.0, 0], [0, 0, 0, 1.0]]) + points = MetaTensor(torch.tensor([[2.5, 2.5, 3.0], [5.0, 5.0, 4.0]], dtype=torch.float32), affine=points_affine) + + # Reference image with identity affine + ref_affine = torch.eye(4) + image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=ref_affine) + image.meta["spatial_shape"] = (8, 8, 8) + + data = {"points": points, "image": image} + transform = GenerateHeatmapd(keys="points", heatmap_keys="heatmap", ref_image_keys="image", sigma=2.0) + result = transform(data) + heatmap = result["heatmap"] + + self.assertIsInstance(heatmap, MetaTensor) + self.assertEqual(tuple(heatmap.shape), (2, 8, 8, 8)) + + # Heatmap should inherit affine from the reference image + assert_allclose(heatmap.affine, image.affine, type_test=False) + + +if __name__ == "__main__": + unittest.main()