1414import torch
1515import torch .nn as nn
1616from lpips import LPIPS
17+ from torchvision .models import ResNet50_Weights , resnet50
18+ from torchvision .models .feature_extraction import create_feature_extractor
1719
1820
1921class PerceptualLoss (nn .Module ):
@@ -22,20 +24,29 @@ class PerceptualLoss(nn.Module):
2224 pretrained on: ImageNet that use the LPIPS approach from Zhang, et al. "The unreasonable effectiveness of deep
2325 features as a perceptual metric." https://arxiv.org/abs/1801.03924 ; RadImagenet from Mei, et al. "RadImageNet: An
2426 Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"
25- https://pubs.rsna.org/doi/full/10.1148/ryai.210315 ; and MedicalNet from Chen et al. "Med3D: Transfer Learning for
26- 3D Medical Image Analysis" https://arxiv.org/abs/1904.00625 .
27+ https://pubs.rsna.org/doi/full/10.1148/ryai.210315 ; MedicalNet from Chen et al. "Med3D: Transfer Learning for
28+ 3D Medical Image Analysis" https://arxiv.org/abs/1904.00625 ;
29+ and ResNet50 from Torchvision: https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html .
2730
2831 The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual on slices from the
2932 three axis.
3033
3134 Args:
3235 spatial_dims: number of spatial dimensions.
3336 network_type: {``"alex"``, ``"vgg"``, ``"squeeze"``, ``"radimagenet_resnet50"``,
34- ``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``}
37+ ``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``, ``"resnet50"`` }
3538 Specifies the network architecture to use. Defaults to ``"alex"``.
3639 is_fake_3d: if True use 2.5D approach for a 3D perceptual loss.
3740 fake_3d_ratio: ratio of how many slices per axis are used in the 2.5D approach.
3841 cache_dir: path to cache directory to save the pretrained network weights.
42+ pretrained: whether to load pretrained weights. This argument only works when using networks from
43+ LIPIS or Torchvision. Defaults to ``"True"``.
44+ pretrained_path: if `pretrained` is `True`, users can specify a weights file to be loaded
45+ via using this argument. This argument only works when ``"network_type"`` is "resnet50".
46+ Defaults to `None`.
47+ pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to
48+ extract the expected state dict. This argument only works when ``"network_type"`` is "resnet50".
49+ Defaults to `None`.
3950 """
4051
4152 def __init__ (
@@ -45,6 +56,9 @@ def __init__(
4556 is_fake_3d : bool = True ,
4657 fake_3d_ratio : float = 0.5 ,
4758 cache_dir : str | None = None ,
59+ pretrained : bool = True ,
60+ pretrained_path : str | None = None ,
61+ pretrained_state_dict_key : str | None = None ,
4862 ):
4963 super ().__init__ ()
5064
@@ -65,8 +79,15 @@ def __init__(
6579 self .perceptual_function = MedicalNetPerceptualSimilarity (net = network_type , verbose = False )
6680 elif "radimagenet_" in network_type :
6781 self .perceptual_function = RadImageNetPerceptualSimilarity (net = network_type , verbose = False )
82+ elif network_type == "resnet50" :
83+ self .perceptual_function = TorchvisionModelPerceptualSimilarity (
84+ net = network_type ,
85+ pretrained = pretrained ,
86+ pretrained_path = pretrained_path ,
87+ pretrained_state_dict_key = pretrained_state_dict_key ,
88+ )
6889 else :
69- self .perceptual_function = LPIPS (pretrained = True , net = network_type , verbose = False )
90+ self .perceptual_function = LPIPS (pretrained = pretrained , net = network_type , verbose = False )
7091 self .is_fake_3d = is_fake_3d
7192 self .fake_3d_ratio = fake_3d_ratio
7293
@@ -247,10 +268,95 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
247268 return results
248269
249270
271+ class TorchvisionModelPerceptualSimilarity (nn .Module ):
272+ """
273+ Component to perform the perceptual evaluation with TorchVision models.
274+ Currently, only ResNet50 is supported. The network structure is based on:
275+ https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html
276+
277+ Args:
278+ net: {``"resnet50"``}
279+ Specifies the network architecture to use. Defaults to ``"resnet50"``.
280+ pretrained: whether to load pretrained weights. Defaults to `True`.
281+ pretrained_path: if `pretrained` is `True`, users can specify a weights file to be loaded
282+ via using this argument. Defaults to `None`.
283+ pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to
284+ extract the expected state dict. Defaults to `None`.
285+ """
286+
287+ def __init__ (
288+ self ,
289+ net : str = "resnet50" ,
290+ pretrained : bool = True ,
291+ pretrained_path : str | None = None ,
292+ pretrained_state_dict_key : str | None = None ,
293+ ) -> None :
294+ super ().__init__ ()
295+ supported_networks = ["resnet50" ]
296+ if net not in supported_networks :
297+ raise NotImplementedError (
298+ f"'net' { net } is not supported, please select a network from { supported_networks } ."
299+ )
300+
301+ if pretrained_path is None :
302+ network = resnet50 (weights = ResNet50_Weights .DEFAULT if pretrained else None )
303+ else :
304+ network = resnet50 (weights = None )
305+ if pretrained is True :
306+ state_dict = torch .load (pretrained_path )
307+ if pretrained_state_dict_key is not None :
308+ state_dict = state_dict [pretrained_state_dict_key ]
309+ network .load_state_dict (state_dict )
310+ self .final_layer = "layer4.2.relu_2"
311+ self .model = create_feature_extractor (network , [self .final_layer ])
312+ self .eval ()
313+
314+ for param in self .parameters ():
315+ param .requires_grad = False
316+
317+ def forward (self , input : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
318+ """
319+ We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at
320+ https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html#torchvision.models.ResNet50_Weights,
321+ we make sure that the input and target have 3 channels, and then do Z-Score normalization.
322+ The outputs are normalised across the channels, and we obtain the mean from the spatial dimensions (similar approach to the lpips package).
323+ """
324+ # If input has just 1 channel, repeat channel to have 3 channels
325+ if input .shape [1 ] == 1 and target .shape [1 ] == 1 :
326+ input = input .repeat (1 , 3 , 1 , 1 )
327+ target = target .repeat (1 , 3 , 1 , 1 )
328+
329+ # Input normalization
330+ input = torchvision_zscore_norm (input )
331+ target = torchvision_zscore_norm (target )
332+
333+ # Get model outputs
334+ outs_input = self .model .forward (input )[self .final_layer ]
335+ outs_target = self .model .forward (target )[self .final_layer ]
336+
337+ # Normalise through the channels
338+ feats_input = normalize_tensor (outs_input )
339+ feats_target = normalize_tensor (outs_target )
340+
341+ results = (feats_input - feats_target ) ** 2
342+ results = spatial_average (results .sum (dim = 1 , keepdim = True ), keepdim = True )
343+
344+ return results
345+
346+
250347def spatial_average (x : torch .Tensor , keepdim : bool = True ) -> torch .Tensor :
251348 return x .mean ([2 , 3 ], keepdim = keepdim )
252349
253350
351+ def torchvision_zscore_norm (x : torch .Tensor ) -> torch .Tensor :
352+ mean = [0.485 , 0.456 , 0.406 ]
353+ std = [0.229 , 0.224 , 0.225 ]
354+ x [:, 0 , :, :] = (x [:, 0 , :, :] - mean [0 ]) / std [0 ]
355+ x [:, 1 , :, :] = (x [:, 1 , :, :] - mean [1 ]) / std [1 ]
356+ x [:, 2 , :, :] = (x [:, 2 , :, :] - mean [2 ]) / std [2 ]
357+ return x
358+
359+
254360def subtract_mean (x : torch .Tensor ) -> torch .Tensor :
255361 mean = [0.406 , 0.456 , 0.485 ]
256362 x [:, 0 , :, :] -= mean [0 ]
0 commit comments