Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions pytorch_msssim/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def _ssim(
data_range: float,
win: Tensor,
size_average: bool = True,
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03)
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03),
mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
r""" Calculate ssim index for X and Y

Expand All @@ -70,6 +71,7 @@ def _ssim(
data_range (float or int): value range of input images. (usually 1.0 or 255)
win (torch.Tensor): 1-D gauss kernel
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
mask (torch.Tensor): boolean mask same size as X and Y

Returns:
Tuple[torch.Tensor, torch.Tensor]: ssim results.
Expand Down Expand Up @@ -97,6 +99,10 @@ def _ssim(
cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1
ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map

if mask is not None:
ssim = torch.masked_select(ssim_map, mask).mean()
cs = torch.masked_select(cs_map, mask).mean()
return ssim, cs
ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1)
cs = torch.flatten(cs_map, 2).mean(-1)
return ssim_per_channel, cs
Expand All @@ -112,6 +118,7 @@ def ssim(
win: Optional[Tensor] = None,
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03),
nonnegative_ssim: bool = False,
mask: Optional[Tensor] = None,
) -> Tensor:
r""" interface of ssim
Args:
Expand All @@ -124,16 +131,27 @@ def ssim(
win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu
mask (torch.Tensor): boolean mask same size as X and Y

Returns:
torch.Tensor: ssim results
"""
if not X.shape == Y.shape:
raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.")

if mask is not None and mask.shape != X.shape:
raise ValueError(f"Input mask should have the same dimensions as input images, but got {mask.shape} and {X.shape}.")

for d in range(len(X.shape) - 1, 1, -1):
X = X.squeeze(dim=d)
Y = Y.squeeze(dim=d)
if mask is not None:
mask = mask.squeeze(dim=d)

if mask is not None:
assert size_average is True, "per channel ssim is not available if mask exist"
margin = win_size // 2
mask = mask[..., margin:-margin, margin:-margin]

if len(X.shape) not in (4, 5):
raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}")
Expand All @@ -151,11 +169,13 @@ def ssim(
win = _fspecial_gauss_1d(win_size, win_sigma)
win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))

ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K)
ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K, mask=mask)
if nonnegative_ssim:
ssim_per_channel = torch.relu(ssim_per_channel)

if size_average:
if mask is not None:
return ssim_per_channel
elif size_average:
return ssim_per_channel.mean()
else:
return ssim_per_channel.mean(1)
Expand Down Expand Up @@ -274,7 +294,7 @@ def __init__(
self.K = K
self.nonnegative_ssim = nonnegative_ssim

def forward(self, X: Tensor, Y: Tensor) -> Tensor:
def forward(self, X: Tensor, Y: Tensor, mask: Optional[Tensor]) -> Tensor:
return ssim(
X,
Y,
Expand All @@ -283,6 +303,7 @@ def forward(self, X: Tensor, Y: Tensor) -> Tensor:
win=self.win,
K=self.K,
nonnegative_ssim=self.nonnegative_ssim,
mask=mask,
)


Expand Down