diff --git a/pytorch_msssim/ssim.py b/pytorch_msssim/ssim.py index 16380e2..10103b2 100644 --- a/pytorch_msssim/ssim.py +++ b/pytorch_msssim/ssim.py @@ -60,7 +60,8 @@ def _ssim( data_range: float, win: Tensor, size_average: bool = True, - K: Union[Tuple[float, float], List[float]] = (0.01, 0.03) + K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), + mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: r""" Calculate ssim index for X and Y @@ -70,6 +71,7 @@ def _ssim( data_range (float or int): value range of input images. (usually 1.0 or 255) win (torch.Tensor): 1-D gauss kernel size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar + mask (torch.Tensor): boolean mask same size as X and Y Returns: Tuple[torch.Tensor, torch.Tensor]: ssim results. @@ -97,6 +99,10 @@ def _ssim( cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1 ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map + if mask is not None: + ssim = torch.masked_select(ssim_map, mask).mean() + cs = torch.masked_select(cs_map, mask).mean() + return ssim, cs ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1) cs = torch.flatten(cs_map, 2).mean(-1) return ssim_per_channel, cs @@ -112,6 +118,7 @@ def ssim( win: Optional[Tensor] = None, K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), nonnegative_ssim: bool = False, + mask: Optional[Tensor] = None, ) -> Tensor: r""" interface of ssim Args: @@ -124,6 +131,7 @@ def ssim( win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu + mask (torch.Tensor): boolean mask same size as X and Y Returns: torch.Tensor: ssim results @@ -131,9 +139,19 @@ def ssim( if not X.shape == Y.shape: raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.") + if mask is not None and mask.shape != X.shape: + raise ValueError(f"Input mask should have the same dimensions as input images, but got {mask.shape} and {X.shape}.") + for d in range(len(X.shape) - 1, 1, -1): X = X.squeeze(dim=d) Y = Y.squeeze(dim=d) + if mask is not None: + mask = mask.squeeze(dim=d) + + if mask is not None: + assert size_average is True, "per channel ssim is not available if mask exist" + margin = win_size // 2 + mask = mask[..., margin:-margin, margin:-margin] if len(X.shape) not in (4, 5): raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}") @@ -151,11 +169,13 @@ def ssim( win = _fspecial_gauss_1d(win_size, win_sigma) win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1)) - ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K) + ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K, mask=mask) if nonnegative_ssim: ssim_per_channel = torch.relu(ssim_per_channel) - if size_average: + if mask is not None: + return ssim_per_channel + elif size_average: return ssim_per_channel.mean() else: return ssim_per_channel.mean(1) @@ -274,7 +294,7 @@ def __init__( self.K = K self.nonnegative_ssim = nonnegative_ssim - def forward(self, X: Tensor, Y: Tensor) -> Tensor: + def forward(self, X: Tensor, Y: Tensor, mask: Optional[Tensor]) -> Tensor: return ssim( X, Y, @@ -283,6 +303,7 @@ def forward(self, X: Tensor, Y: Tensor) -> Tensor: win=self.win, K=self.K, nonnegative_ssim=self.nonnegative_ssim, + mask=mask, )