@@ -521,6 +521,14 @@ class GenerateHeatmapd(MapTransform):
521521
522522 backend = GenerateHeatmap .backend
523523
524+ # Error messages
525+ _ERR_HEATMAP_KEYS_LEN = "heatmap_keys length must match keys length."
526+ _ERR_REF_KEYS_LEN = "ref_image_keys length must match keys length when provided."
527+ _ERR_SHAPE_LEN = "spatial_shape length must match keys length when providing per-key shapes."
528+ _ERR_NO_SHAPE = "Unable to determine spatial shape for GenerateHeatmapd. Provide spatial_shape or ref_image_keys."
529+ _ERR_INVALID_POINTS = "landmark arrays must be 2D or 3D with shape (N, D) or (B, N, D)."
530+ _ERR_REF_NO_SHAPE = "Reference data must define a shape attribute."
531+
524532 def __init__ (
525533 self ,
526534 keys : KeysCollection ,
@@ -570,7 +578,7 @@ def _prepare_heatmap_keys(self, heatmap_keys: KeysCollection | None) -> tuple[Ha
570578 if len (keys_tuple ) == 1 and len (self .keys ) > 1 :
571579 keys_tuple = keys_tuple * len (self .keys )
572580 if len (keys_tuple ) != len (self .keys ):
573- raise ValueError ("heatmap_keys length must match keys length." )
581+ raise ValueError (self . _ERR_HEATMAP_KEYS_LEN )
574582 return keys_tuple
575583
576584 def _prepare_optional_keys (self , maybe_keys : KeysCollection | None ) -> tuple [Hashable | None , ...]:
@@ -580,7 +588,7 @@ def _prepare_optional_keys(self, maybe_keys: KeysCollection | None) -> tuple[Has
580588 if len (keys_tuple ) == 1 and len (self .keys ) > 1 :
581589 keys_tuple = keys_tuple * len (self .keys )
582590 if len (keys_tuple ) != len (self .keys ):
583- raise ValueError ("ref_image_keys length must match keys length when provided." )
591+ raise ValueError (self . _ERR_REF_KEYS_LEN )
584592 return tuple (keys_tuple )
585593
586594 def _prepare_shapes (
@@ -595,7 +603,7 @@ def _prepare_shapes(
595603 if len (shape_tuple ) == 1 and len (self .keys ) > 1 :
596604 shape_tuple = shape_tuple * len (self .keys )
597605 if len (shape_tuple ) != len (self .keys ):
598- raise ValueError ("spatial_shape length must match keys length when providing per-key shapes." )
606+ raise ValueError (self . _ERR_SHAPE_LEN )
599607 prepared : list [tuple [int , ...] | None ] = []
600608 for item in shape_tuple :
601609 if item is None :
@@ -612,13 +620,11 @@ def _determine_shape(
612620 return static_shape
613621 points_t = convert_to_tensor (points , dtype = torch .float32 , track_meta = False )
614622 if points_t .ndim not in (2 , 3 ):
615- raise ValueError ("landmark arrays must be 2D or 3D with shape (N, D) or (B, N, D)." )
623+ raise ValueError (self . _ERR_INVALID_POINTS )
616624 spatial_dims = int (points_t .shape [- 1 ])
617625 if ref_key is not None and ref_key in data :
618626 return self ._shape_from_reference (data [ref_key ], spatial_dims )
619- raise ValueError (
620- "Unable to determine spatial shape for GenerateHeatmapd. Provide spatial_shape or ref_image_keys."
621- )
627+ raise ValueError (self ._ERR_NO_SHAPE )
622628
623629 def _shape_from_reference (self , reference : Any , spatial_dims : int ) -> tuple [int , ...]:
624630 if isinstance (reference , MetaTensor ):
@@ -630,23 +636,23 @@ def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int,
630636 return tuple (int (v ) for v in reference .shape [- spatial_dims :])
631637 if hasattr (reference , "shape" ):
632638 return tuple (int (v ) for v in reference .shape [- spatial_dims :])
633- raise ValueError ("Reference data must define a shape attribute." )
639+ raise ValueError (self . _ERR_REF_NO_SHAPE )
634640
635641 def _update_spatial_metadata (self , heatmap : MetaTensor , reference : MetaTensor ) -> None :
636642 """Update spatial metadata of heatmap based on its dimensions."""
637- # Update spatial_shape metadata based on heatmap dimensions
643+ # Determine if batched based on reference's batch dimension
644+ ref_spatial_shape = reference .meta .get ("spatial_shape" , [])
645+ ref_is_batched = len (reference .shape ) > len (ref_spatial_shape ) + 1
646+
638647 if heatmap .ndim == 5 : # 3D batched: (B, C, H, W, D)
639- heatmap . meta [ " spatial_shape" ] = tuple ( int ( v ) for v in heatmap .shape [2 :])
648+ spatial_shape = heatmap .shape [2 :]
640649 elif heatmap .ndim == 4 : # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D)
641- # Need to check if this is batched 2D or non-batched 3D
642- if len (heatmap .shape [1 :]) == len (reference .meta .get ("spatial_shape" , [])):
643- # Non-batched 3D
644- heatmap .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [1 :])
645- else :
646- # Batched 2D
647- heatmap .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [2 :])
650+ # Disambiguate: 2D batched vs 3D non-batched
651+ spatial_shape = heatmap .shape [2 :] if ref_is_batched else heatmap .shape [1 :]
648652 else : # 2D non-batched: (C, H, W)
649- heatmap .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [1 :])
653+ spatial_shape = heatmap .shape [1 :]
654+
655+ heatmap .meta ["spatial_shape" ] = tuple (int (v ) for v in spatial_shape )
650656
651657
652658GenerateHeatmapD = GenerateHeatmapDict = GenerateHeatmapd
0 commit comments