@@ -157,12 +157,12 @@ def _convert_to_ids(self, item: Union[str, dict, list], key=lambda x: x,
157157 token_idx = []
158158 for _ , w in enumerate (token_item ):
159159 if isinstance (w , FigureSegment ) and 'ques_figure_ids' in item .keys ():
160- # image
160+ # image
161161
162162 try :
163163 fig_id = f"{ w .src [10 :- 1 ]} "
164164 fig_index = item ['ques_figure_ids' ].index (fig_id )
165-
165+
166166 if self .img_dir != "" :
167167 fig_src = os .path .join (self .img_dir , fig_id )
168168 if '.png' in item ['ques_figure_paths' ][fig_index ]:
@@ -171,12 +171,12 @@ def _convert_to_ids(self, item: Union[str, dict, list], key=lambda x: x,
171171 fig_src += '.jpg'
172172 else :
173173 fig_src = item ['ques_figure_paths' ][fig_index ]
174-
174+
175175 print (f"Open figure { fig_src } " )
176176 im = Image .open (fig_src )
177177 im = im .resize ((56 , 56 ))
178178 token_idx .append (to_grayscale (im ))
179-
179+
180180 except Exception :
181181 warnings .warn ('Open image error!' )
182182 token_idx .append (self .stoi ['word' ][self .img_token ])
@@ -316,23 +316,26 @@ def padding(self, idx, max_length, type='word'):
316316
317317 def set_img_dir (self , path ):
318318 self .img_dir = path
319-
320-
319+
320+
321321class QuesnetDataset (Dataset ):
322322 '''
323323 Quesnet-specific datasets
324324 '''
325- def __init__ (self , filename : str ,
326- tokenizer : QuesNetTokenizer = None ,
327- img_dir : str = "" ,
328- meta : Optional [list ] = None ,
329- content_key = lambda x : x ['ques_content' ],
330- meta_key = lambda x : x ['know_name' ],
331- answer_key = lambda x : x ['ques_answer' ],
332- option_key = lambda x : x ['ques_options' ],
333- pipeline = None ,
334- skip = 0
335- ):
325+ def __init__ (
326+ self ,
327+ filename : str ,
328+ tokenizer : QuesNetTokenizer = None ,
329+ img_dir : str = "" ,
330+ meta : Optional [list ] = None ,
331+ content_key = lambda x : x ['ques_content' ],
332+ meta_key = lambda x : x ['know_name' ],
333+ answer_key = lambda x : x ['ques_answer' ],
334+ option_key = lambda x : x ['ques_options' ],
335+ pipeline = None ,
336+ skip = 0
337+ ):
338+
336339 self .filename = filename
337340 self .skip = skip
338341 self .img_dir = img_dir
@@ -341,16 +344,22 @@ def __init__(self, filename: str,
341344 self .answer_key = answer_key
342345 self .option_key = option_key
343346 self .pipeline = pipeline
344-
345- if tokenizer == None :
346- tokenizer = QuesNetTokenizer (meta = ['know_name' ],
347- img_dir = img_dir )
347+
348+ if tokenizer is None :
349+ tokenizer = QuesNetTokenizer (
350+ meta = ['know_name' ],
351+ img_dir = img_dir
352+ )
348353 self .tokenizer = tokenizer
349354 self .meta = meta if meta else tokenizer .meta
350355 self .load_data_lines ()
351- tokenizer .set_vocab (self .lines , key = lambda x : x ['ques_content' ],
352- trim_min_count = 2 , silent = False )
353- tokenizer .set_meta_vocab (self .lines , silent = False )
356+ tokenizer .set_vocab (
357+ self .lines ,
358+ key = lambda x : x ['ques_content' ],
359+ trim_min_count = 2 ,
360+ silent = False
361+ )
362+ tokenizer .set_meta_vocab (self .lines , silent = False )
354363
355364
356365 def load_data_lines (self ):
@@ -376,15 +385,13 @@ def load_data_lines(self):
376385 if not line :
377386 break
378387 self .lines .append (json .loads (line .strip ()))
379-
388+
380389 self .length = row - skip - 1
381390 assert self .length > 0 , f'{ data_dir } is empty. Or file length is less than skip length.'
382391
383-
384392 def __len__ (self ):
385393 return len (self .lines )
386394
387-
388395 def __getitem__ (self , index ):
389396 if isinstance (index , int ):
390397 line = self .lines [index ]
@@ -393,33 +400,40 @@ def __getitem__(self, index):
393400 token = self .tokenizer (line , key = self .content_key , meta = self .meta )
394401 content = token ['seq_idx' ]
395402 meta = token ['meta_idx' ]
396-
397- if self .answer_key (line ).isalpha () and len (self .answer_key (line )) == 1 and ord (self .answer_key (line )) < 128 and len (self .option_key (line )) > 0 :
403+
404+ if self .answer_key (line ).isalpha () and len (self .answer_key (line )) == 1 \
405+ and ord (self .answer_key (line )) < 128 and len (self .option_key (line )) > 0 :
398406 answer_idx = ord (self .answer_key (line ).upper ()) - ord ('A' )
399407 options = self .option_key (line )
400408 answer = self .tokenizer (options .pop (answer_idx ), meta = self .meta )['seq_idx' ]
401409 false_options = [(self .tokenizer (option , meta = self .meta ))['seq_idx' ] for option in options ]
402410 else :
403411 answer = (self .tokenizer (self .answer_key (line ), meta = self .meta ))['seq_idx' ]
404412 false_options = [[0 ], [0 ], [0 ]]
405-
406- qs = Question (id = qid , content = content , answer = answer ,
407- false_options = false_options , labels = meta )
413+
414+ qs = Question (
415+ id = qid ,
416+ content = content ,
417+ answer = answer ,
418+ false_options = false_options ,
419+ labels = meta
420+ )
421+
408422 if callable (self .pipeline ):
409- qs = self .pipeline (qs )
410-
423+ qs = self .pipeline (qs )
424+
411425 return qs
412-
413-
426+
414427 elif isinstance (index , slice ):
415428 results = []
416429 for i in range (* index .indices (len (self ))):
417430 results .append (self [i ])
418431 return results
419-
432+
420433 else :
421434 raise TypeError ('Invalid argument type. Index type should be int or slice.' )
422-
435+
436+
423437class EmbeddingDataset (Dataset ):
424438 def __init__ (self , data , data_type = 'image' ):
425439 self .data = data
@@ -434,7 +448,7 @@ def __getitem__(self, idx):
434448 return to_tensor (self .data [idx ])
435449 elif self .data_type == 'meta' :
436450 return self .data [idx ]
437-
451+
438452
439453class PrefetchIter :
440454 """Iterator on data and labels, with states for save and restore."""
@@ -445,7 +459,7 @@ def __init__(self, data, *label, length=None, batch_size=1, shuffle=True):
445459 self .batch_size = batch_size
446460 self .queue = queue .Queue (maxsize = 8 )
447461 self .length = length if length is not None else len (data )
448-
462+
449463 assert all (self .length == len (lab ) for lab in label ), \
450464 'data and label must have same lengths'
451465
@@ -497,10 +511,11 @@ def produce(self):
497511 except Exception as e :
498512 self .queue .put (e )
499513 return
500-
514+
501515
502516sigint_handler = signal .getsignal (signal .SIGINT )
503517
518+
504519def critical (f ):
505520 it = iter (f )
506521 signal_received = ()
@@ -519,7 +534,7 @@ def handler(sig, frame):
519534 sigint_handler (* signal_received )
520535 except StopIteration :
521536 break
522-
537+
523538
524539def pretrain_embedding_layer (dataset : EmbeddingDataset , ae : AE , lr : float = 1e-3 , log_step : int = 1 , epochs : int = 3 ,
525540 batch_size : int = 4 , device = torch .device ('cpu' )):
@@ -538,18 +553,20 @@ def pretrain_embedding_layer(dataset: EmbeddingDataset, ae: AE, lr: float = 1e-3
538553 logger .info (f"[Epoch{ i } ][Batch{ batch } ]Training { train_type } Embedding layer, loss:{ loss } " )
539554 return ae
540555
556+
541557def optimizer (* models , ** kwargs ):
542- _cur_optim = [m .optim_cls (m .parameters (), ** kwargs )
543- if hasattr (m , 'optim_cls' )
544- else torch .optim .Adam (m .parameters (), ** kwargs )
545- for m in models ]
558+ _cur_optim = [
559+ m .optim_cls (m .parameters (), ** kwargs )
560+ if hasattr (m , 'optim_cls' )
561+ else torch .optim .Adam (m .parameters (), ** kwargs ) for m in models
562+ ]
546563 if len (_cur_optim ) == 1 :
547564 return _cur_optim [0 ]
548565 else :
549- return _cur_optim
550-
566+ return _cur_optim
567+
551568
552- def pretrain_quesnet (path , output_dir , pretrain_dir = None , img_dir = None , save_embs = False , load_embs = False , train_params = None ):
569+ def pretrain_quesnet (path , output_dir , pretrain_dir = None , img_dir = None , save_embs = False , load_embs = False , train_params = None ):
553570 """ pretrain quesnet
554571
555572 Parameters
@@ -597,7 +614,7 @@ def pretrain_quesnet(path, output_dir, pretrain_dir = None, img_dir = None, save
597614 """
598615 os .makedirs (output_dir , exist_ok = True )
599616 device = torch .device (train_params ['device' ])
600-
617+
601618 default_train_params = {
602619 # train params
603620 "n_epochs" : 1 ,
@@ -648,11 +665,14 @@ def pretrain_quesnet(path, output_dir, pretrain_dir = None, img_dir = None, save
648665 meta_corpus .append (meta_vector )
649666
650667 # train word2vec for text embedding
651- if pretrain_dir != None and load_embs :
668+ if pretrain_dir is not None and load_embs :
652669 model .quesnet .load_emb (np .load (os .path .join (output_dir , 'w2v_embs.npy' )))
653670 else :
654- gensim_w2v = Word2Vec (sentences = [[item ] for item in emb_dict .keys ()], min_count = 1 ,
655- vector_size = emb_size )
671+ gensim_w2v = Word2Vec (
672+ sentences = [[item ] for item in emb_dict .keys ()],
673+ min_count = 1 ,
674+ vector_size = emb_size
675+ )
656676 gensim_w2v .init_weights ()
657677 gensim_w2v .train (corpus_iterable = w2v_corpus , total_examples = len (w2v_corpus ), epochs = train_params ['n_epochs' ])
658678 w2v_emb = gensim_w2v .syn1neg
@@ -667,35 +687,46 @@ def pretrain_quesnet(path, output_dir, pretrain_dir = None, img_dir = None, save
667687 logger .info ('quesnet Word Embedding loaded' )
668688
669689 # train auto-encoder loss for image embedding
670- if pretrain_dir != None and load_embs :
690+ if pretrain_dir is not None and load_embs :
671691 model .quesnet .load_img (torch .load (os .path .join (pretrain_dir , 'trained_ie.pt' )))
672692 else :
673693 img_dataset = EmbeddingDataset (data = img_corpus , data_type = 'image' )
674- trained_ie = pretrain_embedding_layer (dataset = img_dataset , ae = model .quesnet .ie , lr = train_params ['lr' ],
675- log_step = train_params ['log_steps' ], batch_size = train_params ['batch_size' ],
676- epochs = train_params ['n_epochs' ], device = device )
694+ trained_ie = pretrain_embedding_layer (
695+ dataset = img_dataset ,
696+ ae = model .quesnet .ie ,
697+ lr = train_params ['lr' ],
698+ log_step = train_params ['log_steps' ],
699+ batch_size = train_params ['batch_size' ],
700+ epochs = train_params ['n_epochs' ],
701+ device = device
702+ )
677703 if save_embs :
678704 torch .save (trained_ie .state_dict (), os .path .join (output_dir , 'trained_ie.pt' ))
679705 model .quesnet .load_img (trained_ie )
680706 logger .info ('quesnet Image Embedding loaded' )
681-
682707
683708 # train auto-encoder loss for meta embedding
684- if pretrain_dir != None and load_embs :
709+ if pretrain_dir is not None and load_embs :
685710 model .quesnet .load_meta (torch .load (os .path .join (pretrain_dir , 'trained_me.pt' )))
686711 else :
687712 meta_dateset = EmbeddingDataset (data = meta_corpus , data_type = 'meta' )
688- trained_me = pretrain_embedding_layer (dataset = meta_dateset , ae = model .quesnet .me , lr = train_params ['lr' ],
689- log_step = train_params ['log_steps' ], batch_size = train_params ['batch_size' ],
690- epochs = train_params ['n_epochs' ], device = device )
713+ trained_me = pretrain_embedding_layer (
714+ dataset = meta_dateset ,
715+ ae = model .quesnet .me ,
716+ lr = train_params ['lr' ],
717+ log_step = train_params ['log_steps' ],
718+ batch_size = train_params ['batch_size' ],
719+ epochs = train_params ['n_epochs' ],
720+ device = device
721+ )
691722 if save_embs :
692723 torch .save (trained_me .state_dict (), os .path .join (output_dir , 'trained_me.pt' ))
693724 model .quesnet .load_meta (trained_me )
694725 logger .info ('quesnet Meta Embedding loaded' )
695-
726+
696727 logger .info ("quesnet Word, Image and Meta Embeddings training is done" )
697728 # DONE for datasets
698-
729+
699730 # HLM and DOO training
700731 dataset .pipeline = partial (model .quesnet .make_batch , device = device , pretrain = True )
701732 model .train ()
@@ -726,6 +757,3 @@ def pretrain_quesnet(path, output_dir, pretrain_dir = None, img_dir = None, save
726757
727758 model .save_pretrained (output_dir )
728759 tokenizer .save_pretrained (output_dir )
729-
730-
731-
0 commit comments