@@ -39,6 +39,7 @@ def save_list(item2index, path):
3939def clip (v , low , high ):
4040 return max (low , min (v , high ))
4141
42+
4243# Basic unit of Dataset
4344Question = namedtuple ('Question' ,
4445 ['id' , 'content' , 'answer' , 'false_options' , 'labels' ])
@@ -334,7 +335,7 @@ def __init__(
334335 option_key = lambda x : x ['ques_options' ],
335336 pipeline = None ,
336337 skip = 0
337- ):
338+ ):
338339
339340 self .filename = filename
340341 self .skip = skip
@@ -349,7 +350,7 @@ def __init__(
349350 tokenizer = QuesNetTokenizer (
350351 meta = ['know_name' ],
351352 img_dir = img_dir
352- )
353+ )
353354 self .tokenizer = tokenizer
354355 self .meta = meta if meta else tokenizer .meta
355356 self .load_data_lines ()
@@ -358,16 +359,15 @@ def __init__(
358359 key = lambda x : x ['ques_content' ],
359360 trim_min_count = 2 ,
360361 silent = False
361- )
362+ )
362363 tokenizer .set_meta_vocab (self .lines , silent = False )
363-
364364
365365 def load_data_lines (self ):
366366 '''Read data by row from a JSON file
367-
367+
368368 Important: the data file is loaded during initialization.
369369 '''
370-
370+
371371 # TODO: All data is read into memory without chunking.
372372 # This may lead to low efficiency.
373373 data_dir = self .filename
@@ -402,7 +402,7 @@ def __getitem__(self, index):
402402 meta = token ['meta_idx' ]
403403
404404 if self .answer_key (line ).isalpha () and len (self .answer_key (line )) == 1 \
405- and ord (self .answer_key (line )) < 128 and len (self .option_key (line )) > 0 :
405+ and ord (self .answer_key (line )) < 128 and len (self .option_key (line )) > 0 :
406406 answer_idx = ord (self .answer_key (line ).upper ()) - ord ('A' )
407407 options = self .option_key (line )
408408 answer = self .tokenizer (options .pop (answer_idx ), meta = self .meta )['seq_idx' ]
@@ -417,7 +417,7 @@ def __getitem__(self, index):
417417 answer = answer ,
418418 false_options = false_options ,
419419 labels = meta
420- )
420+ )
421421
422422 if callable (self .pipeline ):
423423 qs = self .pipeline (qs )
@@ -556,17 +556,25 @@ def pretrain_embedding_layer(dataset: EmbeddingDataset, ae: AE, lr: float = 1e-3
556556
557557def optimizer (* models , ** kwargs ):
558558 _cur_optim = [
559- m .optim_cls (m .parameters (), ** kwargs )
560- if hasattr (m , 'optim_cls' )
559+ m .optim_cls (m .parameters (), ** kwargs )
560+ if hasattr (m , 'optim_cls' )
561561 else torch .optim .Adam (m .parameters (), ** kwargs ) for m in models
562- ]
562+ ]
563563 if len (_cur_optim ) == 1 :
564564 return _cur_optim [0 ]
565565 else :
566566 return _cur_optim
567567
568-
569- def pretrain_quesnet (path , output_dir , pretrain_dir = None , img_dir = None , save_embs = False , load_embs = False , train_params = None ):
568+
569+ def pretrain_quesnet (
570+ path ,
571+ output_dir ,
572+ pretrain_dir = None ,
573+ img_dir = None ,
574+ save_embs = False ,
575+ load_embs = False ,
576+ train_params = None
577+ ):
570578 """ pretrain quesnet
571579
572580 Parameters
@@ -672,7 +680,7 @@ def pretrain_quesnet(path, output_dir, pretrain_dir=None, img_dir=None, save_emb
672680 sentences = [[item ] for item in emb_dict .keys ()],
673681 min_count = 1 ,
674682 vector_size = emb_size
675- )
683+ )
676684 gensim_w2v .init_weights ()
677685 gensim_w2v .train (corpus_iterable = w2v_corpus , total_examples = len (w2v_corpus ), epochs = train_params ['n_epochs' ])
678686 w2v_emb = gensim_w2v .syn1neg
@@ -699,7 +707,7 @@ def pretrain_quesnet(path, output_dir, pretrain_dir=None, img_dir=None, save_emb
699707 batch_size = train_params ['batch_size' ],
700708 epochs = train_params ['n_epochs' ],
701709 device = device
702- )
710+ )
703711 if save_embs :
704712 torch .save (trained_ie .state_dict (), os .path .join (output_dir , 'trained_ie.pt' ))
705713 model .quesnet .load_img (trained_ie )
@@ -718,7 +726,7 @@ def pretrain_quesnet(path, output_dir, pretrain_dir=None, img_dir=None, save_emb
718726 batch_size = train_params ['batch_size' ],
719727 epochs = train_params ['n_epochs' ],
720728 device = device
721- )
729+ )
722730 if save_embs :
723731 torch .save (trained_me .state_dict (), os .path .join (output_dir , 'trained_me.pt' ))
724732 model .quesnet .load_meta (trained_me )
0 commit comments