11import math
22from pathlib import Path
3+ from typing import Optional
34
45import matplotlib .pyplot as plt
56import numpy as np
910
1011import neunet
1112import neunet .nn as nn
12- from datasets import load_dataset
13+ from datasets import load_dataset # type: ignore
1314from neunet import Tensor
1415from neunet .optim import Adam
1516
@@ -27,7 +28,8 @@ def __init__(self, d_model, n_heads, dropout=0.1):
2728 self .scale = math .sqrt (d_model )
2829 self .dropout = nn .Dropout (dropout )
2930
30- assert d_model % n_heads == 0
31+ if d_model % n_heads != 0 :
32+ raise ValueError ("d_model must be divisible by n_heads" )
3133
3234 self .depth = d_model // n_heads
3335
@@ -37,7 +39,7 @@ def __init__(self, d_model, n_heads, dropout=0.1):
3739
3840 self .fc = nn .Linear (d_model , d_model )
3941
40- def forward (self , q : Tensor , k : Tensor , v : Tensor , mask : Tensor = None ):
42+ def forward (self , q : Tensor , k : Tensor , v : Tensor , mask : Optional [ Tensor ] = None ):
4143 batch_size = q .shape [0 ]
4244 q = self .wq (q ).contiguous ().reshape (batch_size , - 1 , self .n_heads , self .depth ).transpose (0 , 2 , 1 , 3 )
4345 k = self .wk (k ).contiguous ().reshape (batch_size , - 1 , self .n_heads , self .depth ).transpose (0 , 2 , 1 , 3 )
@@ -246,7 +248,7 @@ def forward(self, src: np.ndarray, tgt: np.ndarray) -> tuple[Tensor, Tensor]:
246248PAD_TOKEN = '<pad>' # noqa: S105
247249SOS_TOKEN = '<sos>' # noqa: S105
248250EOS_TOKEN = '<eos>' # noqa: S105
249- # UNK_TOKEN = '<unk>' # noqa: S105
251+ # UNK_TOKEN = '<unk>'
250252
251253DATASET_PATH = Path ("./datasets/multi30k/" )
252254SAVE_PATH = Path ("./saved models/seq2seq/" )
@@ -255,23 +257,22 @@ def forward(self, src: np.ndarray, tgt: np.ndarray) -> tuple[Tensor, Tensor]:
255257 data = load_dataset ("bentrevett/multi30k" , cache_dir = "datasets/multi30k" )
256258
257259 for split , split_dataset in data .items ():
258- with open (f"./datasets/multi30k/{ split } .en" , 'w' , encoding = 'utf-8' ) as f :
260+ with Path (f"./datasets/multi30k/{ split } .en" ). open ( 'w' , encoding = 'utf-8' ) as f :
259261 for item in split_dataset :
260262 f .write (item ['en' ] + '\n ' )
261263
262- with open (f"./datasets/multi30k/{ split } .de" , 'w' , encoding = 'utf-8' ) as f :
264+ with Path (f"./datasets/multi30k/{ split } .de" ). open ( 'w' , encoding = 'utf-8' ) as f :
263265 for item in split_dataset :
264266 f .write (item ['de' ] + '\n ' )
265267
266268FILE_PATHS = [DATASET_PATH / "train.en" , DATASET_PATH / "train.de" , DATASET_PATH / "val.en" , DATASET_PATH / "val.de" , DATASET_PATH / "test.en" , DATASET_PATH / "test.de" ]
267- FILE_PATHS = [str (path ) for path in FILE_PATHS ]
268269
269270
270271# [Train and load Tokenizer]
271272if not (SAVE_PATH / "vocab" ).exists ():
272273 tokenizer = ByteLevelBPETokenizer ()
273274
274- tokenizer .train (files = FILE_PATHS , vocab_size = 15000 , min_frequency = 1 , special_tokens = [
275+ tokenizer .train (files = [ str ( path ) for path in FILE_PATHS ] , vocab_size = 15000 , min_frequency = 1 , special_tokens = [
275276 PAD_TOKEN ,
276277 SOS_TOKEN ,
277278 EOS_TOKEN ,
@@ -298,7 +299,7 @@ class DataPreprocessor():
298299 def __init__ (self , tokenizer : ByteLevelBPETokenizer ):
299300 self .tokenizer = tokenizer
300301
301- self .tokenizer ._tokenizer .post_processor = TemplateProcessing (
302+ self .tokenizer ._tokenizer .post_processor = TemplateProcessing ( # noqa SLF001
302303 single = f"{ SOS_TOKEN } $A { EOS_TOKEN } " ,
303304 special_tokens = [
304305 (f"{ SOS_TOKEN } " , tokenizer .token_to_id (f"{ SOS_TOKEN } " )),
@@ -309,13 +310,13 @@ def __init__(self, tokenizer: ByteLevelBPETokenizer):
309310 # self.tokenizer.enable_truncation(max_length=128)
310311 self .tokenizer .enable_padding (pad_token = PAD_TOKEN )
311312
312- def tokenize (self , paths : list [str ], batch_size : int , lines_limit : int = None ) -> np .ndarray :
313+ def tokenize (self , paths : list [str ], batch_size : int , lines_limit : Optional [ int ] = None ) -> list [ np .ndarray ] :
313314 examples = []
314315
315316 for src_file in paths :
316317 print (f"Processing { src_file } " )
317- src_file = Path (src_file )
318- lines = src_file .read_text (encoding = "utf-8" ).splitlines ()
318+ path_src_file = Path (src_file )
319+ lines = path_src_file .read_text (encoding = "utf-8" ).splitlines ()
319320
320321 if lines_limit :
321322 lines = lines [:lines_limit ]
@@ -326,20 +327,20 @@ def tokenize(self, paths: list[str], batch_size: int, lines_limit: int = None) -
326327
327328 return examples
328329
329- def __call__ (self , paths : list [str ], batch_size : int , lines_limit : int = None ) -> np .ndarray :
330+ def __call__ (self , paths : list [str ], batch_size : int , lines_limit : Optional [ int ] = None ) -> list [ np .ndarray ] :
330331 return self .tokenize (paths , batch_size , lines_limit )
331332
332333
333334data_post_processor = DataPreprocessor (tokenizer )
334335
335- train_src = data_post_processor ([DATASET_PATH / "train.en" ], batch_size = BATCH_SIZE )
336- train_tgt = data_post_processor ([DATASET_PATH / "train.de" ], batch_size = BATCH_SIZE )
336+ train_src = data_post_processor ([str ( DATASET_PATH / "train.en" ) ], batch_size = BATCH_SIZE )
337+ train_tgt = data_post_processor ([str ( DATASET_PATH / "train.de" ) ], batch_size = BATCH_SIZE )
337338
338- val_src = data_post_processor ([DATASET_PATH / "val.en" ], batch_size = BATCH_SIZE )
339- val_tgt = data_post_processor ([DATASET_PATH / "val.de" ], batch_size = BATCH_SIZE )
339+ val_src = data_post_processor ([str ( DATASET_PATH / "val.en" ) ], batch_size = BATCH_SIZE )
340+ val_tgt = data_post_processor ([str ( DATASET_PATH / "val.de" ) ], batch_size = BATCH_SIZE )
340341
341- test_src = data_post_processor ([DATASET_PATH / "test.en" ], batch_size = BATCH_SIZE )
342- test_tgt = data_post_processor ([DATASET_PATH / "test.de" ], batch_size = BATCH_SIZE )
342+ test_src = data_post_processor ([str ( DATASET_PATH / "test.en" ) ], batch_size = BATCH_SIZE )
343+ test_tgt = data_post_processor ([str ( DATASET_PATH / "test.de" ) ], batch_size = BATCH_SIZE )
343344
344345
345346train_data = train_src , train_tgt
@@ -386,11 +387,11 @@ def __call__(self, paths: list[str], batch_size: int, lines_limit: int = None) -
386387
387388# [train, eval, predict methods definition]
388389
389- def train_step (source : np .ndarray , target : np .ndarray , epoch : int , epochs : int ) -> float :
390+ def train_step (source : list [ np .ndarray ] , target : list [ np .ndarray ] , epoch : int , epochs : int ) -> float :
390391 loss_history = []
391392 model .train ()
392393
393- tqdm_range = tqdm (enumerate (zip (source , target )), total = len (source ))
394+ tqdm_range = tqdm (enumerate (zip (source , target , strict = False )), total = len (source ))
394395 for batch_num , (source_batch , target_batch ) in tqdm_range :
395396
396397 output , _ = model .forward (source_batch , target_batch [:,:- 1 ])
@@ -419,11 +420,11 @@ def train_step(source: np.ndarray, target: np.ndarray, epoch: int, epochs: int)
419420
420421 return epoch_loss
421422
422- def eval (source : np .ndarray , target : np .ndarray ) -> float :
423+ def eval (source : list [ np .ndarray ] , target : list [ np .ndarray ] ) -> float :
423424 loss_history = []
424425 model .eval ()
425426
426- tqdm_range = tqdm (enumerate (zip (source , target )), total = len (source ))
427+ tqdm_range = tqdm (enumerate (zip (source , target , strict = False )), total = len (source ))
427428 for batch_num , (source_batch , target_batch ) in tqdm_range :
428429
429430 output , _ = model .forward (source_batch , target_batch [:,:- 1 ])
@@ -447,7 +448,7 @@ def eval(source: np.ndarray, target: np.ndarray) -> float:
447448 return epoch_loss
448449
449450
450- def train (train_data : np .ndarray , val_data : np .ndarray , epochs : int , save_every_epochs : int , save_path : str = None , validation_check : bool = False ):
451+ def train (train_data : tuple [ list [ np .ndarray ], list [ np . ndarray ]], val_data : tuple [ list [ np .ndarray ], list [ np . ndarray ]], epochs : int , save_every_epochs : int , save_path : Optional [ str ] = None , validation_check : bool = False ):
451452 best_val_loss = float ('inf' )
452453
453454 train_loss_history = []
@@ -547,23 +548,22 @@ def plot_loss_history(train_loss_history, val_loss_history):
547548
548549
549550
550- test_data = []
551+ raw_test_data : list [ dict [ str , str ]] = []
551552
552- with open (DATASET_PATH / "test.en" , 'r' ) as f :
553- en_file = [l .strip () for l in open (DATASET_PATH / "test.en" , 'r' , encoding = 'utf-8' )]
554- de_file = [l .strip () for l in open (DATASET_PATH / "test.de" , 'r' , encoding = 'utf-8' )]
553+ en_file = [l .strip () for l in Path (DATASET_PATH / "test.en" ).open ('r' , encoding = 'utf-8' )]
554+ de_file = [l .strip () for l in Path (DATASET_PATH / "test.de" ).open ('r' , encoding = 'utf-8' )]
555555
556556for i in range (len (en_file )):
557557 if en_file [i ] == '' or de_file [i ] == '' :
558558 continue
559559 en_seq , de_seq = en_file [i ], de_file [i ]
560560
561- test_data .append ({'en' : en_seq , 'de' : de_seq })
561+ raw_test_data .append ({'en' : en_seq , 'de' : de_seq })
562562
563563sentences_num = 10
564564
565- random_indices = np .random .randint (0 , len (test_data ), sentences_num )
566- sentences_selection = [test_data [i ] for i in random_indices ]
565+ random_indices = np .random .randint (0 , len (raw_test_data ), sentences_num )
566+ sentences_selection = [raw_test_data [i ] for i in random_indices ]
567567
568568# [Translate sentences from validation set]
569569for i , example in enumerate (sentences_selection ):
@@ -575,7 +575,8 @@ def plot_loss_history(train_loss_history, val_loss_history):
575575
576576
577577def plot_attention (sentence : str , translation : str , attention : Tensor , heads_num : int = 8 , rows_num : int = 2 , cols_num : int = 4 ):
578- assert rows_num * cols_num == heads_num
578+ if rows_num * cols_num != heads_num :
579+ raise ValueError ("heads_num must be equal to rows_num * cols_num" )
579580 attention = attention .detach ().cpu ().numpy ().squeeze ()
580581
581582 sentence = tokenizer .encode (sentence , add_special_tokens = False ).tokens
0 commit comments