1717import torch
1818import torch .nn as nn
1919from monai .networks .blocks import Convolution
20- from monai .networks .layers import Act
20+ from monai .networks .layers import ( Act , get_pool_layer )
2121
2222
2323class MultiScalePatchDiscriminator (nn .Sequential ):
@@ -82,15 +82,11 @@ def __init__(
8282 ), f"MultiScalePatchDiscriminator: num_d { num_d } must match the number of num_layers_d. { num_layers_d } "
8383
8484 if pooling_method is None :
85- self .pool = None
86- elif pooling_method .lower () == "max" :
87- self .pool = nn .MaxPool2d (kernel_size = kernel_size , stride = 2 )
88- elif pooling_method .lower () == "avg" :
89- self .pool = nn .AvgPool2d (kernel_size = kernel_size , stride = 2 )
85+ pool = None
9086 else :
91- raise ValueError ( f"MultiScalePatchDiscriminator: Pooling method { pooling_method } is not supported." )
87+ pool = get_pool_layer (( pooling_method , { "kernel_size" : kernel_size , "stride" : 2 }), spatial_dims = spatial_dims )
9288 print (
93- f"Initialising MultiScalePatchDiscriminator with { self .num_d } discriminators, { self .num_layers_d } layers and pooling method { self . pool .__class__ .__name__ } ."
89+ f"Initialising { spatial_dims } D MultiScalePatchDiscriminator with { self .num_d } discriminators, { self .num_layers_d } layers and pooling method { pool .__class__ .__name__ } ."
9490 )
9591 self .num_channels = num_channels
9692 self .padding = tuple ([int ((kernel_size - 1 ) / 2 )] * spatial_dims )
@@ -102,7 +98,7 @@ def __init__(
10298 "Your image size is too small to take in up to %d discriminators with num_layers = %d."
10399 "Please reduce num_layers, reduce num_D or enter bigger images." % (i_ , num_layers_d_i )
104100 )
105- if i_ == 0 or self . pool is None :
101+ if i_ == 0 or pool is None :
106102 subnet_d = PatchDiscriminator (
107103 spatial_dims = spatial_dims ,
108104 num_channels = self .num_channels ,
@@ -119,7 +115,7 @@ def __init__(
119115 )
120116 else :
121117 subnet_d = nn .Sequential (
122- * [self . pool ] * i_ ,
118+ * [pool ] * i_ ,
123119 PatchDiscriminator (
124120 spatial_dims = spatial_dims ,
125121 num_channels = self .num_channels ,
0 commit comments