Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 9f06461

Browse files
committed
Added option to initialize multiscalepatchgan with pooling layers
1 parent d294bbf commit 9f06461

File tree

1 file changed

+59
-16
lines changed

1 file changed

+59
-16
lines changed

generative/networks/nets/patchgan_discriminator.py

Lines changed: 59 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class MultiScalePatchDiscriminator(nn.Sequential):
3838
spatial_dims: number of spatial dimensions (1D, 2D etc.)
3939
num_channels: number of filters in the first convolutional layer (double of the value is taken from then on)
4040
in_channels: number of input channels
41+
pooling_method: pooling method to be applied before each discriminator after the first.
42+
If None, the number of layers is multiplied by the number of discriminators.
4143
out_channels: number of output channels in each discriminator
4244
kernel_size: kernel size of the convolution layers
4345
activation: activation layer type
@@ -52,10 +54,11 @@ class MultiScalePatchDiscriminator(nn.Sequential):
5254
def __init__(
5355
self,
5456
num_d: int,
55-
num_layers_d: int,
57+
num_layers_d: int | list[int],
5658
spatial_dims: int,
5759
num_channels: int,
5860
in_channels: int,
61+
pooling_method: str = None,
5962
out_channels: int = 1,
6063
kernel_size: int = 4,
6164
activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}),
@@ -67,31 +70,71 @@ def __init__(
6770
) -> None:
6871
super().__init__()
6972
self.num_d = num_d
73+
if isinstance(num_layers_d, int) and pooling_method is None:
74+
# if pooling_method is None, calculate the number of layers for each discriminator by multiplying by the number of discriminators
75+
num_layers_d = [num_layers_d * i for i in range(1, num_d + 1)]
76+
elif isinstance(num_layers_d, int) and pooling_method is not None:
77+
# if pooling_method is not None, the number of layers is the same for all discriminators
78+
num_layers_d = [num_layers_d] * num_d
7079
self.num_layers_d = num_layers_d
80+
assert (
81+
len(self.num_layers_d) == self.num_d
82+
), f"MultiScalePatchDiscriminator: num_d {num_d} must match the number of num_layers_d. {num_layers_d}"
83+
84+
if pooling_method is None:
85+
self.pool = None
86+
elif pooling_method.lower() == "max":
87+
self.pool = nn.MaxPool2d(kernel_size=kernel_size, stride=2)
88+
elif pooling_method.lower() == "avg":
89+
self.pool = nn.AvgPool2d(kernel_size=kernel_size, stride=2)
90+
else:
91+
raise ValueError(f"MultiScalePatchDiscriminator: Pooling method {pooling_method} is not supported.")
92+
print(
93+
f"Initialising MultiScalePatchDiscriminator with {self.num_d} discriminators, {self.num_layers_d} layers and pooling method {self.pool.__class__.__name__}."
94+
)
7195
self.num_channels = num_channels
7296
self.padding = tuple([int((kernel_size - 1) / 2)] * spatial_dims)
7397
for i_ in range(self.num_d):
74-
num_layers_d_i = self.num_layers_d * (i_ + 1)
98+
num_layers_d_i = self.num_layers_d[i_]
7599
output_size = float(minimum_size_im) / (2**num_layers_d_i)
76100
if output_size < 1:
77101
raise AssertionError(
78102
"Your image size is too small to take in up to %d discriminators with num_layers = %d."
79103
"Please reduce num_layers, reduce num_D or enter bigger images." % (i_, num_layers_d_i)
80104
)
81-
subnet_d = PatchDiscriminator(
82-
spatial_dims=spatial_dims,
83-
num_channels=self.num_channels,
84-
in_channels=in_channels,
85-
out_channels=out_channels,
86-
num_layers_d=num_layers_d_i,
87-
kernel_size=kernel_size,
88-
activation=activation,
89-
norm=norm,
90-
bias=bias,
91-
padding=self.padding,
92-
dropout=dropout,
93-
last_conv_kernel_size=last_conv_kernel_size,
94-
)
105+
if i_ == 0 or self.pool is None:
106+
subnet_d = PatchDiscriminator(
107+
spatial_dims=spatial_dims,
108+
num_channels=self.num_channels,
109+
in_channels=in_channels,
110+
out_channels=out_channels,
111+
num_layers_d=num_layers_d_i,
112+
kernel_size=kernel_size,
113+
activation=activation,
114+
norm=norm,
115+
bias=bias,
116+
padding=self.padding,
117+
dropout=dropout,
118+
last_conv_kernel_size=last_conv_kernel_size,
119+
)
120+
else:
121+
subnet_d = nn.Sequential(
122+
*[self.pool] * i_,
123+
PatchDiscriminator(
124+
spatial_dims=spatial_dims,
125+
num_channels=self.num_channels,
126+
in_channels=in_channels,
127+
out_channels=out_channels,
128+
num_layers_d=num_layers_d_i,
129+
kernel_size=kernel_size,
130+
activation=activation,
131+
norm=norm,
132+
bias=bias,
133+
padding=self.padding,
134+
dropout=dropout,
135+
last_conv_kernel_size=last_conv_kernel_size,
136+
),
137+
)
95138

96139
self.add_module("discriminator_%d" % i_, subnet_d)
97140

0 commit comments

Comments
 (0)