Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 9b209eb

Browse files
committed
Run tests
Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com>
1 parent f969c24 commit 9b209eb

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

generative/losses/perceptual.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
319319
We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at
320320
https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html#torchvision.models.ResNet50_Weights,
321321
we make sure that the input and target have 3 channels, and then do Z-Score normalization.
322-
The outputs are normalised across the channels, and we obtain the mean from the spatial dimensions (similar approach to the lpips package).
322+
The outputs are normalised across the channels, and we obtain the mean from the spatial dimensions (similar
323+
approach to the lpips package).
323324
"""
324325
# If input has just 1 channel, repeat channel to have 3 channels
325326
if input.shape[1] == 1 and target.shape[1] == 1:

tests/test_perceptual_loss.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
import unittest
1515

1616
import torch
17-
from generative.losses import PerceptualLoss
1817
from parameterized import parameterized
1918

19+
from generative.losses import PerceptualLoss
20+
2021
TEST_CASES = [
2122
[{"spatial_dims": 2, "network_type": "squeeze"}, (2, 1, 64, 64), (2, 1, 64, 64)],
2223
[

0 commit comments

Comments
 (0)