@@ -548,9 +548,19 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
548548 ):
549549 points = d [key ]
550550 shape = self ._determine_shape (points , static_shape , d , ref_key )
551+ # The GenerateHeatmap transform will handle type conversion based on input points
551552 heatmap = self .generator (points , spatial_shape = shape )
553+ # If there's a reference image and we need to match its type/device
552554 reference = d .get (ref_key ) if ref_key is not None and ref_key in d else None
553- d [out_key ] = self ._prepare_output (heatmap , reference )
555+ if reference is not None and isinstance (reference , (torch .Tensor , np .ndarray )):
556+ # Convert to match reference type and device while preserving heatmap's dtype
557+ heatmap , _ , _ = convert_to_dst_type (
558+ heatmap , reference , dtype = heatmap .dtype , device = getattr (reference , "device" , None )
559+ )
560+ # Copy metadata if reference is MetaTensor
561+ if isinstance (reference , MetaTensor ) and isinstance (heatmap , MetaTensor ):
562+ self ._update_spatial_metadata (heatmap , reference )
563+ d [out_key ] = heatmap
554564 return d
555565
556566 def _prepare_heatmap_keys (self , heatmap_keys : KeysCollection | None ) -> tuple [Hashable , ...]:
@@ -622,29 +632,21 @@ def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int,
622632 return tuple (int (v ) for v in reference .shape [- spatial_dims :])
623633 raise ValueError ("Reference data must define a shape attribute." )
624634
625- def _prepare_output (self , heatmap : NdarrayOrTensor , reference : Any ) -> Any :
626- if isinstance (reference , MetaTensor ):
627- # Use heatmap's dtype (from generator), not reference's dtype
628- converted , _ , _ = convert_to_dst_type (heatmap , reference , dtype = heatmap .dtype , device = reference .device )
629- # For batched data shape is (B, C, *spatial), for non-batched it's (C, *spatial)
630- if heatmap .ndim == 5 : # 3D batched: (B, C, H, W, D)
631- converted .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [2 :])
632- elif heatmap .ndim == 4 : # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D)
633- # Need to check if this is batched 2D or non-batched 3D
634- if len (heatmap .shape [1 :]) == len (reference .meta .get ("spatial_shape" , [])):
635- # Non-batched 3D
636- converted .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [1 :])
637- else :
638- # Batched 2D
639- converted .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [2 :])
640- else : # 2D non-batched: (C, H, W)
641- converted .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [1 :])
642- return converted
643- if isinstance (reference , torch .Tensor ):
644- # Use heatmap's dtype (from generator), not reference's dtype
645- converted , _ , _ = convert_to_dst_type (heatmap , reference , dtype = heatmap .dtype , device = reference .device )
646- return converted
647- return heatmap
635+ def _update_spatial_metadata (self , heatmap : MetaTensor , reference : MetaTensor ) -> None :
636+ """Update spatial metadata of heatmap based on its dimensions."""
637+ # Update spatial_shape metadata based on heatmap dimensions
638+ if heatmap .ndim == 5 : # 3D batched: (B, C, H, W, D)
639+ heatmap .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [2 :])
640+ elif heatmap .ndim == 4 : # 2D batched (B, C, H, W) or 3D non-batched (C, H, W, D)
641+ # Need to check if this is batched 2D or non-batched 3D
642+ if len (heatmap .shape [1 :]) == len (reference .meta .get ("spatial_shape" , [])):
643+ # Non-batched 3D
644+ heatmap .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [1 :])
645+ else :
646+ # Batched 2D
647+ heatmap .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [2 :])
648+ else : # 2D non-batched: (C, H, W)
649+ heatmap .meta ["spatial_shape" ] = tuple (int (v ) for v in heatmap .shape [1 :])
648650
649651
650652GenerateHeatmapD = GenerateHeatmapDict = GenerateHeatmapd
0 commit comments