5858 [(1 , 1 , 32 , 64 , 32 ), (1 , 1 , 4 , 8 , 4 )],
5959 [4 , 7 ],
6060]
61+ TEST_3D_POOL = [
62+ {
63+ "num_d" : 2 ,
64+ "num_layers_d" : 3 ,
65+ "spatial_dims" : 3 ,
66+ "num_channels" : 8 ,
67+ "in_channels" : 3 ,
68+ "out_channels" : 1 ,
69+ "kernel_size" : 3 ,
70+ "pooling_method" : "max" ,
71+ "activation" : "LEAKYRELU" ,
72+ "norm" : "instance" ,
73+ "bias" : False ,
74+ "dropout" : 0.1 ,
75+ "minimum_size_im" : 256 ,
76+ },
77+ torch .rand ([1 , 3 , 256 , 512 , 256 ]),
78+ [(1 , 1 , 32 , 64 , 32 ), (1 , 1 , 16 , 32 , 16 )],
79+ [4 , 4 ],
80+ ]
81+ TEST_2D_POOL = [
82+ {
83+ "num_d" : 4 ,
84+ "num_layers_d" : 3 ,
85+ "spatial_dims" : 2 ,
86+ "num_channels" : 8 ,
87+ "in_channels" : 3 ,
88+ "out_channels" : 1 ,
89+ "kernel_size" : 3 ,
90+ "pooling_method" : "avg" ,
91+ "activation" : "LEAKYRELU" ,
92+ "norm" : "instance" ,
93+ "bias" : False ,
94+ "dropout" : 0.1 ,
95+ "minimum_size_im" : 256 ,
96+ },
97+ torch .rand ([1 , 3 , 256 , 512 ]),
98+ [(1 , 1 , 32 , 64 ), (1 , 1 , 16 , 32 ), (1 , 1 , 8 , 16 ), (1 , 1 , 4 , 8 )],
99+ [4 , 4 , 4 , 4 ],
100+ ]
101+ TEST_LAYER_LIST = [
102+ {
103+ "num_d" : 3 ,
104+ "num_layers_d" : [3 ,4 ,5 ],
105+ "spatial_dims" : 2 ,
106+ "num_channels" : 8 ,
107+ "in_channels" : 3 ,
108+ "out_channels" : 1 ,
109+ "kernel_size" : 3 ,
110+ "activation" : "LEAKYRELU" ,
111+ "norm" : "instance" ,
112+ "bias" : False ,
113+ "dropout" : 0.1 ,
114+ "minimum_size_im" : 256 ,
115+ },
116+ torch .rand ([1 , 3 , 256 , 512 ]),
117+ [(1 , 1 , 32 , 64 ), (1 , 1 , 16 , 32 ), (1 , 1 , 8 , 16 )],
118+ [4 , 5 , 6 ],
119+ ]
61120TEST_TOO_SMALL_SIZE = [
62121 {
63122 "num_d" : 2 ,
74133 "minimum_size_im" : 256 ,
75134 }
76135]
136+ TEST_MISMATCHED_NUM_LAYERS = [
137+ {
138+ "num_d" : 5 ,
139+ "num_layers_d" : [3 ,4 ,5 ],
140+ "spatial_dims" : 2 ,
141+ "num_channels" : 8 ,
142+ "in_channels" : 3 ,
143+ "out_channels" : 1 ,
144+ "kernel_size" : 3 ,
145+ "activation" : "LEAKYRELU" ,
146+ "norm" : "instance" ,
147+ "bias" : False ,
148+ "dropout" : 0.1 ,
149+ "minimum_size_im" : 256 ,
150+ }
151+ ]
77152
78- CASES = [TEST_2D , TEST_3D ]
79-
153+ CASES = [TEST_2D , TEST_3D , TEST_3D_POOL , TEST_2D_POOL , TEST_LAYER_LIST ]
80154
81155class TestPatchGAN (unittest .TestCase ):
82156 @parameterized .expand (CASES )
@@ -93,6 +167,10 @@ def test_too_small_shape(self):
93167 with self .assertRaises (AssertionError ):
94168 MultiScalePatchDiscriminator (** TEST_TOO_SMALL_SIZE [0 ])
95169
170+ def test_mismatched_num_layers (self ):
171+ with self .assertRaises (AssertionError ):
172+ MultiScalePatchDiscriminator (** TEST_MISMATCHED_NUM_LAYERS [0 ])
173+
96174 def test_script (self ):
97175 net = MultiScalePatchDiscriminator (
98176 num_d = 2 ,
0 commit comments