Skip to content

Commit 536b7b6

Browse files
committed
add ddpm utface data load/train example
1 parent 24946ec commit 536b7b6

File tree

7 files changed

+83
-15
lines changed

7 files changed

+83
-15
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,5 @@ dmypy.json
136136
*.npy
137137
*.png
138138
*.jpeg
139+
*.zip
140+
*.nt

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,18 @@ Some [examples](examples/) were trained on the [MNIST](https://pjreddie.com/proj
145145

146146
#### More details about some of them:
147147

148+
<details>
149+
<summary>Denoising Diffusion Probabilistic Model (DDPM)</summary>
150+
151+
<p align="center">
152+
<img src="generated images/ddpm_mnist.gif" width=20% height=20%>
153+
<img src="generated images/ddpm_utkface.gif" width=20% height=20%>
154+
</p>
155+
156+
Code:
157+
*[Model Example](examples/ddpm.py)*
158+
</details>
159+
148160
<details>
149161
<summary>Convolutional Classifier</summary>
150162

data_loader.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import zipfile
12
from pathlib import Path
23

34
import numpy as np
@@ -6,13 +7,13 @@
67
from mnist_data_downloader import download_data
78

89

9-
def prepare_data(data):
10+
def prepare_mnist_data(data):
1011
inputs, targets = [], []
1112

1213
for raw_line in tqdm(data, desc="preparing data"):
1314
line = raw_line.split(",")
1415

15-
inputs.append(np.asfarray(line[1:]))
16+
inputs.append(np.asfarray(line[1:]) / 127.5 - 1) # normalization: / 255 => [0; 1] #/ 127.5-1 => [-1; 1]
1617
targets.append(int(line[0]))
1718

1819
return inputs, targets
@@ -31,10 +32,10 @@ def load_mnist(path="datasets/mnist/"):
3132

3233

3334
if not (Path(path) / "mnist_train.npy").exists() or not (Path(path) / "mnist_test.npy").exists():
34-
training_inputs, training_targets = prepare_data(training_data)
35+
training_inputs, training_targets = prepare_mnist_data(training_data)
3536
training_inputs = np.asfarray(training_inputs)
3637

37-
test_inputs, test_targets = prepare_data(test_data)
38+
test_inputs, test_targets = prepare_mnist_data(test_data)
3839
test_inputs = np.asfarray(test_inputs)
3940

4041
np.save(path + "mnist_train.npy", training_inputs)
@@ -53,3 +54,47 @@ def load_mnist(path="datasets/mnist/"):
5354
test_dataset = test_inputs
5455

5556
return training_dataset, test_dataset, training_targets, test_targets
57+
58+
59+
import os
60+
61+
62+
def prepare_utkface_data(path, image_size = (3, 32, 32)):
63+
64+
import random
65+
66+
import numpy as np
67+
from PIL import Image
68+
69+
images = os.listdir(path)
70+
random.shuffle(images)
71+
72+
training_inputs = []
73+
for image in tqdm(images, desc = 'preparing data'):
74+
image = Image.open(path + "/" + image)
75+
image = image.resize((image_size[1], image_size[2]))
76+
image = np.asarray(image)
77+
image = image.transpose(2, 0, 1)
78+
image = image / 127.5 - 1
79+
training_inputs.append(image)
80+
81+
return np.array(training_inputs)
82+
83+
84+
def load_utkface(path="datasets/utkface/", image_size=(3, 32, 32)):
85+
path = Path(path)
86+
if not path.exists():
87+
path.mkdir(parents=True)
88+
89+
if not (path / 'UTKFace').exists():
90+
with zipfile.ZipFile(path / 'archive.zip', 'r') as zip_ref:
91+
zip_ref.extractall(path)
92+
93+
save_path = path / 'UTKFace.npy'
94+
if not save_path.exists():
95+
training_inputs = prepare_utkface_data(path / 'UTKFace', image_size)
96+
np.save(save_path, training_inputs)
97+
else:
98+
training_inputs = np.load(save_path)
99+
100+
return training_inputs

datasets/utkface/readme.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
For training examples you need to download the zip from the link: https://www.kaggle.com/datasets/jangedoo/utkface-new and drop it into this folder.

examples/ddpm.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import neunet as nnet
1414
import neunet.nn as nn
15-
from data_loader import load_mnist
15+
from data_loader import load_mnist, load_utkface # noqa F401
1616
from neunet import Tensor
1717
from neunet.optim import Adam
1818

@@ -274,7 +274,7 @@ def denormalize(x):
274274
else:
275275
return Image.fromarray(images_array)
276276

277-
def train(self, dataset, epochs, batch_size, image_path, image_size, save_every_epochs=1):
277+
def train(self, dataset, epochs, batch_size, image_path, image_size, save_path, save_every_epochs=1):
278278
channels, H_size, W_size = image_size
279279

280280
data_batches = np.array_split(dataset, np.arange(batch_size, len(dataset), batch_size))
@@ -333,6 +333,11 @@ def train(self, dataset, epochs, batch_size, image_path, image_size, save_every_
333333
loop=0,
334334
)
335335

336+
if not Path(save_path).exists():
337+
Path(save_path).mkdir(parents=True, exist_ok=True)
338+
339+
nnet.save(self.model.state_dict(), f"{save_path}/ddpm_{epoch + 1}.nt")
340+
336341
loss_history.append(epoch_loss)
337342

338343
return loss_history
@@ -506,29 +511,32 @@ def forward(self, x, t):
506511

507512

508513
device = "cuda"
514+
image_size = (3, 32, 32)
515+
# image_size = (1, 28, 28) # for mnist
509516

517+
training_data = load_utkface(image_size=(3, 32, 32))
518+
# training_data, _, _, _ = load_mnist() # for mnist
510519

511520
diffusion = Diffusion(
512521
model=SimpleUNet(
513-
image_channels=1,
514-
image_size=28,
515-
down_channels=(32, 64, 128),
516-
up_channels=(128, 64, 32),
522+
image_channels=image_size[0],
523+
image_size = image_size[2],
524+
down_channels=(128, 256, 512, 1024), # (32, 64, 128) for mnist
525+
up_channels=(1024, 512, 256, 128), # (128, 64, 32) for mnist
517526
).to(device),
518527
timesteps=300,
519528
beta_start=0.0001,
520529
beta_end=0.02,
521530
criterion=nn.MSELoss(),
522531
)
523532

524-
training_data, test_data, training_labels, test_labels = load_mnist()
525-
training_data = training_data / 127.5 - 1 # normalization: / 255 => [0; 1] #/ 127.5-1 => [-1; 1]
533+
# diffusion.model.load_state_dict(nnet.load("saved models/utkface_model/ddpm_3.nt")) # load saved model example if it exists
526534

527-
# diffusion.ddpm_denoise_sample(25, (1, 28, 28))
528535
diffusion.train(
529536
training_data,
530537
epochs=3,
531-
batch_size=16,
538+
batch_size=5,
532539
image_path="generated images",
533-
image_size=(1, 28, 28),
540+
save_path = "saved models/utkface_model", # for "saved models/mnist_model" mnist
541+
image_size=image_size,
534542
)
7.67 MB
Loading

generated images/ddpm_utkface.gif

10.4 MB
Loading

0 commit comments

Comments
 (0)