Skip to content

Commit 729e506

Browse files
fix bugs for quesnet
1 parent 21a7d6d commit 729e506

File tree

5 files changed

+87
-46
lines changed

5 files changed

+87
-46
lines changed

EduNLP/ModelZoo/quesnet/quesnet.py

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ def make_batch(self, data, device, pretrain=False):
114114
ans_input = []
115115
ans_output = []
116116
false_options = [[] for i in range(3)]
117+
118+
if not isinstance(data, list):
119+
data = [data]
120+
117121
for q in data:
118122
meta = torch.zeros(len(self.stoi[self.meta])).to(device)
119123
meta[q.labels.get(self.meta) or []] = 1
@@ -156,7 +160,7 @@ def make_batch(self, data, device, pretrain=False):
156160

157161
for i, fo in enumerate(q.false_options):
158162
false_options[i].append([0] + fo)
159-
163+
160164
lembs = SeqBatch(lembs, device=device)
161165
rembs = SeqBatch(rembs, device=device)
162166
embs = SeqBatch(embs, device=device)
@@ -192,6 +196,23 @@ def make_batch(self, data, device, pretrain=False):
192196
words = torch.cat(words, dim=0) if words else None
193197
ims = torch.cat(ims, dim=0) if ims else None
194198
metas = torch.cat(metas, dim=0) if metas else None
199+
200+
201+
# print("debug1")
202+
# print(lembs)
203+
# print(rembs)
204+
# print(words)
205+
# print(ims)
206+
# print(metas)
207+
# print(wmask)
208+
# print(imask)
209+
# print(mmask)
210+
# print(embs)
211+
# print(ans_input)
212+
# print(ans_output)
213+
# print(false_opt_input)
214+
215+
195216
if pretrain:
196217
return (
197218
lembs, rembs, words, ims, metas, wmask, imask, mmask,
@@ -302,67 +323,70 @@ def __init__(self, _stoi=None, pretrained_embs: np.ndarray = None, pretrained_im
302323
self.config = PretrainedConfig.from_dict(self.config)
303324

304325
def forward(self, batch):
305-
left, right, words, ims, metas, wmask, imask, mmask, inputs, ans_input, ans_output, false_opt_input = batch
326+
left, right, words, ims, metas, wmask, imask, mmask, inputs, ans_input, ans_output, false_opt_input = batch[0]
306327

307328
# high-level loss
308329
outputs = self.quesnet(inputs)
309330
embeded = outputs.embeded
310331
h = outputs.hidden
311332

312333
x = ans_input.packed()
313-
y, _ = self.ans_decode(PackedSequence(self.quesnet.we(x.data), x.batch_sizes),
334+
335+
y, _ = self.ans_decode(PackedSequence(self.quesnet.we(x[0].data), x.batch_sizes),
314336
h.repeat(self.config.layers, 1, 1))
315337
floss = F.cross_entropy(self.ans_output(y.data),
316338
ans_output.packed().data)
317339
floss = floss + F.binary_cross_entropy_with_logits(self.ans_judge(y.data),
318340
torch.ones_like(self.ans_judge(y.data)))
319341
for false_opt in false_opt_input:
320342
x = false_opt.packed()
321-
y, _ = self.ans_decode(PackedSequence(self.quesnet.we(x.data), x.batch_sizes),
343+
if x == (None, None):
344+
continue
345+
y, _ = self.ans_decode(PackedSequence(self.quesnet.we(x[0].data), x.batch_sizes),
322346
h.repeat(self.config.layers, 1, 1))
323347
floss = floss + F.binary_cross_entropy_with_logits(self.ans_judge(y.data),
324348
torch.zeros_like(self.ans_judge(y.data)))
325349
loss = floss * self.lambda_loss[1]
326350
# low-level loss
327-
left_hid = self.quesnet(left).pack_embeded.data[:, :self.rnn_size]
328-
right_hid = self.quesnet(right).pack_embeded.data[:, self.rnn_size:]
351+
left_hid = self.quesnet(left).pack_embeded.data[:, :self.rnn_size].clone()
352+
right_hid = self.quesnet(right).pack_embeded.data[:, self.rnn_size:].clone()
329353

330354
wloss = iloss = mloss = None
331355

332356
if words is not None:
333-
lwfea = torch.masked_select(left_hid, wmask.unsqueeze(1).bool()) \
334-
.view(-1, self.rnn_size)
335-
lout = self.lwoutput(lwfea)
336-
rwfea = torch.masked_select(right_hid, wmask.unsqueeze(1).bool()) \
337-
.view(-1, self.rnn_size)
338-
rout = self.rwoutput(rwfea)
339-
out = self.woutput(torch.cat([lwfea, rwfea], dim=1))
357+
lwfea = torch.masked_select(left_hid.clone(), wmask.unsqueeze(1).bool()) \
358+
.view(-1, self.rnn_size).clone()
359+
lout = self.lwoutput(lwfea.clone())
360+
rwfea = torch.masked_select(right_hid.clone(), wmask.unsqueeze(1).bool()) \
361+
.view(-1, self.rnn_size).clone()
362+
rout = self.rwoutput(rwfea.clone())
363+
out = self.woutput(torch.cat([lwfea.clone(), rwfea.clone()], dim=1).clone())
340364
wloss = (F.cross_entropy(out, words) + F.cross_entropy(lout, words) + F.
341365
cross_entropy(rout, words)) * self.quesnet.lambda_input[0] / 3
342366
wloss *= self.lambda_loss[0]
343367
loss = loss + wloss
344368

345369
if ims is not None:
346-
lifea = torch.masked_select(left_hid, imask.unsqueeze(1).bool()) \
347-
.view(-1, self.rnn_size)
348-
lout = self.lioutput(lifea)
349-
rifea = torch.masked_select(right_hid, imask.unsqueeze(1).bool()) \
350-
.view(-1, self.rnn_size)
351-
rout = self.rioutput(rifea)
352-
out = self.ioutput(torch.cat([lifea, rifea], dim=1))
370+
lifea = torch.masked_select(left_hid.clone(), imask.unsqueeze(1).bool()) \
371+
.view(-1, self.rnn_size).clone()
372+
lout = self.lioutput(lifea.clone())
373+
rifea = torch.masked_select(right_hid.clone(), imask.unsqueeze(1).bool()) \
374+
.view(-1, self.rnn_size).clone()
375+
rout = self.rioutput(rifea.clone())
376+
out = self.ioutput(torch.cat([lifea.clone(), rifea.clone()], dim=1).clone())
353377
iloss = (self.quesnet.ie.loss(ims, out) + self.quesnet.ie.loss(ims, lout) + self.quesnet.ie.
354378
loss(ims, rout)) * self.quesnet.lambda_input[1] / 3
355379
iloss *= self.lambda_loss[0]
356380
loss = loss + iloss
357381

358382
if metas is not None:
359-
lmfea = torch.masked_select(left_hid, mmask.unsqueeze(1).bool()) \
360-
.view(-1, self.rnn_size)
361-
lout = self.lmoutput(lmfea)
362-
rmfea = torch.masked_select(right_hid, mmask.unsqueeze(1).bool()) \
363-
.view(-1, self.rnn_size)
364-
rout = self.rmoutput(rmfea)
365-
out = self.moutput(torch.cat([lmfea, rmfea], dim=1))
383+
lmfea = torch.masked_select(left_hid.clone(), mmask.unsqueeze(1).bool()) \
384+
.view(-1, self.rnn_size).clone()
385+
lout = self.lmoutput(lmfea.clone())
386+
rmfea = torch.masked_select(right_hid.clone(), mmask.unsqueeze(1).bool()) \
387+
.view(-1, self.rnn_size).clone()
388+
rout = self.rmoutput(rmfea.clone())
389+
out = self.moutput(torch.cat([lmfea.clone(), rmfea.clone()], dim=1).clone())
366390
mloss = (self.quesnet.me.loss(metas, out) + self.quesnet.me.loss(metas, lout) + self.quesnet.me.
367391
loss(metas, rout)) * self.quesnet.lambda_input[2] / 3
368392
mloss *= self.lambda_loss[0]

EduNLP/ModelZoo/quesnet/util.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,19 @@ def __init__(self, seqs, dtype=None, device=None):
1111
self.dtype = dtype
1212
self.device = device
1313
self.seqs = seqs
14-
self.lens = [len(x) for x in seqs]
14+
15+
if not seqs:
16+
self.lens = [0]
17+
else:
18+
self.lens = [len(x) for x in seqs]
1519

1620
self.ind = argsort(self.lens)[::-1]
1721
self.inv = argsort(self.ind)
1822
self.lens.sort(reverse=True)
1923
self._prefix = [0]
2024
self._index = {}
2125
c = 0
26+
2227
for i in range(self.lens[0]):
2328
for j in range(len(self.lens)):
2429
if self.lens[j] <= i:
@@ -28,10 +33,15 @@ def __init__(self, seqs, dtype=None, device=None):
2833

2934
def packed(self):
3035
ind = torch.tensor(self.ind, dtype=torch.long, device=self.device)
36+
if not ind.numel() or ind.max() >= self.padded()[0].size(1):
37+
return None, None
3138
padded = self.padded()[0].index_select(1, ind)
3239
return pack_padded_sequence(padded, torch.tensor(self.lens))
3340

3441
def padded(self, max_len=None, batch_first=False):
42+
if not self.seqs:
43+
return torch.empty((0, 0), dtype=self.dtype, device=self.device), torch.empty((0, 0), dtype=torch.bool, device=self.device)
44+
3545
seqs = [torch.tensor(s, dtype=self.dtype, device=self.device)
3646
if not isinstance(s, torch.Tensor) else s
3747
for s in self.seqs]

EduNLP/Pretrain/quesnet_vec.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,9 @@ def _convert_to_ids(self, item: Union[str, dict, list], key=lambda x: x,
156156
token_item = self.tokenize(item, key)
157157
token_idx = []
158158
for _, w in enumerate(token_item):
159-
if isinstance(w, FigureSegment):
159+
if isinstance(w, FigureSegment) and 'ques_figure_ids' in item.keys():
160160
# image
161+
161162
try:
162163
fig_id = f"{w.src[10:-1]}"
163164
fig_index = item['ques_figure_ids'].index(fig_id)
@@ -171,11 +172,13 @@ def _convert_to_ids(self, item: Union[str, dict, list], key=lambda x: x,
171172
else:
172173
fig_src = item['ques_figure_paths'][fig_index]
173174

175+
print(f"Open figure {fig_src}")
174176
im = Image.open(fig_src)
175177
im = im.resize((56, 56))
176178
token_idx.append(to_grayscale(im))
179+
177180
except Exception:
178-
warnings.warn('Open image error! path = ' + fig_src)
181+
warnings.warn('Open image error!')
179182
token_idx.append(self.stoi['word'][self.img_token])
180183
else:
181184
# word
@@ -390,6 +393,7 @@ def __getitem__(self, index):
390393
token = self.tokenizer(line, key=self.content_key, meta=self.meta)
391394
content = token['seq_idx']
392395
meta = token['meta_idx']
396+
393397
if self.answer_key(line).isalpha() and len(self.answer_key(line)) == 1 and ord(self.answer_key(line)) < 128 and len(self.option_key(line)) > 0:
394398
answer_idx = ord(self.answer_key(line).upper()) - ord('A')
395399
options = self.option_key(line)
@@ -441,7 +445,7 @@ def __init__(self, data, *label, length=None, batch_size=1, shuffle=True):
441445
self.batch_size = batch_size
442446
self.queue = queue.Queue(maxsize=8)
443447
self.length = length if length is not None else len(data)
444-
448+
445449
assert all(self.length == len(lab) for lab in label), \
446450
'data and label must have same lengths'
447451

@@ -545,7 +549,7 @@ def optimizer(*models, **kwargs):
545549
return _cur_optim
546550

547551

548-
def pretrain_quesnet(path, output_dir, pretrain_dir = None, img_dir = None, save_embs = False, train_params = None):
552+
def pretrain_quesnet(path, output_dir, pretrain_dir = None, img_dir = None, save_embs = False, load_embs = False, train_params = None):
549553
""" pretrain quesnet
550554
551555
Parameters
@@ -558,6 +562,8 @@ def pretrain_quesnet(path, output_dir, pretrain_dir = None, img_dir = None, save
558562
quesnet tokenizer
559563
save_embs : bool, optional
560564
whether to save pretrained word/image/meta embeddings seperately
565+
load_embs : bool, optional
566+
whether to load pretrained word/image/meta embeddings seperately
561567
train_params : dict, optional
562568
the training parameters and model parameters, by default None
563569
- "n_epochs": int, default = 1
@@ -609,7 +615,7 @@ def pretrain_quesnet(path, output_dir, pretrain_dir = None, img_dir = None, save
609615
default_train_params.update(train_params)
610616
train_params = default_train_params
611617

612-
dataset = QuesnetDataset(path)
618+
dataset = QuesnetDataset(path, img_dir=img_dir)
613619
tokenizer = dataset.tokenizer
614620
tokenizer.save_pretrained(output_dir)
615621
model = QuesNetForPreTraining(_stoi=tokenizer.stoi, feat_size=train_params['feat_size'],
@@ -642,7 +648,7 @@ def pretrain_quesnet(path, output_dir, pretrain_dir = None, img_dir = None, save
642648
meta_corpus.append(meta_vector)
643649

644650
# train word2vec for text embedding
645-
if pretrain_dir != None:
651+
if pretrain_dir != None and load_embs:
646652
model.quesnet.load_emb(np.load(os.path.join(output_dir, 'w2v_embs.npy')))
647653
else:
648654
gensim_w2v = Word2Vec(sentences=[[item] for item in emb_dict.keys()], min_count=1,
@@ -661,7 +667,7 @@ def pretrain_quesnet(path, output_dir, pretrain_dir = None, img_dir = None, save
661667
logger.info('quesnet Word Embedding loaded')
662668

663669
# train auto-encoder loss for image embedding
664-
if pretrain_dir != None:
670+
if pretrain_dir != None and load_embs:
665671
model.quesnet.load_img(torch.load(os.path.join(pretrain_dir, 'trained_ie.pt')))
666672
else:
667673
img_dataset = EmbeddingDataset(data=img_corpus, data_type='image')
@@ -675,7 +681,7 @@ def pretrain_quesnet(path, output_dir, pretrain_dir = None, img_dir = None, save
675681

676682

677683
# train auto-encoder loss for meta embedding
678-
if pretrain_dir != None:
684+
if pretrain_dir != None and load_embs:
679685
model.quesnet.load_meta(torch.load(os.path.join(pretrain_dir, 'trained_me.pt')))
680686
else:
681687
meta_dateset = EmbeddingDataset(data=meta_corpus, data_type='meta')
@@ -696,7 +702,7 @@ def pretrain_quesnet(path, output_dir, pretrain_dir = None, img_dir = None, save
696702
optim = optimizer(model, lr=train_params['lr'])
697703
n_batches = 0
698704
for epoch in range(0, train_params['n_epochs']):
699-
train_iter = PrefetchIter(dataset, train_params['batch_size'])
705+
train_iter = PrefetchIter(dataset, batch_size=train_params['batch_size'])
700706
bar = enumerate(tqdm(train_iter, initial=train_iter.pos),
701707
train_iter.pos)
702708
for i, batch in critical(bar):

0 commit comments

Comments
 (0)