|
1 | | -import numpy as np |
2 | 1 | import math |
3 | | -import neunet |
4 | | -import neunet.nn as nn |
5 | | -import matplotlib.pyplot as plt |
6 | | -from neunet.optim import Adam |
7 | | -from neunet import Tensor |
8 | | -from tqdm import tqdm |
9 | 2 | from pathlib import Path |
10 | | -from data_loader import load_multi30k |
| 3 | + |
| 4 | +import matplotlib.pyplot as plt |
| 5 | +import numpy as np |
11 | 6 | from tokenizers import ByteLevelBPETokenizer |
12 | 7 | from tokenizers.processors import TemplateProcessing |
| 8 | +from tqdm import tqdm |
13 | 9 |
|
| 10 | +import neunet |
| 11 | +import neunet.nn as nn |
| 12 | +from data_loader import load_multi30k |
| 13 | +from neunet import Tensor |
| 14 | +from neunet.optim import Adam |
14 | 15 |
|
15 | 16 | """ |
16 | 17 | Seq2Seq Transformer for language translation from English to German |
@@ -242,10 +243,10 @@ def forward(self, src: np.ndarray, tgt: np.ndarray) -> tuple[Tensor, Tensor]: |
242 | 243 |
|
243 | 244 | BATCH_SIZE = 32 |
244 | 245 |
|
245 | | -PAD_TOKEN = '<pad>' |
246 | | -SOS_TOKEN = '<sos>' |
247 | | -EOS_TOKEN = '<eos>' |
248 | | -UNK_TOKEN = '<unk>' |
| 246 | +PAD_TOKEN = '<pad>' # noqa: S105 |
| 247 | +SOS_TOKEN = '<sos>' # noqa: S105 |
| 248 | +EOS_TOKEN = '<eos>' # noqa: S105 |
| 249 | +UNK_TOKEN = '<unk>' # noqa: S105 |
249 | 250 |
|
250 | 251 | DATASET_PATH = Path("./datasets/multi30k/") |
251 | 252 | SAVE_PATH = Path("./saved models/seq2seq/") |
@@ -460,7 +461,7 @@ def train(train_data: np.ndarray, val_data: np.ndarray, epochs: int, save_every_ |
460 | 461 |
|
461 | 462 | neunet.save(model.state_dict(), f"{save_path}/seq2seq_{epoch + 1}.nt") |
462 | 463 | else: |
463 | | - print(f'Current validation loss is higher than previous. Not saved.') |
| 464 | + print('Current validation loss is higher than previous. Not saved.') |
464 | 465 | break |
465 | 466 |
|
466 | 467 | return train_loss_history, val_loss_history |
|
0 commit comments