Skip to content

Commit 35c1057

Browse files
update codes dor flake8
1 parent 729e506 commit 35c1057

File tree

2 files changed

+95
-68
lines changed

2 files changed

+95
-68
lines changed

EduNLP/Pretrain/quesnet_vec.py

Lines changed: 95 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,12 @@ def _convert_to_ids(self, item: Union[str, dict, list], key=lambda x: x,
157157
token_idx = []
158158
for _, w in enumerate(token_item):
159159
if isinstance(w, FigureSegment) and 'ques_figure_ids' in item.keys():
160-
# image
160+
# image
161161

162162
try:
163163
fig_id = f"{w.src[10:-1]}"
164164
fig_index = item['ques_figure_ids'].index(fig_id)
165-
165+
166166
if self.img_dir != "":
167167
fig_src = os.path.join(self.img_dir, fig_id)
168168
if '.png' in item['ques_figure_paths'][fig_index]:
@@ -171,12 +171,12 @@ def _convert_to_ids(self, item: Union[str, dict, list], key=lambda x: x,
171171
fig_src += '.jpg'
172172
else:
173173
fig_src = item['ques_figure_paths'][fig_index]
174-
174+
175175
print(f"Open figure {fig_src}")
176176
im = Image.open(fig_src)
177177
im = im.resize((56, 56))
178178
token_idx.append(to_grayscale(im))
179-
179+
180180
except Exception:
181181
warnings.warn('Open image error!')
182182
token_idx.append(self.stoi['word'][self.img_token])
@@ -316,23 +316,26 @@ def padding(self, idx, max_length, type='word'):
316316

317317
def set_img_dir(self, path):
318318
self.img_dir = path
319-
320-
319+
320+
321321
class QuesnetDataset(Dataset):
322322
'''
323323
Quesnet-specific datasets
324324
'''
325-
def __init__(self, filename: str,
326-
tokenizer: QuesNetTokenizer = None,
327-
img_dir: str = "",
328-
meta: Optional[list] = None,
329-
content_key=lambda x: x['ques_content'],
330-
meta_key=lambda x: x['know_name'],
331-
answer_key=lambda x: x['ques_answer'],
332-
option_key=lambda x: x['ques_options'],
333-
pipeline=None,
334-
skip=0
335-
):
325+
def __init__(
326+
self,
327+
filename: str,
328+
tokenizer: QuesNetTokenizer = None,
329+
img_dir: str = "",
330+
meta: Optional[list] = None,
331+
content_key=lambda x: x['ques_content'],
332+
meta_key=lambda x: x['know_name'],
333+
answer_key=lambda x: x['ques_answer'],
334+
option_key=lambda x: x['ques_options'],
335+
pipeline=None,
336+
skip=0
337+
):
338+
336339
self.filename = filename
337340
self.skip = skip
338341
self.img_dir = img_dir
@@ -341,16 +344,22 @@ def __init__(self, filename: str,
341344
self.answer_key = answer_key
342345
self.option_key = option_key
343346
self.pipeline = pipeline
344-
345-
if tokenizer == None:
346-
tokenizer = QuesNetTokenizer(meta=['know_name'],
347-
img_dir=img_dir)
347+
348+
if tokenizer is None:
349+
tokenizer = QuesNetTokenizer(
350+
meta=['know_name'],
351+
img_dir=img_dir
352+
)
348353
self.tokenizer = tokenizer
349354
self.meta = meta if meta else tokenizer.meta
350355
self.load_data_lines()
351-
tokenizer.set_vocab(self.lines, key=lambda x: x['ques_content'],
352-
trim_min_count=2, silent=False)
353-
tokenizer.set_meta_vocab(self.lines, silent=False)
356+
tokenizer.set_vocab(
357+
self.lines,
358+
key=lambda x: x['ques_content'],
359+
trim_min_count=2,
360+
silent=False
361+
)
362+
tokenizer.set_meta_vocab(self.lines, silent=False)
354363

355364

356365
def load_data_lines(self):
@@ -376,15 +385,13 @@ def load_data_lines(self):
376385
if not line:
377386
break
378387
self.lines.append(json.loads(line.strip()))
379-
388+
380389
self.length = row - skip - 1
381390
assert self.length > 0, f'{data_dir} is empty. Or file length is less than skip length.'
382391

383-
384392
def __len__(self):
385393
return len(self.lines)
386394

387-
388395
def __getitem__(self, index):
389396
if isinstance(index, int):
390397
line = self.lines[index]
@@ -393,33 +400,40 @@ def __getitem__(self, index):
393400
token = self.tokenizer(line, key=self.content_key, meta=self.meta)
394401
content = token['seq_idx']
395402
meta = token['meta_idx']
396-
397-
if self.answer_key(line).isalpha() and len(self.answer_key(line)) == 1 and ord(self.answer_key(line)) < 128 and len(self.option_key(line)) > 0:
403+
404+
if self.answer_key(line).isalpha() and len(self.answer_key(line)) == 1 \
405+
and ord(self.answer_key(line)) < 128 and len(self.option_key(line)) > 0:
398406
answer_idx = ord(self.answer_key(line).upper()) - ord('A')
399407
options = self.option_key(line)
400408
answer = self.tokenizer(options.pop(answer_idx), meta=self.meta)['seq_idx']
401409
false_options = [(self.tokenizer(option, meta=self.meta))['seq_idx'] for option in options]
402410
else:
403411
answer = (self.tokenizer(self.answer_key(line), meta=self.meta))['seq_idx']
404412
false_options = [[0], [0], [0]]
405-
406-
qs = Question(id=qid, content=content, answer=answer,
407-
false_options=false_options, labels=meta)
413+
414+
qs = Question(
415+
id=qid,
416+
content=content,
417+
answer=answer,
418+
false_options=false_options,
419+
labels=meta
420+
)
421+
408422
if callable(self.pipeline):
409-
qs = self.pipeline(qs)
410-
423+
qs = self.pipeline(qs)
424+
411425
return qs
412-
413-
426+
414427
elif isinstance(index, slice):
415428
results = []
416429
for i in range(*index.indices(len(self))):
417430
results.append(self[i])
418431
return results
419-
432+
420433
else:
421434
raise TypeError('Invalid argument type. Index type should be int or slice.')
422-
435+
436+
423437
class EmbeddingDataset(Dataset):
424438
def __init__(self, data, data_type='image'):
425439
self.data = data
@@ -434,7 +448,7 @@ def __getitem__(self, idx):
434448
return to_tensor(self.data[idx])
435449
elif self.data_type == 'meta':
436450
return self.data[idx]
437-
451+
438452

439453
class PrefetchIter:
440454
"""Iterator on data and labels, with states for save and restore."""
@@ -445,7 +459,7 @@ def __init__(self, data, *label, length=None, batch_size=1, shuffle=True):
445459
self.batch_size = batch_size
446460
self.queue = queue.Queue(maxsize=8)
447461
self.length = length if length is not None else len(data)
448-
462+
449463
assert all(self.length == len(lab) for lab in label), \
450464
'data and label must have same lengths'
451465

@@ -497,10 +511,11 @@ def produce(self):
497511
except Exception as e:
498512
self.queue.put(e)
499513
return
500-
514+
501515

502516
sigint_handler = signal.getsignal(signal.SIGINT)
503517

518+
504519
def critical(f):
505520
it = iter(f)
506521
signal_received = ()
@@ -519,7 +534,7 @@ def handler(sig, frame):
519534
sigint_handler(*signal_received)
520535
except StopIteration:
521536
break
522-
537+
523538

524539
def pretrain_embedding_layer(dataset: EmbeddingDataset, ae: AE, lr: float = 1e-3, log_step: int = 1, epochs: int = 3,
525540
batch_size: int = 4, device=torch.device('cpu')):
@@ -538,18 +553,20 @@ def pretrain_embedding_layer(dataset: EmbeddingDataset, ae: AE, lr: float = 1e-3
538553
logger.info(f"[Epoch{i}][Batch{batch}]Training {train_type} Embedding layer, loss:{loss}")
539554
return ae
540555

556+
541557
def optimizer(*models, **kwargs):
542-
_cur_optim = [m.optim_cls(m.parameters(), **kwargs)
543-
if hasattr(m, 'optim_cls')
544-
else torch.optim.Adam(m.parameters(), **kwargs)
545-
for m in models]
558+
_cur_optim = [
559+
m.optim_cls(m.parameters(), **kwargs)
560+
if hasattr(m, 'optim_cls')
561+
else torch.optim.Adam(m.parameters(), **kwargs) for m in models
562+
]
546563
if len(_cur_optim) == 1:
547564
return _cur_optim[0]
548565
else:
549-
return _cur_optim
550-
566+
return _cur_optim
567+
551568

552-
def pretrain_quesnet(path, output_dir, pretrain_dir = None, img_dir = None, save_embs = False, load_embs = False, train_params = None):
569+
def pretrain_quesnet(path, output_dir, pretrain_dir=None, img_dir=None, save_embs=False, load_embs=False, train_params=None):
553570
""" pretrain quesnet
554571
555572
Parameters
@@ -597,7 +614,7 @@ def pretrain_quesnet(path, output_dir, pretrain_dir = None, img_dir = None, save
597614
"""
598615
os.makedirs(output_dir, exist_ok=True)
599616
device = torch.device(train_params['device'])
600-
617+
601618
default_train_params = {
602619
# train params
603620
"n_epochs": 1,
@@ -648,11 +665,14 @@ def pretrain_quesnet(path, output_dir, pretrain_dir = None, img_dir = None, save
648665
meta_corpus.append(meta_vector)
649666

650667
# train word2vec for text embedding
651-
if pretrain_dir != None and load_embs:
668+
if pretrain_dir is not None and load_embs:
652669
model.quesnet.load_emb(np.load(os.path.join(output_dir, 'w2v_embs.npy')))
653670
else:
654-
gensim_w2v = Word2Vec(sentences=[[item] for item in emb_dict.keys()], min_count=1,
655-
vector_size=emb_size)
671+
gensim_w2v = Word2Vec(
672+
sentences=[[item] for item in emb_dict.keys()],
673+
min_count=1,
674+
vector_size=emb_size
675+
)
656676
gensim_w2v.init_weights()
657677
gensim_w2v.train(corpus_iterable=w2v_corpus, total_examples=len(w2v_corpus), epochs=train_params['n_epochs'])
658678
w2v_emb = gensim_w2v.syn1neg
@@ -667,35 +687,46 @@ def pretrain_quesnet(path, output_dir, pretrain_dir = None, img_dir = None, save
667687
logger.info('quesnet Word Embedding loaded')
668688

669689
# train auto-encoder loss for image embedding
670-
if pretrain_dir != None and load_embs:
690+
if pretrain_dir is not None and load_embs:
671691
model.quesnet.load_img(torch.load(os.path.join(pretrain_dir, 'trained_ie.pt')))
672692
else:
673693
img_dataset = EmbeddingDataset(data=img_corpus, data_type='image')
674-
trained_ie = pretrain_embedding_layer(dataset=img_dataset, ae=model.quesnet.ie, lr=train_params['lr'],
675-
log_step=train_params['log_steps'], batch_size=train_params['batch_size'],
676-
epochs=train_params['n_epochs'], device=device)
694+
trained_ie = pretrain_embedding_layer(
695+
dataset=img_dataset,
696+
ae=model.quesnet.ie,
697+
lr=train_params['lr'],
698+
log_step=train_params['log_steps'],
699+
batch_size=train_params['batch_size'],
700+
epochs=train_params['n_epochs'],
701+
device=device
702+
)
677703
if save_embs:
678704
torch.save(trained_ie.state_dict(), os.path.join(output_dir, 'trained_ie.pt'))
679705
model.quesnet.load_img(trained_ie)
680706
logger.info('quesnet Image Embedding loaded')
681-
682707

683708
# train auto-encoder loss for meta embedding
684-
if pretrain_dir != None and load_embs:
709+
if pretrain_dir is not None and load_embs:
685710
model.quesnet.load_meta(torch.load(os.path.join(pretrain_dir, 'trained_me.pt')))
686711
else:
687712
meta_dateset = EmbeddingDataset(data=meta_corpus, data_type='meta')
688-
trained_me = pretrain_embedding_layer(dataset=meta_dateset, ae=model.quesnet.me, lr=train_params['lr'],
689-
log_step=train_params['log_steps'], batch_size=train_params['batch_size'],
690-
epochs=train_params['n_epochs'], device=device)
713+
trained_me = pretrain_embedding_layer(
714+
dataset=meta_dateset,
715+
ae=model.quesnet.me,
716+
lr=train_params['lr'],
717+
log_step=train_params['log_steps'],
718+
batch_size=train_params['batch_size'],
719+
epochs=train_params['n_epochs'],
720+
device=device
721+
)
691722
if save_embs:
692723
torch.save(trained_me.state_dict(), os.path.join(output_dir, 'trained_me.pt'))
693724
model.quesnet.load_meta(trained_me)
694725
logger.info('quesnet Meta Embedding loaded')
695-
726+
696727
logger.info("quesnet Word, Image and Meta Embeddings training is done")
697728
# DONE for datasets
698-
729+
699730
# HLM and DOO training
700731
dataset.pipeline = partial(model.quesnet.make_batch, device=device, pretrain=True)
701732
model.train()
@@ -726,6 +757,3 @@ def pretrain_quesnet(path, output_dir, pretrain_dir = None, img_dir = None, save
726757

727758
model.save_pretrained(output_dir)
728759
tokenizer.save_pretrained(output_dir)
729-
730-
731-

tests/test_pretrain/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
# TEST_GPU = torch.cuda.is_available()
1010

1111

12-
1312
@pytest.fixture(scope="module")
1413
def standard_luna_data():
1514
data_path = path_append(abs_current_dir(__file__), "../../static/test_data/standard_luna_data.json", to_str=True)

0 commit comments

Comments
 (0)