Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 547e2ba

Browse files
committed
Fixed pooling method working with different spatial_dims
1 parent 9f06461 commit 547e2ba

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

generative/networks/nets/patchgan_discriminator.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818
import torch.nn as nn
1919
from monai.networks.blocks import Convolution
20-
from monai.networks.layers import Act
20+
from monai.networks.layers import (Act, get_pool_layer)
2121

2222

2323
class MultiScalePatchDiscriminator(nn.Sequential):
@@ -82,15 +82,11 @@ def __init__(
8282
), f"MultiScalePatchDiscriminator: num_d {num_d} must match the number of num_layers_d. {num_layers_d}"
8383

8484
if pooling_method is None:
85-
self.pool = None
86-
elif pooling_method.lower() == "max":
87-
self.pool = nn.MaxPool2d(kernel_size=kernel_size, stride=2)
88-
elif pooling_method.lower() == "avg":
89-
self.pool = nn.AvgPool2d(kernel_size=kernel_size, stride=2)
85+
pool = None
9086
else:
91-
raise ValueError(f"MultiScalePatchDiscriminator: Pooling method {pooling_method} is not supported.")
87+
pool = get_pool_layer((pooling_method, {"kernel_size": kernel_size, "stride": 2}), spatial_dims=spatial_dims)
9288
print(
93-
f"Initialising MultiScalePatchDiscriminator with {self.num_d} discriminators, {self.num_layers_d} layers and pooling method {self.pool.__class__.__name__}."
89+
f"Initialising {spatial_dims}D MultiScalePatchDiscriminator with {self.num_d} discriminators, {self.num_layers_d} layers and pooling method {pool.__class__.__name__}."
9490
)
9591
self.num_channels = num_channels
9692
self.padding = tuple([int((kernel_size - 1) / 2)] * spatial_dims)
@@ -102,7 +98,7 @@ def __init__(
10298
"Your image size is too small to take in up to %d discriminators with num_layers = %d."
10399
"Please reduce num_layers, reduce num_D or enter bigger images." % (i_, num_layers_d_i)
104100
)
105-
if i_ == 0 or self.pool is None:
101+
if i_ == 0 or pool is None:
106102
subnet_d = PatchDiscriminator(
107103
spatial_dims=spatial_dims,
108104
num_channels=self.num_channels,
@@ -119,7 +115,7 @@ def __init__(
119115
)
120116
else:
121117
subnet_d = nn.Sequential(
122-
*[self.pool] * i_,
118+
*[pool] * i_,
123119
PatchDiscriminator(
124120
spatial_dims=spatial_dims,
125121
num_channels=self.num_channels,

0 commit comments

Comments
 (0)