Skip to content

Commit f02c08a

Browse files
bibhabasumohapatrasdesrozisvfdev-5
authored andcommitted
Warnings fixed for test_ssim.py (#2360)
* first commit * changing channel_axis to -1 for multichannel argument * deprecate permute(0,2,3,1) instead use channel_axis=1 that is (B,C,H,W) Co-authored-by: Sylvain Desroziers <sylvain.desroziers@gmail.com> Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent 5dcee96 commit f02c08a

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

tests/ignite/metrics/test_ssim.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,14 @@ def _test_ssim(y_pred, y, data_range, kernel_size, sigma, gaussian, use_sample_c
7474
ssim.update((y_pred, y))
7575
ignite_ssim = ssim.compute()
7676

77-
skimg_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
77+
skimg_pred = y_pred.cpu().numpy()
7878
skimg_y = skimg_pred * 0.8
7979
skimg_ssim = ski_ssim(
8080
skimg_pred,
8181
skimg_y,
8282
win_size=kernel_size,
8383
sigma=sigma,
84-
multichannel=True,
84+
channel_axis=1,
8585
gaussian_weights=gaussian,
8686
data_range=data_range,
8787
use_sample_covariance=use_sample_covariance,
@@ -135,14 +135,14 @@ def update(engine, i):
135135
assert "ssim" in engine.state.metrics
136136
res = engine.state.metrics["ssim"]
137137

138-
np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
138+
np_pred = y_pred.cpu().numpy()
139139
np_true = np_pred * 0.65
140140
true_res = ski_ssim(
141141
np_pred,
142142
np_true,
143143
win_size=11,
144144
sigma=1.5,
145-
multichannel=True,
145+
channel_axis=1,
146146
gaussian_weights=True,
147147
data_range=1.0,
148148
use_sample_covariance=False,
@@ -159,9 +159,9 @@ def update(engine, i):
159159
assert "ssim" in engine.state.metrics
160160
res = engine.state.metrics["ssim"]
161161

162-
np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
162+
np_pred = y_pred.cpu().numpy()
163163
np_true = np_pred * 0.65
164-
true_res = ski_ssim(np_pred, np_true, win_size=7, multichannel=True, gaussian_weights=False, data_range=1.0)
164+
true_res = ski_ssim(np_pred, np_true, win_size=7, channel_axis=1, gaussian_weights=False, data_range=1.0)
165165

166166
assert pytest.approx(res, abs=tol) == true_res
167167

0 commit comments

Comments
 (0)