@@ -72,14 +72,7 @@ def test_array_torch_device_and_dtype_propagation(self):
7272
7373 def test_array_channel_order_identity (self ):
7474 # ensure the order of channels follows the order of input points
75- pts = np .array (
76- [
77- [2.0 , 2.0 ], # point A
78- [12.0 , 2.0 ], # point B
79- [2.0 , 12.0 ], # point C
80- ],
81- dtype = np .float32 ,
82- )
75+ pts = np .array ([[2.0 , 2.0 ], [12.0 , 2.0 ], [2.0 , 12.0 ]], dtype = np .float32 ) # point A # point B # point C
8376 hm = GenerateHeatmap (sigma = 1.2 , spatial_shape = (16 , 16 ))(pts )
8477 self .assertEqual (hm .shape , (3 , 16 , 16 ))
8578
@@ -90,11 +83,7 @@ def test_array_channel_order_identity(self):
9083 def test_array_points_out_of_bounds (self ):
9184 # points outside spatial domain: heatmap should still be valid (no NaN/Inf) and not all-zeros
9285 pts = np .array (
93- [
94- [- 5.0 , - 5.0 ], # outside top-left
95- [100.0 , 100.0 ], # outside bottom-right
96- [8.0 , 8.0 ], # inside
97- ],
86+ [[- 5.0 , - 5.0 ], [100.0 , 100.0 ], [8.0 , 8.0 ]], # outside top-left # outside bottom-right # inside
9887 dtype = np .float32 ,
9988 )
10089 hm = GenerateHeatmap (sigma = 2.0 , spatial_shape = (16 , 16 ))(pts )
@@ -118,12 +107,7 @@ def test_dict_with_reference_meta(self):
118107 image .meta ["spatial_shape" ] = (8 , 8 , 8 )
119108 data = {"points" : points , "image" : image }
120109
121- transform = GenerateHeatmapd (
122- keys = "points" ,
123- heatmap_keys = "heatmap" ,
124- ref_image_keys = "image" ,
125- sigma = 2.0 ,
126- )
110+ transform = GenerateHeatmapd (keys = "points" , heatmap_keys = "heatmap" , ref_image_keys = "image" , sigma = 2.0 )
127111
128112 result = transform (data )
129113 heatmap = result ["heatmap" ]
@@ -172,13 +156,7 @@ def test_dict_dtype_control(self):
172156 self .assertEqual (hm .dtype , torch .float16 )
173157
174158 def test_array_batched_3d (self ):
175- points = np .array (
176- [
177- [[4.2 , 7.8 , 1.0 ]], # Batch 1
178- [[12.3 , 3.6 , 2.0 ]], # Batch 2
179- ],
180- dtype = np .float32 ,
181- )
159+ points = np .array ([[[4.2 , 7.8 , 1.0 ]], [[12.3 , 3.6 , 2.0 ]]], dtype = np .float32 ) # Batch 1 # Batch 2
182160 transform = GenerateHeatmap (sigma = 1.5 , spatial_shape = (16 , 16 , 16 ))
183161
184162 heatmap = transform (points )
@@ -193,25 +171,14 @@ def test_array_batched_3d(self):
193171 self .assertTrue (np .all (np .abs (peak - points [i , 0 ]) <= 1.0 ), msg = f"peak={ peak } , point={ points [i , 0 ]} " )
194172
195173 def test_dict_batched_with_ref (self ):
196- points = torch .tensor (
197- [
198- [[1.5 , 2.5 , 3.5 ]], # Batch 1
199- [[4.5 , 5.5 , 6.5 ]], # Batch 2
200- ],
201- dtype = torch .float32 ,
202- )
174+ points = torch .tensor ([[[1.5 , 2.5 , 3.5 ]], [[4.5 , 5.5 , 6.5 ]]], dtype = torch .float32 ) # Batch 1 # Batch 2
203175 affine = torch .eye (4 )
204176 # A single reference image is used for the whole batch
205177 image = MetaTensor (torch .zeros ((1 , 8 , 8 , 8 ), dtype = torch .float32 ), affine = affine )
206178 image .meta ["spatial_shape" ] = (8 , 8 , 8 )
207179 data = {"points" : points , "image" : image }
208180
209- transform = GenerateHeatmapd (
210- keys = "points" ,
211- heatmap_keys = "heatmap" ,
212- ref_image_keys = "image" ,
213- sigma = 1.0 ,
214- )
181+ transform = GenerateHeatmapd (keys = "points" , heatmap_keys = "heatmap" , ref_image_keys = "image" , sigma = 1.0 )
215182
216183 result = transform (data )
217184 heatmap = result ["heatmap" ]
0 commit comments