Skip to content

Commit 1f26c8b

Browse files
committed
minor changes
1 parent 6bdc202 commit 1f26c8b

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

examples/seq2seq.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1-
import numpy as np
21
import math
3-
import neunet
4-
import neunet.nn as nn
5-
import matplotlib.pyplot as plt
6-
from neunet.optim import Adam
7-
from neunet import Tensor
8-
from tqdm import tqdm
92
from pathlib import Path
10-
from data_loader import load_multi30k
3+
4+
import matplotlib.pyplot as plt
5+
import numpy as np
116
from tokenizers import ByteLevelBPETokenizer
127
from tokenizers.processors import TemplateProcessing
8+
from tqdm import tqdm
139

10+
import neunet
11+
import neunet.nn as nn
12+
from data_loader import load_multi30k
13+
from neunet import Tensor
14+
from neunet.optim import Adam
1415

1516
"""
1617
Seq2Seq Transformer for language translation from English to German
@@ -242,10 +243,10 @@ def forward(self, src: np.ndarray, tgt: np.ndarray) -> tuple[Tensor, Tensor]:
242243

243244
BATCH_SIZE = 32
244245

245-
PAD_TOKEN = '<pad>'
246-
SOS_TOKEN = '<sos>'
247-
EOS_TOKEN = '<eos>'
248-
UNK_TOKEN = '<unk>'
246+
PAD_TOKEN = '<pad>' # noqa: S105
247+
SOS_TOKEN = '<sos>' # noqa: S105
248+
EOS_TOKEN = '<eos>' # noqa: S105
249+
UNK_TOKEN = '<unk>' # noqa: S105
249250

250251
DATASET_PATH = Path("./datasets/multi30k/")
251252
SAVE_PATH = Path("./saved models/seq2seq/")
@@ -460,7 +461,7 @@ def train(train_data: np.ndarray, val_data: np.ndarray, epochs: int, save_every_
460461

461462
neunet.save(model.state_dict(), f"{save_path}/seq2seq_{epoch + 1}.nt")
462463
else:
463-
print(f'Current validation loss is higher than previous. Not saved.')
464+
print('Current validation loss is higher than previous. Not saved.')
464465
break
465466

466467
return train_loss_history, val_loss_history

multi30k_data_downloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
def download_multi30k_data(urls, path, filenames):
99
for _, (url, filename) in enumerate(zip(urls, filenames, strict=False)):
10-
resp = requests.get(url, stream=True, verify=False)
10+
resp = requests.get(url, stream=True, verify=False, timeout=10)
1111
total = int(resp.headers.get('content-length', 0))
1212
with open(Path(path) / filename, 'wb') as file, tqdm(
1313
desc = f'downloading {filename = } to {path = }',

0 commit comments

Comments
 (0)