Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit f969c24

Browse files
authored
Merge pull request #390 from yiheng-wang-nv/389-add-resnet50-support
389 add torchvision resnet50 support
2 parents fd04ec6 + eb01aee commit f969c24

File tree

2 files changed

+116
-6
lines changed

2 files changed

+116
-6
lines changed

generative/losses/perceptual.py

Lines changed: 110 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import torch
1515
import torch.nn as nn
1616
from lpips import LPIPS
17+
from torchvision.models import ResNet50_Weights, resnet50
18+
from torchvision.models.feature_extraction import create_feature_extractor
1719

1820

1921
class PerceptualLoss(nn.Module):
@@ -22,20 +24,29 @@ class PerceptualLoss(nn.Module):
2224
pretrained on: ImageNet that use the LPIPS approach from Zhang, et al. "The unreasonable effectiveness of deep
2325
features as a perceptual metric." https://arxiv.org/abs/1801.03924 ; RadImagenet from Mei, et al. "RadImageNet: An
2426
Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"
25-
https://pubs.rsna.org/doi/full/10.1148/ryai.210315 ; and MedicalNet from Chen et al. "Med3D: Transfer Learning for
26-
3D Medical Image Analysis" https://arxiv.org/abs/1904.00625 .
27+
https://pubs.rsna.org/doi/full/10.1148/ryai.210315 ; MedicalNet from Chen et al. "Med3D: Transfer Learning for
28+
3D Medical Image Analysis" https://arxiv.org/abs/1904.00625 ;
29+
and ResNet50 from Torchvision: https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html .
2730
2831
The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual on slices from the
2932
three axis.
3033
3134
Args:
3235
spatial_dims: number of spatial dimensions.
3336
network_type: {``"alex"``, ``"vgg"``, ``"squeeze"``, ``"radimagenet_resnet50"``,
34-
``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``}
37+
``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``, ``"resnet50"``}
3538
Specifies the network architecture to use. Defaults to ``"alex"``.
3639
is_fake_3d: if True use 2.5D approach for a 3D perceptual loss.
3740
fake_3d_ratio: ratio of how many slices per axis are used in the 2.5D approach.
3841
cache_dir: path to cache directory to save the pretrained network weights.
42+
pretrained: whether to load pretrained weights. This argument only works when using networks from
43+
LIPIS or Torchvision. Defaults to ``"True"``.
44+
pretrained_path: if `pretrained` is `True`, users can specify a weights file to be loaded
45+
via using this argument. This argument only works when ``"network_type"`` is "resnet50".
46+
Defaults to `None`.
47+
pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to
48+
extract the expected state dict. This argument only works when ``"network_type"`` is "resnet50".
49+
Defaults to `None`.
3950
"""
4051

4152
def __init__(
@@ -45,6 +56,9 @@ def __init__(
4556
is_fake_3d: bool = True,
4657
fake_3d_ratio: float = 0.5,
4758
cache_dir: str | None = None,
59+
pretrained: bool = True,
60+
pretrained_path: str | None = None,
61+
pretrained_state_dict_key: str | None = None,
4862
):
4963
super().__init__()
5064

@@ -65,8 +79,15 @@ def __init__(
6579
self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False)
6680
elif "radimagenet_" in network_type:
6781
self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False)
82+
elif network_type == "resnet50":
83+
self.perceptual_function = TorchvisionModelPerceptualSimilarity(
84+
net=network_type,
85+
pretrained=pretrained,
86+
pretrained_path=pretrained_path,
87+
pretrained_state_dict_key=pretrained_state_dict_key,
88+
)
6889
else:
69-
self.perceptual_function = LPIPS(pretrained=True, net=network_type, verbose=False)
90+
self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False)
7091
self.is_fake_3d = is_fake_3d
7192
self.fake_3d_ratio = fake_3d_ratio
7293

@@ -247,10 +268,95 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
247268
return results
248269

249270

271+
class TorchvisionModelPerceptualSimilarity(nn.Module):
272+
"""
273+
Component to perform the perceptual evaluation with TorchVision models.
274+
Currently, only ResNet50 is supported. The network structure is based on:
275+
https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html
276+
277+
Args:
278+
net: {``"resnet50"``}
279+
Specifies the network architecture to use. Defaults to ``"resnet50"``.
280+
pretrained: whether to load pretrained weights. Defaults to `True`.
281+
pretrained_path: if `pretrained` is `True`, users can specify a weights file to be loaded
282+
via using this argument. Defaults to `None`.
283+
pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to
284+
extract the expected state dict. Defaults to `None`.
285+
"""
286+
287+
def __init__(
288+
self,
289+
net: str = "resnet50",
290+
pretrained: bool = True,
291+
pretrained_path: str | None = None,
292+
pretrained_state_dict_key: str | None = None,
293+
) -> None:
294+
super().__init__()
295+
supported_networks = ["resnet50"]
296+
if net not in supported_networks:
297+
raise NotImplementedError(
298+
f"'net' {net} is not supported, please select a network from {supported_networks}."
299+
)
300+
301+
if pretrained_path is None:
302+
network = resnet50(weights=ResNet50_Weights.DEFAULT if pretrained else None)
303+
else:
304+
network = resnet50(weights=None)
305+
if pretrained is True:
306+
state_dict = torch.load(pretrained_path)
307+
if pretrained_state_dict_key is not None:
308+
state_dict = state_dict[pretrained_state_dict_key]
309+
network.load_state_dict(state_dict)
310+
self.final_layer = "layer4.2.relu_2"
311+
self.model = create_feature_extractor(network, [self.final_layer])
312+
self.eval()
313+
314+
for param in self.parameters():
315+
param.requires_grad = False
316+
317+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
318+
"""
319+
We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at
320+
https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html#torchvision.models.ResNet50_Weights,
321+
we make sure that the input and target have 3 channels, and then do Z-Score normalization.
322+
The outputs are normalised across the channels, and we obtain the mean from the spatial dimensions (similar approach to the lpips package).
323+
"""
324+
# If input has just 1 channel, repeat channel to have 3 channels
325+
if input.shape[1] == 1 and target.shape[1] == 1:
326+
input = input.repeat(1, 3, 1, 1)
327+
target = target.repeat(1, 3, 1, 1)
328+
329+
# Input normalization
330+
input = torchvision_zscore_norm(input)
331+
target = torchvision_zscore_norm(target)
332+
333+
# Get model outputs
334+
outs_input = self.model.forward(input)[self.final_layer]
335+
outs_target = self.model.forward(target)[self.final_layer]
336+
337+
# Normalise through the channels
338+
feats_input = normalize_tensor(outs_input)
339+
feats_target = normalize_tensor(outs_target)
340+
341+
results = (feats_input - feats_target) ** 2
342+
results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True)
343+
344+
return results
345+
346+
250347
def spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor:
251348
return x.mean([2, 3], keepdim=keepdim)
252349

253350

351+
def torchvision_zscore_norm(x: torch.Tensor) -> torch.Tensor:
352+
mean = [0.485, 0.456, 0.406]
353+
std = [0.229, 0.224, 0.225]
354+
x[:, 0, :, :] = (x[:, 0, :, :] - mean[0]) / std[0]
355+
x[:, 1, :, :] = (x[:, 1, :, :] - mean[1]) / std[1]
356+
x[:, 2, :, :] = (x[:, 2, :, :] - mean[2]) / std[2]
357+
return x
358+
359+
254360
def subtract_mean(x: torch.Tensor) -> torch.Tensor:
255361
mean = [0.406, 0.456, 0.485]
256362
x[:, 0, :, :] -= mean[0]

tests/test_perceptual_loss.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414
import unittest
1515

1616
import torch
17-
from parameterized import parameterized
18-
1917
from generative.losses import PerceptualLoss
18+
from parameterized import parameterized
2019

2120
TEST_CASES = [
2221
[{"spatial_dims": 2, "network_type": "squeeze"}, (2, 1, 64, 64), (2, 1, 64, 64)],
@@ -37,6 +36,11 @@
3736
(2, 1, 64, 64, 64),
3837
(2, 1, 64, 64, 64),
3938
],
39+
[
40+
{"spatial_dims": 3, "network_type": "resnet50", "is_fake_3d": True, "pretrained": True, "fake_3d_ratio": 0.2},
41+
(2, 1, 64, 64, 64),
42+
(2, 1, 64, 64, 64),
43+
],
4044
]
4145

4246

0 commit comments

Comments
 (0)