Skip to content

Commit 44bba3e

Browse files
committed
comments
1 parent 867fc26 commit 44bba3e

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

SP21/GAN/vanilla_gan.ipynb

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@
6868
" ------------\n",
6969
" Input Image: img_shape\n",
7070
" Flattened\n",
71-
" Linear MLP(128, 512, 256, 1)\n",
72-
" Relu activation after every layer except last.\n",
71+
" Linear MLP(1024, 512, 256, 1)\n",
72+
" Leaky Relu activation after every layer except last.\n",
7373
" Sigmoid activation after last layer to normalize in range 0 to 1\n",
7474
" \"\"\"\n",
7575
" def __init__(self, img_shape):\n",
@@ -122,7 +122,8 @@
122122
"gen_optim = torch.optim.Adam(generator.parameters(), lr=2e-4)\n",
123123
"disc_optim = torch.optim.Adam(discriminator.parameters(), lr=2e-4)\n",
124124
"\n",
125-
"# use gpu if possible\n",
125+
"# .to(device) moves the networks / models to that device, which is either CPU or the GPU depending on what was detected\n",
126+
"# if moved to GPU, then the networks can make use of the GPU for computations which is much faster!\n",
126127
"generator = generator.to(device)\n",
127128
"discriminator = discriminator.to(device)"
128129
]
@@ -140,31 +141,43 @@
140141
" print(\"Epoch {}\".format(epoch))\n",
141142
" avg_g_loss = 0\n",
142143
" avg_d_loss = 0\n",
144+
" \n",
145+
" # notebook.tqdm is a nice way of displaying progress on a jupyter or colab notebook while we loop over the data in train_dataloader\n",
143146
" pbar = notebook.tqdm(train_dataloader, total=len(train_dataloader))\n",
144147
" i = 0\n",
145148
" for data in pbar:\n",
146149
" i += 1\n",
147150
" real_images = data[0].to(device)\n",
148151
" ### Train Generator ###\n",
152+
"\n",
153+
" # .zero_grad() is important in PyTorch. Don't forget it. If you do, the optimizer won't work.\n",
149154
" generator_optim.zero_grad()\n",
150155
" \n",
151156
" latent_input = torch.randn((len(real_images), 1, *latent_shape)).to(device)\n",
152157
" fake_images = generator(latent_input)\n",
153158
"\n",
154159
" fake_res = discriminator(fake_images)\n",
155160
" \n",
161+
" # we penalize the generator for being unable to make the discrminator predict 1s for generated fake images\n",
156162
" generator_loss = adversarial_loss(fake_res, torch.ones_like(fake_res))\n",
163+
"\n",
164+
" # .backward() computes gradients for the loss function with respect to anything that is not detached\n",
157165
" generator_loss.backward()\n",
166+
" # .step() uses a optimizer to apply the gradients to the model parameters, updating the model to reduce the loss\n",
158167
" generator_optim.step()\n",
159168
" \n",
160169
" ### Train Discriminator ###\n",
161170
" discriminator_optim.zero_grad()\n",
162171
" \n",
163172
" real_res = discriminator(real_images)\n",
164173
"\n",
174+
" # .detach() removes fake_images variable from gradient computation, meaning our \n",
175+
" # generator is not going to be updated when we use the optimizer\n",
165176
" fake_res = discriminator(fake_images.detach())\n",
166177
"\n",
178+
" # we penalize the discriminator for not predicting 1s for real images\n",
167179
" discriminator_real_loss = adversarial_loss(real_res, torch.ones_like(real_res))\n",
180+
" # we penalize the discriminator for not predicting 0s for generated, fake images\n",
168181
" discriminator_fake_loss = adversarial_loss(fake_res, torch.zeros_like(real_res))\n",
169182
" \n",
170183
" discriminator_loss = (discriminator_real_loss + discriminator_fake_loss) / 2\n",

0 commit comments

Comments
 (0)