Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions pytorch_msssim/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def _ssim(
data_range: float,
win: Tensor,
size_average: bool = True,
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03)
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03),
alpha_beta_gamma: Union[Tuple[float, float, float], List[float]] = (1., 1., 1.),
) -> Tuple[Tensor, Tensor]:
r""" Calculate ssim index for X and Y

Expand All @@ -80,6 +81,7 @@ def _ssim(

C1 = (K1 * data_range) ** 2
C2 = (K2 * data_range) ** 2
C3 = C2 / 2

win = win.to(X.device, dtype=X.dtype)

Expand All @@ -93,9 +95,16 @@ def _ssim(
sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq)
sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq)
sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2)
sigma1 = sigma1_sq ** 0.5
sigma2 = sigma2_sq ** 0.5

luminance = (2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)
contrast = (2 * sigma1 * sigma2 + C2) / (sigma1_sq + sigma2_sq + C2)
structure = (sigma12 + C3) / (sigma1 * sigma2 + C3)

cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1
ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
alpha, beta, gamma = alpha_beta_gamma
cs_map = (contrast ** beta) * (structure ** gamma)
ssim_map = (luminance ** alpha) * cs_map

ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1)
cs = torch.flatten(cs_map, 2).mean(-1)
Expand All @@ -111,6 +120,7 @@ def ssim(
win_sigma: float = 1.5,
win: Optional[Tensor] = None,
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03),
alpha_beta_gamma: Union[Tuple[float, float, float], List[float]] = (1., 1., 1.),
nonnegative_ssim: bool = False,
) -> Tensor:
r""" interface of ssim
Expand All @@ -123,6 +133,7 @@ def ssim(
win_sigma: (float, optional): sigma of normal distribution
win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
alpha_beta_gamma (list or tuple, optional): scalar constants (alpha, beta, gamma). Controls relative strength of luminance, contrast, and structure terms.
nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu

Returns:
Expand Down Expand Up @@ -151,7 +162,7 @@ def ssim(
win = _fspecial_gauss_1d(win_size, win_sigma)
win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))

ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K)
ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K, alpha_beta_gamma=alpha_beta_gamma)
if nonnegative_ssim:
ssim_per_channel = torch.relu(ssim_per_channel)

Expand All @@ -170,7 +181,9 @@ def ms_ssim(
win_sigma: float = 1.5,
win: Optional[Tensor] = None,
weights: Optional[List[float]] = None,
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03)
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03),
alpha_beta_gamma: Union[Tuple[float, float, float], List[float]] = (1., 1., 1.),

) -> Tensor:
r""" interface of ms-ssim
Args:
Expand All @@ -183,6 +196,8 @@ def ms_ssim(
win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
weights (list, optional): weights for different levels
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
alpha_beta_gamma (list or tuple, optional): scalar constants (alpha, beta, gamma). Controls relative strength of luminance, contrast, and structure terms.

Returns:
torch.Tensor: ms-ssim results
"""
Expand Down Expand Up @@ -225,7 +240,7 @@ def ms_ssim(
levels = weights_tensor.shape[0]
mcs = []
for i in range(levels):
ssim_per_channel, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, K=K)
ssim_per_channel, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, K=K, alpha_beta_gamma=alpha_beta_gamma)

if i < levels - 1:
mcs.append(torch.relu(cs))
Expand Down Expand Up @@ -253,6 +268,7 @@ def __init__(
channel: int = 3,
spatial_dims: int = 2,
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03),
alpha_beta_gamma: Union[Tuple[float, float, float], List[float]] = (1., 1., 1.),
nonnegative_ssim: bool = False,
) -> None:
r""" class for ssim
Expand All @@ -263,6 +279,7 @@ def __init__(
win_sigma: (float, optional): sigma of normal distribution
channel (int, optional): input channels (default: 3)
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
alpha_beta_gamma (list or tuple, optional): scalar constants (alpha, beta, gamma). Controls relative strength of luminance, contrast, and structure terms.
nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu.
"""

Expand All @@ -272,6 +289,7 @@ def __init__(
self.size_average = size_average
self.data_range = data_range
self.K = K
self.alpha_beta_gamma = alpha_beta_gamma
self.nonnegative_ssim = nonnegative_ssim

def forward(self, X: Tensor, Y: Tensor) -> Tensor:
Expand All @@ -282,6 +300,7 @@ def forward(self, X: Tensor, Y: Tensor) -> Tensor:
size_average=self.size_average,
win=self.win,
K=self.K,
alpha_beta_gamma=self.alpha_beta_gamma,
nonnegative_ssim=self.nonnegative_ssim,
)

Expand All @@ -297,6 +316,7 @@ def __init__(
spatial_dims: int = 2,
weights: Optional[List[float]] = None,
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03),
alpha_beta_gamma: Union[Tuple[float, float, float], List[float]] = (1., 1., 1.),
) -> None:
r""" class for ms-ssim
Args:
Expand All @@ -307,6 +327,7 @@ def __init__(
channel (int, optional): input channels (default: 3)
weights (list, optional): weights for different levels
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
alpha_beta_gamma (list or tuple, optional): scalar constants (alpha, beta, gamma). Controls relative strength of luminance, contrast, and structure terms.
"""

super(MS_SSIM, self).__init__()
Expand All @@ -316,6 +337,7 @@ def __init__(
self.data_range = data_range
self.weights = weights
self.K = K
self.alpha_beta_gamma = alpha_beta_gamma

def forward(self, X: Tensor, Y: Tensor) -> Tensor:
return ms_ssim(
Expand All @@ -326,4 +348,5 @@ def forward(self, X: Tensor, Y: Tensor) -> Tensor:
win=self.win,
weights=self.weights,
K=self.K,
alpha_beta_gamma = self.alpha_beta_gamma
)