|
68 | 68 | " ------------\n", |
69 | 69 | " Input Image: img_shape\n", |
70 | 70 | " Flattened\n", |
71 | | - " Linear MLP(128, 512, 256, 1)\n", |
72 | | - " Relu activation after every layer except last.\n", |
| 71 | + " Linear MLP(1024, 512, 256, 1)\n", |
| 72 | + " Leaky Relu activation after every layer except last.\n", |
73 | 73 | " Sigmoid activation after last layer to normalize in range 0 to 1\n", |
74 | 74 | " \"\"\"\n", |
75 | 75 | " def __init__(self, img_shape):\n", |
|
122 | 122 | "gen_optim = torch.optim.Adam(generator.parameters(), lr=2e-4)\n", |
123 | 123 | "disc_optim = torch.optim.Adam(discriminator.parameters(), lr=2e-4)\n", |
124 | 124 | "\n", |
125 | | - "# use gpu if possible\n", |
| 125 | + "# .to(device) moves the networks / models to that device, which is either CPU or the GPU depending on what was detected\n", |
| 126 | + "# if moved to GPU, then the networks can make use of the GPU for computations which is much faster!\n", |
126 | 127 | "generator = generator.to(device)\n", |
127 | 128 | "discriminator = discriminator.to(device)" |
128 | 129 | ] |
|
140 | 141 | " print(\"Epoch {}\".format(epoch))\n", |
141 | 142 | " avg_g_loss = 0\n", |
142 | 143 | " avg_d_loss = 0\n", |
| 144 | + " \n", |
| 145 | + " # notebook.tqdm is a nice way of displaying progress on a jupyter or colab notebook while we loop over the data in train_dataloader\n", |
143 | 146 | " pbar = notebook.tqdm(train_dataloader, total=len(train_dataloader))\n", |
144 | 147 | " i = 0\n", |
145 | 148 | " for data in pbar:\n", |
146 | 149 | " i += 1\n", |
147 | 150 | " real_images = data[0].to(device)\n", |
148 | 151 | " ### Train Generator ###\n", |
| 152 | + "\n", |
| 153 | + " # .zero_grad() is important in PyTorch. Don't forget it. If you do, the optimizer won't work.\n", |
149 | 154 | " generator_optim.zero_grad()\n", |
150 | 155 | " \n", |
151 | 156 | " latent_input = torch.randn((len(real_images), 1, *latent_shape)).to(device)\n", |
152 | 157 | " fake_images = generator(latent_input)\n", |
153 | 158 | "\n", |
154 | 159 | " fake_res = discriminator(fake_images)\n", |
155 | 160 | " \n", |
| 161 | + " # we penalize the generator for being unable to make the discrminator predict 1s for generated fake images\n", |
156 | 162 | " generator_loss = adversarial_loss(fake_res, torch.ones_like(fake_res))\n", |
| 163 | + "\n", |
| 164 | + " # .backward() computes gradients for the loss function with respect to anything that is not detached\n", |
157 | 165 | " generator_loss.backward()\n", |
| 166 | + " # .step() uses a optimizer to apply the gradients to the model parameters, updating the model to reduce the loss\n", |
158 | 167 | " generator_optim.step()\n", |
159 | 168 | " \n", |
160 | 169 | " ### Train Discriminator ###\n", |
161 | 170 | " discriminator_optim.zero_grad()\n", |
162 | 171 | " \n", |
163 | 172 | " real_res = discriminator(real_images)\n", |
164 | 173 | "\n", |
| 174 | + " # .detach() removes fake_images variable from gradient computation, meaning our \n", |
| 175 | + " # generator is not going to be updated when we use the optimizer\n", |
165 | 176 | " fake_res = discriminator(fake_images.detach())\n", |
166 | 177 | "\n", |
| 178 | + " # we penalize the discriminator for not predicting 1s for real images\n", |
167 | 179 | " discriminator_real_loss = adversarial_loss(real_res, torch.ones_like(real_res))\n", |
| 180 | + " # we penalize the discriminator for not predicting 0s for generated, fake images\n", |
168 | 181 | " discriminator_fake_loss = adversarial_loss(fake_res, torch.zeros_like(real_res))\n", |
169 | 182 | " \n", |
170 | 183 | " discriminator_loss = (discriminator_real_loss + discriminator_fake_loss) / 2\n", |
|
0 commit comments