Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 00e3c8e

Browse files
committed
Added some unit tests for new functionalities
1 parent 547e2ba commit 00e3c8e

File tree

1 file changed

+80
-2
lines changed

1 file changed

+80
-2
lines changed

tests/test_patch_gan.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,65 @@
5858
[(1, 1, 32, 64, 32), (1, 1, 4, 8, 4)],
5959
[4, 7],
6060
]
61+
TEST_3D_POOL = [
62+
{
63+
"num_d": 2,
64+
"num_layers_d": 3,
65+
"spatial_dims": 3,
66+
"num_channels": 8,
67+
"in_channels": 3,
68+
"out_channels": 1,
69+
"kernel_size": 3,
70+
"pooling_method": "max",
71+
"activation": "LEAKYRELU",
72+
"norm": "instance",
73+
"bias": False,
74+
"dropout": 0.1,
75+
"minimum_size_im": 256,
76+
},
77+
torch.rand([1, 3, 256, 512, 256]),
78+
[(1, 1, 32, 64, 32), (1, 1, 16, 32, 16)],
79+
[4, 4],
80+
]
81+
TEST_2D_POOL = [
82+
{
83+
"num_d": 4,
84+
"num_layers_d": 3,
85+
"spatial_dims": 2,
86+
"num_channels": 8,
87+
"in_channels": 3,
88+
"out_channels": 1,
89+
"kernel_size": 3,
90+
"pooling_method": "avg",
91+
"activation": "LEAKYRELU",
92+
"norm": "instance",
93+
"bias": False,
94+
"dropout": 0.1,
95+
"minimum_size_im": 256,
96+
},
97+
torch.rand([1, 3, 256, 512]),
98+
[(1, 1, 32, 64), (1, 1, 16, 32), (1, 1, 8, 16), (1, 1, 4, 8)],
99+
[4, 4, 4, 4],
100+
]
101+
TEST_LAYER_LIST = [
102+
{
103+
"num_d": 3,
104+
"num_layers_d": [3,4,5],
105+
"spatial_dims": 2,
106+
"num_channels": 8,
107+
"in_channels": 3,
108+
"out_channels": 1,
109+
"kernel_size": 3,
110+
"activation": "LEAKYRELU",
111+
"norm": "instance",
112+
"bias": False,
113+
"dropout": 0.1,
114+
"minimum_size_im": 256,
115+
},
116+
torch.rand([1, 3, 256, 512]),
117+
[(1, 1, 32, 64), (1, 1, 16, 32), (1, 1, 8, 16)],
118+
[4, 5, 6],
119+
]
61120
TEST_TOO_SMALL_SIZE = [
62121
{
63122
"num_d": 2,
@@ -74,9 +133,24 @@
74133
"minimum_size_im": 256,
75134
}
76135
]
136+
TEST_MISMATCHED_NUM_LAYERS = [
137+
{
138+
"num_d": 5,
139+
"num_layers_d": [3,4,5],
140+
"spatial_dims": 2,
141+
"num_channels": 8,
142+
"in_channels": 3,
143+
"out_channels": 1,
144+
"kernel_size": 3,
145+
"activation": "LEAKYRELU",
146+
"norm": "instance",
147+
"bias": False,
148+
"dropout": 0.1,
149+
"minimum_size_im": 256,
150+
}
151+
]
77152

78-
CASES = [TEST_2D, TEST_3D]
79-
153+
CASES = [TEST_2D, TEST_3D, TEST_3D_POOL, TEST_2D_POOL, TEST_LAYER_LIST]
80154

81155
class TestPatchGAN(unittest.TestCase):
82156
@parameterized.expand(CASES)
@@ -93,6 +167,10 @@ def test_too_small_shape(self):
93167
with self.assertRaises(AssertionError):
94168
MultiScalePatchDiscriminator(**TEST_TOO_SMALL_SIZE[0])
95169

170+
def test_mismatched_num_layers(self):
171+
with self.assertRaises(AssertionError):
172+
MultiScalePatchDiscriminator(**TEST_MISMATCHED_NUM_LAYERS[0])
173+
96174
def test_script(self):
97175
net = MultiScalePatchDiscriminator(
98176
num_d=2,

0 commit comments

Comments
 (0)