@@ -114,6 +114,10 @@ def make_batch(self, data, device, pretrain=False):
114114 ans_input = []
115115 ans_output = []
116116 false_options = [[] for i in range (3 )]
117+
118+ if not isinstance (data , list ):
119+ data = [data ]
120+
117121 for q in data :
118122 meta = torch .zeros (len (self .stoi [self .meta ])).to (device )
119123 meta [q .labels .get (self .meta ) or []] = 1
@@ -156,7 +160,7 @@ def make_batch(self, data, device, pretrain=False):
156160
157161 for i , fo in enumerate (q .false_options ):
158162 false_options [i ].append ([0 ] + fo )
159-
163+
160164 lembs = SeqBatch (lembs , device = device )
161165 rembs = SeqBatch (rembs , device = device )
162166 embs = SeqBatch (embs , device = device )
@@ -192,6 +196,23 @@ def make_batch(self, data, device, pretrain=False):
192196 words = torch .cat (words , dim = 0 ) if words else None
193197 ims = torch .cat (ims , dim = 0 ) if ims else None
194198 metas = torch .cat (metas , dim = 0 ) if metas else None
199+
200+
201+ # print("debug1")
202+ # print(lembs)
203+ # print(rembs)
204+ # print(words)
205+ # print(ims)
206+ # print(metas)
207+ # print(wmask)
208+ # print(imask)
209+ # print(mmask)
210+ # print(embs)
211+ # print(ans_input)
212+ # print(ans_output)
213+ # print(false_opt_input)
214+
215+
195216 if pretrain :
196217 return (
197218 lembs , rembs , words , ims , metas , wmask , imask , mmask ,
@@ -302,67 +323,70 @@ def __init__(self, _stoi=None, pretrained_embs: np.ndarray = None, pretrained_im
302323 self .config = PretrainedConfig .from_dict (self .config )
303324
304325 def forward (self , batch ):
305- left , right , words , ims , metas , wmask , imask , mmask , inputs , ans_input , ans_output , false_opt_input = batch
326+ left , right , words , ims , metas , wmask , imask , mmask , inputs , ans_input , ans_output , false_opt_input = batch [ 0 ]
306327
307328 # high-level loss
308329 outputs = self .quesnet (inputs )
309330 embeded = outputs .embeded
310331 h = outputs .hidden
311332
312333 x = ans_input .packed ()
313- y , _ = self .ans_decode (PackedSequence (self .quesnet .we (x .data ), x .batch_sizes ),
334+
335+ y , _ = self .ans_decode (PackedSequence (self .quesnet .we (x [0 ].data ), x .batch_sizes ),
314336 h .repeat (self .config .layers , 1 , 1 ))
315337 floss = F .cross_entropy (self .ans_output (y .data ),
316338 ans_output .packed ().data )
317339 floss = floss + F .binary_cross_entropy_with_logits (self .ans_judge (y .data ),
318340 torch .ones_like (self .ans_judge (y .data )))
319341 for false_opt in false_opt_input :
320342 x = false_opt .packed ()
321- y , _ = self .ans_decode (PackedSequence (self .quesnet .we (x .data ), x .batch_sizes ),
343+ if x == (None , None ):
344+ continue
345+ y , _ = self .ans_decode (PackedSequence (self .quesnet .we (x [0 ].data ), x .batch_sizes ),
322346 h .repeat (self .config .layers , 1 , 1 ))
323347 floss = floss + F .binary_cross_entropy_with_logits (self .ans_judge (y .data ),
324348 torch .zeros_like (self .ans_judge (y .data )))
325349 loss = floss * self .lambda_loss [1 ]
326350 # low-level loss
327- left_hid = self .quesnet (left ).pack_embeded .data [:, :self .rnn_size ]
328- right_hid = self .quesnet (right ).pack_embeded .data [:, self .rnn_size :]
351+ left_hid = self .quesnet (left ).pack_embeded .data [:, :self .rnn_size ]. clone ()
352+ right_hid = self .quesnet (right ).pack_embeded .data [:, self .rnn_size :]. clone ()
329353
330354 wloss = iloss = mloss = None
331355
332356 if words is not None :
333- lwfea = torch .masked_select (left_hid , wmask .unsqueeze (1 ).bool ()) \
334- .view (- 1 , self .rnn_size )
335- lout = self .lwoutput (lwfea )
336- rwfea = torch .masked_select (right_hid , wmask .unsqueeze (1 ).bool ()) \
337- .view (- 1 , self .rnn_size )
338- rout = self .rwoutput (rwfea )
339- out = self .woutput (torch .cat ([lwfea , rwfea ], dim = 1 ))
357+ lwfea = torch .masked_select (left_hid . clone () , wmask .unsqueeze (1 ).bool ()) \
358+ .view (- 1 , self .rnn_size ). clone ()
359+ lout = self .lwoutput (lwfea . clone () )
360+ rwfea = torch .masked_select (right_hid . clone () , wmask .unsqueeze (1 ).bool ()) \
361+ .view (- 1 , self .rnn_size ). clone ()
362+ rout = self .rwoutput (rwfea . clone () )
363+ out = self .woutput (torch .cat ([lwfea . clone () , rwfea . clone () ], dim = 1 ). clone ( ))
340364 wloss = (F .cross_entropy (out , words ) + F .cross_entropy (lout , words ) + F .
341365 cross_entropy (rout , words )) * self .quesnet .lambda_input [0 ] / 3
342366 wloss *= self .lambda_loss [0 ]
343367 loss = loss + wloss
344368
345369 if ims is not None :
346- lifea = torch .masked_select (left_hid , imask .unsqueeze (1 ).bool ()) \
347- .view (- 1 , self .rnn_size )
348- lout = self .lioutput (lifea )
349- rifea = torch .masked_select (right_hid , imask .unsqueeze (1 ).bool ()) \
350- .view (- 1 , self .rnn_size )
351- rout = self .rioutput (rifea )
352- out = self .ioutput (torch .cat ([lifea , rifea ], dim = 1 ))
370+ lifea = torch .masked_select (left_hid . clone () , imask .unsqueeze (1 ).bool ()) \
371+ .view (- 1 , self .rnn_size ). clone ()
372+ lout = self .lioutput (lifea . clone () )
373+ rifea = torch .masked_select (right_hid . clone () , imask .unsqueeze (1 ).bool ()) \
374+ .view (- 1 , self .rnn_size ). clone ()
375+ rout = self .rioutput (rifea . clone () )
376+ out = self .ioutput (torch .cat ([lifea . clone () , rifea . clone () ], dim = 1 ). clone ( ))
353377 iloss = (self .quesnet .ie .loss (ims , out ) + self .quesnet .ie .loss (ims , lout ) + self .quesnet .ie .
354378 loss (ims , rout )) * self .quesnet .lambda_input [1 ] / 3
355379 iloss *= self .lambda_loss [0 ]
356380 loss = loss + iloss
357381
358382 if metas is not None :
359- lmfea = torch .masked_select (left_hid , mmask .unsqueeze (1 ).bool ()) \
360- .view (- 1 , self .rnn_size )
361- lout = self .lmoutput (lmfea )
362- rmfea = torch .masked_select (right_hid , mmask .unsqueeze (1 ).bool ()) \
363- .view (- 1 , self .rnn_size )
364- rout = self .rmoutput (rmfea )
365- out = self .moutput (torch .cat ([lmfea , rmfea ], dim = 1 ))
383+ lmfea = torch .masked_select (left_hid . clone () , mmask .unsqueeze (1 ).bool ()) \
384+ .view (- 1 , self .rnn_size ). clone ()
385+ lout = self .lmoutput (lmfea . clone () )
386+ rmfea = torch .masked_select (right_hid . clone () , mmask .unsqueeze (1 ).bool ()) \
387+ .view (- 1 , self .rnn_size ). clone ()
388+ rout = self .rmoutput (rmfea . clone () )
389+ out = self .moutput (torch .cat ([lmfea . clone () , rmfea . clone () ], dim = 1 ). clone ( ))
366390 mloss = (self .quesnet .me .loss (metas , out ) + self .quesnet .me .loss (metas , lout ) + self .quesnet .me .
367391 loss (metas , rout )) * self .quesnet .lambda_input [2 ] / 3
368392 mloss *= self .lambda_loss [0 ]
0 commit comments