diff --git a/pytorch_msssim/ssim.py b/pytorch_msssim/ssim.py index 16380e2..92272a1 100644 --- a/pytorch_msssim/ssim.py +++ b/pytorch_msssim/ssim.py @@ -60,7 +60,8 @@ def _ssim( data_range: float, win: Tensor, size_average: bool = True, - K: Union[Tuple[float, float], List[float]] = (0.01, 0.03) + K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), + alpha_beta_gamma: Union[Tuple[float, float, float], List[float]] = (1., 1., 1.), ) -> Tuple[Tensor, Tensor]: r""" Calculate ssim index for X and Y @@ -80,6 +81,7 @@ def _ssim( C1 = (K1 * data_range) ** 2 C2 = (K2 * data_range) ** 2 + C3 = C2 / 2 win = win.to(X.device, dtype=X.dtype) @@ -93,9 +95,16 @@ def _ssim( sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq) sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq) sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2) + sigma1 = sigma1_sq ** 0.5 + sigma2 = sigma2_sq ** 0.5 + + luminance = (2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1) + contrast = (2 * sigma1 * sigma2 + C2) / (sigma1_sq + sigma2_sq + C2) + structure = (sigma12 + C3) / (sigma1 * sigma2 + C3) - cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1 - ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map + alpha, beta, gamma = alpha_beta_gamma + cs_map = (contrast ** beta) * (structure ** gamma) + ssim_map = (luminance ** alpha) * cs_map ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1) cs = torch.flatten(cs_map, 2).mean(-1) @@ -111,6 +120,7 @@ def ssim( win_sigma: float = 1.5, win: Optional[Tensor] = None, K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), + alpha_beta_gamma: Union[Tuple[float, float, float], List[float]] = (1., 1., 1.), nonnegative_ssim: bool = False, ) -> Tensor: r""" interface of ssim @@ -123,6 +133,7 @@ def ssim( win_sigma: (float, optional): sigma of normal distribution win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + alpha_beta_gamma (list or tuple, optional): scalar constants (alpha, beta, gamma). Controls relative strength of luminance, contrast, and structure terms. nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu Returns: @@ -151,7 +162,7 @@ def ssim( win = _fspecial_gauss_1d(win_size, win_sigma) win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1)) - ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K) + ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K, alpha_beta_gamma=alpha_beta_gamma) if nonnegative_ssim: ssim_per_channel = torch.relu(ssim_per_channel) @@ -170,7 +181,9 @@ def ms_ssim( win_sigma: float = 1.5, win: Optional[Tensor] = None, weights: Optional[List[float]] = None, - K: Union[Tuple[float, float], List[float]] = (0.01, 0.03) + K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), + alpha_beta_gamma: Union[Tuple[float, float, float], List[float]] = (1., 1., 1.), + ) -> Tensor: r""" interface of ms-ssim Args: @@ -183,6 +196,8 @@ def ms_ssim( win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma weights (list, optional): weights for different levels K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + alpha_beta_gamma (list or tuple, optional): scalar constants (alpha, beta, gamma). Controls relative strength of luminance, contrast, and structure terms. + Returns: torch.Tensor: ms-ssim results """ @@ -225,7 +240,7 @@ def ms_ssim( levels = weights_tensor.shape[0] mcs = [] for i in range(levels): - ssim_per_channel, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, K=K) + ssim_per_channel, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, K=K, alpha_beta_gamma=alpha_beta_gamma) if i < levels - 1: mcs.append(torch.relu(cs)) @@ -253,6 +268,7 @@ def __init__( channel: int = 3, spatial_dims: int = 2, K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), + alpha_beta_gamma: Union[Tuple[float, float, float], List[float]] = (1., 1., 1.), nonnegative_ssim: bool = False, ) -> None: r""" class for ssim @@ -263,6 +279,7 @@ def __init__( win_sigma: (float, optional): sigma of normal distribution channel (int, optional): input channels (default: 3) K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + alpha_beta_gamma (list or tuple, optional): scalar constants (alpha, beta, gamma). Controls relative strength of luminance, contrast, and structure terms. nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu. """ @@ -272,6 +289,7 @@ def __init__( self.size_average = size_average self.data_range = data_range self.K = K + self.alpha_beta_gamma = alpha_beta_gamma self.nonnegative_ssim = nonnegative_ssim def forward(self, X: Tensor, Y: Tensor) -> Tensor: @@ -282,6 +300,7 @@ def forward(self, X: Tensor, Y: Tensor) -> Tensor: size_average=self.size_average, win=self.win, K=self.K, + alpha_beta_gamma=self.alpha_beta_gamma, nonnegative_ssim=self.nonnegative_ssim, ) @@ -297,6 +316,7 @@ def __init__( spatial_dims: int = 2, weights: Optional[List[float]] = None, K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), + alpha_beta_gamma: Union[Tuple[float, float, float], List[float]] = (1., 1., 1.), ) -> None: r""" class for ms-ssim Args: @@ -307,6 +327,7 @@ def __init__( channel (int, optional): input channels (default: 3) weights (list, optional): weights for different levels K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + alpha_beta_gamma (list or tuple, optional): scalar constants (alpha, beta, gamma). Controls relative strength of luminance, contrast, and structure terms. """ super(MS_SSIM, self).__init__() @@ -316,6 +337,7 @@ def __init__( self.data_range = data_range self.weights = weights self.K = K + self.alpha_beta_gamma = alpha_beta_gamma def forward(self, X: Tensor, Y: Tensor) -> Tensor: return ms_ssim( @@ -326,4 +348,5 @@ def forward(self, X: Tensor, Y: Tensor) -> Tensor: win=self.win, weights=self.weights, K=self.K, + alpha_beta_gamma = self.alpha_beta_gamma )