Skip to content

Commit 9773642

Browse files
committed
add gpt example, minor changes
1 parent a662097 commit 9773642

File tree

10 files changed

+1542
-291
lines changed

10 files changed

+1542
-291
lines changed

README.md

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,13 @@ Some [examples](examples/) were trained on the [MNIST](https://pjreddie.com/proj
136136
3. *[Conway`s Game of Life](examples/conway.py)*
137137
4. *[Denoising Diffusion Probabilistic Model](examples/ddpm.py)*
138138
5. *[Generative Adversarial Network](examples/gan.py)*
139-
6. *[Recurrent Digits Classifier](examples/recurrent_digits_classifier.py)*
140-
7. *[Recurrent Sequences Classifier](examples/recurrent_sequences_classifier.py)*
141-
8. *[Seq2Seq Transformer](examples/seq2seq.py)*
142-
9. *[Variational Autoencoder](examples/vae.py)*
143-
10. *[Vector Quantized Variational Autoencoder](examples/vqvae.py)*
144-
11. *[Word2Vec](examples/word2vec.py)*
139+
6. *[Generative Pre-trained Transformer](examples/gpt.py)*
140+
7. *[Recurrent Digits Classifier](examples/recurrent_digits_classifier.py)*
141+
8. *[Recurrent Sequences Classifier](examples/recurrent_sequences_classifier.py)*
142+
9. *[Seq2Seq Transformer](examples/seq2seq.py)*
143+
10. *[Variational Autoencoder](examples/vae.py)*
144+
11. *[Vector Quantized Variational Autoencoder](examples/vqvae.py)*
145+
12. *[Word2Vec](examples/word2vec.py)*
145146

146147

147148

@@ -334,7 +335,7 @@ Code:
334335
<details>
335336
<summary>Seq2Seq Transformer</summary>
336337

337-
#### Examples of translated sentences of validation set:
338+
#### Examples of translated sentences (EN -> DE) of validation set:
338339

339340
>Example №1
340341
*Input sentence: These four people are standing outdoors, with 3 dogs.
@@ -654,6 +655,28 @@ Training process Example | Interpolation between images Example
654655
<img src="generated images/gan_training_process.gif"> | <img src="generated images/gan_vectors_interpolation.gif">
655656
</details>
656657

658+
<details>
659+
<summary>Generative Pre-trained Transformer</summary>
660+
661+
#### Examples of a model trained to generate prompts for Stable Diffusion:
662+
663+
>Example №1
664+
*a detailed image of a dark haired cyborg - car 3 d model, a glowing aura, symmetrical, intricate, elegant, highly detailed, digital painting, artstation, concept art, smooth, sharp focus, illustration, art by krenz cushart and artem demura*
665+
666+
>Example №2
667+
*an female warrior, full length, red hair, dark eyes, symmetrical face, highly detailed, digital art, sharp focus, trending on art station, anime art style*
668+
669+
>Example №3
670+
*portrait of a young ruggedly handsome but joyful pirate, male, masculine, upper body, red hair, long hair, d & d, fantasy, sharp features, piercing gaze, sharp features, digital painting, artstation, concept art, matte, sharp*
671+
672+
>Example №4
673+
*an anthropomorphic fox wizard, fine art, award winning, intricate, elegant, sharp focus, cinematic lighting, highly detailed, digital painting, 8 k concept art, art by guweiz and z. w. gu, masterpiece, trending on artstation*
674+
675+
>Example №5
676+
*a beautiful portrait painting of a cyberpunk city by simon stalenhag and pascal blanche and alphonse mucha, in style of colorful comic. symmetry, hyper detailed. octanev render. trending on artstation*
677+
678+
</details>
679+
657680
<details>
658681
<summary>Conway`s Game of Life Neural Network Simulation</summary>
659682

@@ -842,5 +865,5 @@ Native implementation Example | Neural network Example
842865

843866
### TODO:
844867
- [x] Add Seq2Seq Transformer example
845-
- [ ] Add GPT example
868+
- [x] Add GPT example
846869
- [ ] Add lr schedulers

data_loader.py

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -104,61 +104,6 @@ def load_utkface(path="datasets/utkface/", image_size=(3, 32, 32)):
104104

105105

106106

107-
def load_multi30k(path="datasets/multi30k/"):
108-
#References: https://pytorch.org/text/stable/_modules/torchtext/datasets/multi30k.html
109-
urls = {
110-
"train": r"https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz",
111-
"valid": r"https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz",
112-
"test": r"https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/mmt16_task1_test.tar.gz",
113-
}
114-
115-
filenames = ["mmt16_task1_test.tar.gz", "training.tar.gz", "validation.tar.gz"]
116-
117-
path = Path(path)
118-
if not path.exists():
119-
path.mkdir(parents=True)
120-
121-
download_multi30k_data(urls.values(), path, filenames)
122-
123-
for filename in filenames:
124-
tar = tarfile.open(Path(path) / filename)
125-
tar.extractall(path)
126-
tar.close()
127-
128-
print(f'Extracted {filename}')
129-
130-
131-
ret = []
132-
filenames = ["train", "val", "test"]
133-
134-
for filename in filenames:
135-
136-
examples = []
137-
138-
en_path = os.path.join(path, filename + '.en')
139-
de_path = os.path.join(path, filename + '.de')
140-
141-
en_file = [l.strip() for l in open(en_path, 'r', encoding='utf-8')]
142-
de_file = [l.strip() for l in open(de_path, 'r', encoding='utf-8')]
143-
144-
assert len(en_file) == len(de_file)
145-
146-
for i in range(len(en_file)):
147-
if en_file[i] == '' or de_file[i] == '':
148-
continue
149-
en_seq, de_seq = en_file[i], de_file[i]
150-
151-
examples.append({'en': en_seq, 'de': de_seq})
152-
153-
ret.append(examples)
154-
155-
train_dataset, valid_dataset, test_dataset = ret
156-
return train_dataset, valid_dataset, test_dataset
157-
158-
159-
160-
161-
162107

163108

164109

0 commit comments

Comments
 (0)