Skip to content

Commit 9bd85db

Browse files
fix bugs for flake8
1 parent 35c1057 commit 9bd85db

File tree

3 files changed

+32
-39
lines changed

3 files changed

+32
-39
lines changed

EduNLP/ModelZoo/quesnet/quesnet.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def make_batch(self, data, device, pretrain=False):
160160

161161
for i, fo in enumerate(q.false_options):
162162
false_options[i].append([0] + fo)
163-
163+
164164
lembs = SeqBatch(lembs, device=device)
165165
rembs = SeqBatch(rembs, device=device)
166166
embs = SeqBatch(embs, device=device)
@@ -195,24 +195,8 @@ def make_batch(self, data, device, pretrain=False):
195195

196196
words = torch.cat(words, dim=0) if words else None
197197
ims = torch.cat(ims, dim=0) if ims else None
198-
metas = torch.cat(metas, dim=0) if metas else None
199-
200-
201-
# print("debug1")
202-
# print(lembs)
203-
# print(rembs)
204-
# print(words)
205-
# print(ims)
206-
# print(metas)
207-
# print(wmask)
208-
# print(imask)
209-
# print(mmask)
210-
# print(embs)
211-
# print(ans_input)
212-
# print(ans_output)
213-
# print(false_opt_input)
214-
215-
198+
metas = torch.cat(metas, dim=0) if metas else None
199+
216200
if pretrain:
217201
return (
218202
lembs, rembs, words, ims, metas, wmask, imask, mmask,
@@ -331,7 +315,7 @@ def forward(self, batch):
331315
h = outputs.hidden
332316

333317
x = ans_input.packed()
334-
318+
335319
y, _ = self.ans_decode(PackedSequence(self.quesnet.we(x[0].data), x.batch_sizes),
336320
h.repeat(self.config.layers, 1, 1))
337321
floss = F.cross_entropy(self.ans_output(y.data),

EduNLP/ModelZoo/quesnet/util.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, seqs, dtype=None, device=None):
2323
self._prefix = [0]
2424
self._index = {}
2525
c = 0
26-
26+
2727
for i in range(self.lens[0]):
2828
for j in range(len(self.lens)):
2929
if self.lens[j] <= i:
@@ -40,8 +40,9 @@ def packed(self):
4040

4141
def padded(self, max_len=None, batch_first=False):
4242
if not self.seqs:
43-
return torch.empty((0, 0), dtype=self.dtype, device=self.device), torch.empty((0, 0), dtype=torch.bool, device=self.device)
44-
43+
return torch.empty((0, 0), dtype=self.dtype, device=self.device), \
44+
torch.empty((0, 0), dtype=torch.bool, device=self.device)
45+
4546
seqs = [torch.tensor(s, dtype=self.dtype, device=self.device)
4647
if not isinstance(s, torch.Tensor) else s
4748
for s in self.seqs]

EduNLP/Pretrain/quesnet_vec.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def save_list(item2index, path):
3939
def clip(v, low, high):
4040
return max(low, min(v, high))
4141

42+
4243
# Basic unit of Dataset
4344
Question = namedtuple('Question',
4445
['id', 'content', 'answer', 'false_options', 'labels'])
@@ -334,7 +335,7 @@ def __init__(
334335
option_key=lambda x: x['ques_options'],
335336
pipeline=None,
336337
skip=0
337-
):
338+
):
338339

339340
self.filename = filename
340341
self.skip = skip
@@ -349,7 +350,7 @@ def __init__(
349350
tokenizer = QuesNetTokenizer(
350351
meta=['know_name'],
351352
img_dir=img_dir
352-
)
353+
)
353354
self.tokenizer = tokenizer
354355
self.meta = meta if meta else tokenizer.meta
355356
self.load_data_lines()
@@ -358,16 +359,15 @@ def __init__(
358359
key=lambda x: x['ques_content'],
359360
trim_min_count=2,
360361
silent=False
361-
)
362+
)
362363
tokenizer.set_meta_vocab(self.lines, silent=False)
363-
364364

365365
def load_data_lines(self):
366366
'''Read data by row from a JSON file
367-
367+
368368
Important: the data file is loaded during initialization.
369369
'''
370-
370+
371371
# TODO: All data is read into memory without chunking.
372372
# This may lead to low efficiency.
373373
data_dir = self.filename
@@ -402,7 +402,7 @@ def __getitem__(self, index):
402402
meta = token['meta_idx']
403403

404404
if self.answer_key(line).isalpha() and len(self.answer_key(line)) == 1 \
405-
and ord(self.answer_key(line)) < 128 and len(self.option_key(line)) > 0:
405+
and ord(self.answer_key(line)) < 128 and len(self.option_key(line)) > 0:
406406
answer_idx = ord(self.answer_key(line).upper()) - ord('A')
407407
options = self.option_key(line)
408408
answer = self.tokenizer(options.pop(answer_idx), meta=self.meta)['seq_idx']
@@ -417,7 +417,7 @@ def __getitem__(self, index):
417417
answer=answer,
418418
false_options=false_options,
419419
labels=meta
420-
)
420+
)
421421

422422
if callable(self.pipeline):
423423
qs = self.pipeline(qs)
@@ -556,17 +556,25 @@ def pretrain_embedding_layer(dataset: EmbeddingDataset, ae: AE, lr: float = 1e-3
556556

557557
def optimizer(*models, **kwargs):
558558
_cur_optim = [
559-
m.optim_cls(m.parameters(), **kwargs)
560-
if hasattr(m, 'optim_cls')
559+
m.optim_cls(m.parameters(), **kwargs)
560+
if hasattr(m, 'optim_cls')
561561
else torch.optim.Adam(m.parameters(), **kwargs) for m in models
562-
]
562+
]
563563
if len(_cur_optim) == 1:
564564
return _cur_optim[0]
565565
else:
566566
return _cur_optim
567567

568-
569-
def pretrain_quesnet(path, output_dir, pretrain_dir=None, img_dir=None, save_embs=False, load_embs=False, train_params=None):
568+
569+
def pretrain_quesnet(
570+
path,
571+
output_dir,
572+
pretrain_dir=None,
573+
img_dir=None,
574+
save_embs=False,
575+
load_embs=False,
576+
train_params=None
577+
):
570578
""" pretrain quesnet
571579
572580
Parameters
@@ -672,7 +680,7 @@ def pretrain_quesnet(path, output_dir, pretrain_dir=None, img_dir=None, save_emb
672680
sentences=[[item] for item in emb_dict.keys()],
673681
min_count=1,
674682
vector_size=emb_size
675-
)
683+
)
676684
gensim_w2v.init_weights()
677685
gensim_w2v.train(corpus_iterable=w2v_corpus, total_examples=len(w2v_corpus), epochs=train_params['n_epochs'])
678686
w2v_emb = gensim_w2v.syn1neg
@@ -699,7 +707,7 @@ def pretrain_quesnet(path, output_dir, pretrain_dir=None, img_dir=None, save_emb
699707
batch_size=train_params['batch_size'],
700708
epochs=train_params['n_epochs'],
701709
device=device
702-
)
710+
)
703711
if save_embs:
704712
torch.save(trained_ie.state_dict(), os.path.join(output_dir, 'trained_ie.pt'))
705713
model.quesnet.load_img(trained_ie)
@@ -718,7 +726,7 @@ def pretrain_quesnet(path, output_dir, pretrain_dir=None, img_dir=None, save_emb
718726
batch_size=train_params['batch_size'],
719727
epochs=train_params['n_epochs'],
720728
device=device
721-
)
729+
)
722730
if save_embs:
723731
torch.save(trained_me.state_dict(), os.path.join(output_dir, 'trained_me.pt'))
724732
model.quesnet.load_meta(trained_me)

0 commit comments

Comments
 (0)