11"""
22Discriminator and Generator implementation from DCGAN paper
3+
4+ Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
5+ * 2020-11-01: Initial coding
6+ * 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
37"""
48
59import torch
@@ -11,9 +15,7 @@ def __init__(self, channels_img, features_d):
1115 super (Discriminator , self ).__init__ ()
1216 self .disc = nn .Sequential (
1317 # input: N x channels_img x 64 x 64
14- nn .Conv2d (
15- channels_img , features_d , kernel_size = 4 , stride = 2 , padding = 1
16- ),
18+ nn .Conv2d (channels_img , features_d , kernel_size = 4 , stride = 2 , padding = 1 ),
1719 nn .LeakyReLU (0.2 ),
1820 # _block(in_channels, out_channels, kernel_size, stride, padding)
1921 self ._block (features_d , features_d * 2 , 4 , 2 , 1 ),
@@ -34,7 +36,7 @@ def _block(self, in_channels, out_channels, kernel_size, stride, padding):
3436 padding ,
3537 bias = False ,
3638 ),
37- #nn.BatchNorm2d(out_channels),
39+ # nn.BatchNorm2d(out_channels),
3840 nn .LeakyReLU (0.2 ),
3941 )
4042
@@ -68,7 +70,7 @@ def _block(self, in_channels, out_channels, kernel_size, stride, padding):
6870 padding ,
6971 bias = False ,
7072 ),
71- #nn.BatchNorm2d(out_channels),
73+ # nn.BatchNorm2d(out_channels),
7274 nn .ReLU (),
7375 )
7476
@@ -82,6 +84,7 @@ def initialize_weights(model):
8284 if isinstance (m , (nn .Conv2d , nn .ConvTranspose2d , nn .BatchNorm2d )):
8385 nn .init .normal_ (m .weight .data , 0.0 , 0.02 )
8486
87+
8588def test ():
8689 N , in_channels , H , W = 8 , 3 , 64 , 64
8790 noise_dim = 100
@@ -91,6 +94,8 @@ def test():
9194 gen = Generator (noise_dim , in_channels , 8 )
9295 z = torch .randn ((N , noise_dim , 1 , 1 ))
9396 assert gen (z ).shape == (N , in_channels , H , W ), "Generator test failed"
97+ print ("Success, tests passed!" )
9498
9599
96- # test()
100+ if __name__ == "__main__" :
101+ test ()
0 commit comments