@@ -38,6 +38,8 @@ class MultiScalePatchDiscriminator(nn.Sequential):
3838 spatial_dims: number of spatial dimensions (1D, 2D etc.)
3939 num_channels: number of filters in the first convolutional layer (double of the value is taken from then on)
4040 in_channels: number of input channels
41+ pooling_method: pooling method to be applied before each discriminator after the first.
42+ If None, the number of layers is multiplied by the number of discriminators.
4143 out_channels: number of output channels in each discriminator
4244 kernel_size: kernel size of the convolution layers
4345 activation: activation layer type
@@ -52,10 +54,11 @@ class MultiScalePatchDiscriminator(nn.Sequential):
5254 def __init__ (
5355 self ,
5456 num_d : int ,
55- num_layers_d : int ,
57+ num_layers_d : int | list [ int ] ,
5658 spatial_dims : int ,
5759 num_channels : int ,
5860 in_channels : int ,
61+ pooling_method : str = None ,
5962 out_channels : int = 1 ,
6063 kernel_size : int = 4 ,
6164 activation : str | tuple = (Act .LEAKYRELU , {"negative_slope" : 0.2 }),
@@ -67,31 +70,71 @@ def __init__(
6770 ) -> None :
6871 super ().__init__ ()
6972 self .num_d = num_d
73+ if isinstance (num_layers_d , int ) and pooling_method is None :
74+ # if pooling_method is None, calculate the number of layers for each discriminator by multiplying by the number of discriminators
75+ num_layers_d = [num_layers_d * i for i in range (1 , num_d + 1 )]
76+ elif isinstance (num_layers_d , int ) and pooling_method is not None :
77+ # if pooling_method is not None, the number of layers is the same for all discriminators
78+ num_layers_d = [num_layers_d ] * num_d
7079 self .num_layers_d = num_layers_d
80+ assert (
81+ len (self .num_layers_d ) == self .num_d
82+ ), f"MultiScalePatchDiscriminator: num_d { num_d } must match the number of num_layers_d. { num_layers_d } "
83+
84+ if pooling_method is None :
85+ self .pool = None
86+ elif pooling_method .lower () == "max" :
87+ self .pool = nn .MaxPool2d (kernel_size = kernel_size , stride = 2 )
88+ elif pooling_method .lower () == "avg" :
89+ self .pool = nn .AvgPool2d (kernel_size = kernel_size , stride = 2 )
90+ else :
91+ raise ValueError (f"MultiScalePatchDiscriminator: Pooling method { pooling_method } is not supported." )
92+ print (
93+ f"Initialising MultiScalePatchDiscriminator with { self .num_d } discriminators, { self .num_layers_d } layers and pooling method { self .pool .__class__ .__name__ } ."
94+ )
7195 self .num_channels = num_channels
7296 self .padding = tuple ([int ((kernel_size - 1 ) / 2 )] * spatial_dims )
7397 for i_ in range (self .num_d ):
74- num_layers_d_i = self .num_layers_d * ( i_ + 1 )
98+ num_layers_d_i = self .num_layers_d [ i_ ]
7599 output_size = float (minimum_size_im ) / (2 ** num_layers_d_i )
76100 if output_size < 1 :
77101 raise AssertionError (
78102 "Your image size is too small to take in up to %d discriminators with num_layers = %d."
79103 "Please reduce num_layers, reduce num_D or enter bigger images." % (i_ , num_layers_d_i )
80104 )
81- subnet_d = PatchDiscriminator (
82- spatial_dims = spatial_dims ,
83- num_channels = self .num_channels ,
84- in_channels = in_channels ,
85- out_channels = out_channels ,
86- num_layers_d = num_layers_d_i ,
87- kernel_size = kernel_size ,
88- activation = activation ,
89- norm = norm ,
90- bias = bias ,
91- padding = self .padding ,
92- dropout = dropout ,
93- last_conv_kernel_size = last_conv_kernel_size ,
94- )
105+ if i_ == 0 or self .pool is None :
106+ subnet_d = PatchDiscriminator (
107+ spatial_dims = spatial_dims ,
108+ num_channels = self .num_channels ,
109+ in_channels = in_channels ,
110+ out_channels = out_channels ,
111+ num_layers_d = num_layers_d_i ,
112+ kernel_size = kernel_size ,
113+ activation = activation ,
114+ norm = norm ,
115+ bias = bias ,
116+ padding = self .padding ,
117+ dropout = dropout ,
118+ last_conv_kernel_size = last_conv_kernel_size ,
119+ )
120+ else :
121+ subnet_d = nn .Sequential (
122+ * [self .pool ] * i_ ,
123+ PatchDiscriminator (
124+ spatial_dims = spatial_dims ,
125+ num_channels = self .num_channels ,
126+ in_channels = in_channels ,
127+ out_channels = out_channels ,
128+ num_layers_d = num_layers_d_i ,
129+ kernel_size = kernel_size ,
130+ activation = activation ,
131+ norm = norm ,
132+ bias = bias ,
133+ padding = self .padding ,
134+ dropout = dropout ,
135+ last_conv_kernel_size = last_conv_kernel_size ,
136+ ),
137+ )
95138
96139 self .add_module ("discriminator_%d" % i_ , subnet_d )
97140
0 commit comments