Skip to content

Commit aaf2833

Browse files
committed
address the code review.
Signed-off-by: sewon.jeon <sewon.jeon@connecteve.com>
1 parent 9f10dcf commit aaf2833

File tree

3 files changed

+42
-22
lines changed

3 files changed

+42
-22
lines changed

monai/transforms/post/array.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import warnings
1818
from collections.abc import Callable, Iterable, Sequence
19+
from typing import ClassVar
1920

2021
import numpy as np
2122
import torch
@@ -766,7 +767,7 @@ class GenerateHeatmap(Transform):
766767
767768
"""
768769

769-
backend = [TransformBackends.NUMPY, TransformBackends.TORCH]
770+
backend: ClassVar[list] = [TransformBackends.NUMPY, TransformBackends.TORCH]
770771

771772
def __init__(
772773
self,
@@ -862,7 +863,10 @@ def _resolve_sigma(self, spatial_dims: int) -> tuple[float, ...]:
862863

863864
@staticmethod
864865
def _is_inside(center: Sequence[float], bounds: tuple[int, ...]) -> bool:
865-
return all(0 <= c < size for c, size in zip(center, bounds))
866+
for c, size in zip(center, bounds):
867+
if not (0 <= c < size):
868+
return False
869+
return True
866870

867871
def _make_window(
868872
self, center: Sequence[float], radius: tuple[int, ...], bounds: tuple[int, ...], device: torch.device
@@ -879,6 +883,16 @@ def _make_window(
879883
return tuple(slices), tuple(coord_shifts)
880884

881885
def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tuple[float, ...]) -> torch.Tensor:
886+
"""
887+
Evaluate Gaussian at given coordinate shifts with specified sigmas.
888+
889+
Args:
890+
coord_shifts: Per-dimension coordinate offsets from center.
891+
sigma: Per-dimension standard deviations.
892+
893+
Returns:
894+
Gaussian values at the specified coordinates.
895+
"""
882896
device = coord_shifts[0].device
883897
shape = tuple(len(axis) for axis in coord_shifts)
884898
if 0 in shape:

monai/transforms/post/dictionary.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,14 @@ class GenerateHeatmapd(MapTransform):
521521

522522
backend = GenerateHeatmap.backend
523523

524+
# Error messages
525+
_ERR_HEATMAP_KEYS_LEN = "heatmap_keys length must match keys length."
526+
_ERR_REF_KEYS_LEN = "ref_image_keys length must match keys length when provided."
527+
_ERR_SHAPE_LEN = "spatial_shape length must match keys length when providing per-key shapes."
528+
_ERR_NO_SHAPE = "Unable to determine spatial shape for GenerateHeatmapd. Provide spatial_shape or ref_image_keys."
529+
_ERR_INVALID_POINTS = "landmark arrays must be 2D or 3D with shape (N, D) or (B, N, D)."
530+
_ERR_REF_NO_SHAPE = "Reference data must define a shape attribute."
531+
524532
def __init__(
525533
self,
526534
keys: KeysCollection,
@@ -570,7 +578,7 @@ def _prepare_heatmap_keys(self, heatmap_keys: KeysCollection | None) -> tuple[Ha
570578
if len(keys_tuple) == 1 and len(self.keys) > 1:
571579
keys_tuple = keys_tuple * len(self.keys)
572580
if len(keys_tuple) != len(self.keys):
573-
raise ValueError("heatmap_keys length must match keys length.")
581+
raise ValueError(self._ERR_HEATMAP_KEYS_LEN)
574582
return keys_tuple
575583

576584
def _prepare_optional_keys(self, maybe_keys: KeysCollection | None) -> tuple[Hashable | None, ...]:
@@ -580,7 +588,7 @@ def _prepare_optional_keys(self, maybe_keys: KeysCollection | None) -> tuple[Has
580588
if len(keys_tuple) == 1 and len(self.keys) > 1:
581589
keys_tuple = keys_tuple * len(self.keys)
582590
if len(keys_tuple) != len(self.keys):
583-
raise ValueError("ref_image_keys length must match keys length when provided.")
591+
raise ValueError(self._ERR_REF_KEYS_LEN)
584592
return tuple(keys_tuple)
585593

586594
def _prepare_shapes(
@@ -595,7 +603,7 @@ def _prepare_shapes(
595603
if len(shape_tuple) == 1 and len(self.keys) > 1:
596604
shape_tuple = shape_tuple * len(self.keys)
597605
if len(shape_tuple) != len(self.keys):
598-
raise ValueError("spatial_shape length must match keys length when providing per-key shapes.")
606+
raise ValueError(self._ERR_SHAPE_LEN)
599607
prepared: list[tuple[int, ...] | None] = []
600608
for item in shape_tuple:
601609
if item is None:
@@ -612,13 +620,11 @@ def _determine_shape(
612620
return static_shape
613621
points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)
614622
if points_t.ndim not in (2, 3):
615-
raise ValueError("landmark arrays must be 2D or 3D with shape (N, D) or (B, N, D).")
623+
raise ValueError(self._ERR_INVALID_POINTS)
616624
spatial_dims = int(points_t.shape[-1])
617625
if ref_key is not None and ref_key in data:
618626
return self._shape_from_reference(data[ref_key], spatial_dims)
619-
raise ValueError(
620-
"Unable to determine spatial shape for GenerateHeatmapd. Provide spatial_shape or ref_image_keys."
621-
)
627+
raise ValueError(self._ERR_NO_SHAPE)
622628

623629
def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int, ...]:
624630
if isinstance(reference, MetaTensor):
@@ -630,23 +636,23 @@ def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int,
630636
return tuple(int(v) for v in reference.shape[-spatial_dims:])
631637
if hasattr(reference, "shape"):
632638
return tuple(int(v) for v in reference.shape[-spatial_dims:])
633-
raise ValueError("Reference data must define a shape attribute.")
639+
raise ValueError(self._ERR_REF_NO_SHAPE)
634640

635641
def _update_spatial_metadata(self, heatmap: MetaTensor, reference: MetaTensor) -> None:
636642
"""Update spatial metadata of heatmap based on its dimensions."""
637-
# Update spatial_shape metadata based on heatmap dimensions
643+
# Determine if batched based on reference's batch dimension
644+
ref_spatial_shape = reference.meta.get("spatial_shape", [])
645+
ref_is_batched = len(reference.shape) > len(ref_spatial_shape) + 1
646+
638647
if heatmap.ndim == 5: # 3D batched: (B, C, H, W, D)
639-
heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:])
648+
spatial_shape = heatmap.shape[2:]
640649
elif heatmap.ndim == 4: # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D)
641-
# Need to check if this is batched 2D or non-batched 3D
642-
if len(heatmap.shape[1:]) == len(reference.meta.get("spatial_shape", [])):
643-
# Non-batched 3D
644-
heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:])
645-
else:
646-
# Batched 2D
647-
heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[2:])
650+
# Disambiguate: 2D batched vs 3D non-batched
651+
spatial_shape = heatmap.shape[2:] if ref_is_batched else heatmap.shape[1:]
648652
else: # 2D non-batched: (C, H, W)
649-
heatmap.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:])
653+
spatial_shape = heatmap.shape[1:]
654+
655+
heatmap.meta["spatial_shape"] = tuple(int(v) for v in spatial_shape)
650656

651657

652658
GenerateHeatmapD = GenerateHeatmapDict = GenerateHeatmapd

tests/transforms/test_generate_heatmapd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898

9999
class TestGenerateHeatmapd(unittest.TestCase):
100100
@parameterized.expand(TEST_CASES_WITH_REF)
101-
def test_dict_with_reference_meta(self, _, points, params, expected_shape, expected_dtype, uses_ref):
101+
def test_dict_with_reference_meta(self, _, points, params, expected_shape, *_unused):
102102
affine = torch.eye(4)
103103
image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine)
104104
image.meta["spatial_shape"] = (8, 8, 8)
@@ -148,7 +148,7 @@ def test_dict_dtype_control(self, _, points, params, expected_shape, expected_dt
148148
self.assertEqual(hm.dtype, expected_dtype)
149149

150150
@parameterized.expand(TEST_CASES_BATCHED)
151-
def test_dict_batched_with_ref(self, _, points, params, expected_shape, expected_dtype):
151+
def test_dict_batched_with_ref(self, _, points, params, expected_shape, _expected_dtype):
152152
affine = torch.eye(4)
153153
# A single reference image is used for the whole batch
154154
image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine)

0 commit comments

Comments
 (0)