@@ -114,6 +114,10 @@ def make_batch(self, data, device, pretrain=False):
114114 ans_input = []
115115 ans_output = []
116116 false_options = [[] for i in range (3 )]
117+
118+ if not isinstance (data , list ):
119+ data = [data ]
120+
117121 for q in data :
118122 meta = torch .zeros (len (self .stoi [self .meta ])).to (device )
119123 meta [q .labels .get (self .meta ) or []] = 1
@@ -192,6 +196,7 @@ def make_batch(self, data, device, pretrain=False):
192196 words = torch .cat (words , dim = 0 ) if words else None
193197 ims = torch .cat (ims , dim = 0 ) if ims else None
194198 metas = torch .cat (metas , dim = 0 ) if metas else None
199+
195200 if pretrain :
196201 return (
197202 lembs , rembs , words , ims , metas , wmask , imask , mmask ,
@@ -302,67 +307,70 @@ def __init__(self, _stoi=None, pretrained_embs: np.ndarray = None, pretrained_im
302307 self .config = PretrainedConfig .from_dict (self .config )
303308
304309 def forward (self , batch ):
305- left , right , words , ims , metas , wmask , imask , mmask , inputs , ans_input , ans_output , false_opt_input = batch
310+ left , right , words , ims , metas , wmask , imask , mmask , inputs , ans_input , ans_output , false_opt_input = batch [ 0 ]
306311
307312 # high-level loss
308313 outputs = self .quesnet (inputs )
309314 embeded = outputs .embeded
310315 h = outputs .hidden
311316
312317 x = ans_input .packed ()
313- y , _ = self .ans_decode (PackedSequence (self .quesnet .we (x .data ), x .batch_sizes ),
318+
319+ y , _ = self .ans_decode (PackedSequence (self .quesnet .we (x [0 ].data ), x .batch_sizes ),
314320 h .repeat (self .config .layers , 1 , 1 ))
315321 floss = F .cross_entropy (self .ans_output (y .data ),
316322 ans_output .packed ().data )
317323 floss = floss + F .binary_cross_entropy_with_logits (self .ans_judge (y .data ),
318324 torch .ones_like (self .ans_judge (y .data )))
319325 for false_opt in false_opt_input :
320326 x = false_opt .packed ()
321- y , _ = self .ans_decode (PackedSequence (self .quesnet .we (x .data ), x .batch_sizes ),
327+ if x == (None , None ):
328+ continue
329+ y , _ = self .ans_decode (PackedSequence (self .quesnet .we (x [0 ].data ), x .batch_sizes ),
322330 h .repeat (self .config .layers , 1 , 1 ))
323331 floss = floss + F .binary_cross_entropy_with_logits (self .ans_judge (y .data ),
324332 torch .zeros_like (self .ans_judge (y .data )))
325333 loss = floss * self .lambda_loss [1 ]
326334 # low-level loss
327- left_hid = self .quesnet (left ).pack_embeded .data [:, :self .rnn_size ]
328- right_hid = self .quesnet (right ).pack_embeded .data [:, self .rnn_size :]
335+ left_hid = self .quesnet (left ).pack_embeded .data [:, :self .rnn_size ]. clone ()
336+ right_hid = self .quesnet (right ).pack_embeded .data [:, self .rnn_size :]. clone ()
329337
330338 wloss = iloss = mloss = None
331339
332340 if words is not None :
333- lwfea = torch .masked_select (left_hid , wmask .unsqueeze (1 ).bool ()) \
334- .view (- 1 , self .rnn_size )
335- lout = self .lwoutput (lwfea )
336- rwfea = torch .masked_select (right_hid , wmask .unsqueeze (1 ).bool ()) \
337- .view (- 1 , self .rnn_size )
338- rout = self .rwoutput (rwfea )
339- out = self .woutput (torch .cat ([lwfea , rwfea ], dim = 1 ))
341+ lwfea = torch .masked_select (left_hid . clone () , wmask .unsqueeze (1 ).bool ()) \
342+ .view (- 1 , self .rnn_size ). clone ()
343+ lout = self .lwoutput (lwfea . clone () )
344+ rwfea = torch .masked_select (right_hid . clone () , wmask .unsqueeze (1 ).bool ()) \
345+ .view (- 1 , self .rnn_size ). clone ()
346+ rout = self .rwoutput (rwfea . clone () )
347+ out = self .woutput (torch .cat ([lwfea . clone () , rwfea . clone () ], dim = 1 ). clone ( ))
340348 wloss = (F .cross_entropy (out , words ) + F .cross_entropy (lout , words ) + F .
341349 cross_entropy (rout , words )) * self .quesnet .lambda_input [0 ] / 3
342350 wloss *= self .lambda_loss [0 ]
343351 loss = loss + wloss
344352
345353 if ims is not None :
346- lifea = torch .masked_select (left_hid , imask .unsqueeze (1 ).bool ()) \
347- .view (- 1 , self .rnn_size )
348- lout = self .lioutput (lifea )
349- rifea = torch .masked_select (right_hid , imask .unsqueeze (1 ).bool ()) \
350- .view (- 1 , self .rnn_size )
351- rout = self .rioutput (rifea )
352- out = self .ioutput (torch .cat ([lifea , rifea ], dim = 1 ))
354+ lifea = torch .masked_select (left_hid . clone () , imask .unsqueeze (1 ).bool ()) \
355+ .view (- 1 , self .rnn_size ). clone ()
356+ lout = self .lioutput (lifea . clone () )
357+ rifea = torch .masked_select (right_hid . clone () , imask .unsqueeze (1 ).bool ()) \
358+ .view (- 1 , self .rnn_size ). clone ()
359+ rout = self .rioutput (rifea . clone () )
360+ out = self .ioutput (torch .cat ([lifea . clone () , rifea . clone () ], dim = 1 ). clone ( ))
353361 iloss = (self .quesnet .ie .loss (ims , out ) + self .quesnet .ie .loss (ims , lout ) + self .quesnet .ie .
354362 loss (ims , rout )) * self .quesnet .lambda_input [1 ] / 3
355363 iloss *= self .lambda_loss [0 ]
356364 loss = loss + iloss
357365
358366 if metas is not None :
359- lmfea = torch .masked_select (left_hid , mmask .unsqueeze (1 ).bool ()) \
360- .view (- 1 , self .rnn_size )
361- lout = self .lmoutput (lmfea )
362- rmfea = torch .masked_select (right_hid , mmask .unsqueeze (1 ).bool ()) \
363- .view (- 1 , self .rnn_size )
364- rout = self .rmoutput (rmfea )
365- out = self .moutput (torch .cat ([lmfea , rmfea ], dim = 1 ))
367+ lmfea = torch .masked_select (left_hid . clone () , mmask .unsqueeze (1 ).bool ()) \
368+ .view (- 1 , self .rnn_size ). clone ()
369+ lout = self .lmoutput (lmfea . clone () )
370+ rmfea = torch .masked_select (right_hid . clone () , mmask .unsqueeze (1 ).bool ()) \
371+ .view (- 1 , self .rnn_size ). clone ()
372+ rout = self .rmoutput (rmfea . clone () )
373+ out = self .moutput (torch .cat ([lmfea . clone () , rmfea . clone () ], dim = 1 ). clone ( ))
366374 mloss = (self .quesnet .me .loss (metas , out ) + self .quesnet .me .loss (metas , lout ) + self .quesnet .me .
367375 loss (metas , rout )) * self .quesnet .lambda_input [2 ] / 3
368376 mloss *= self .lambda_loss [0 ]
0 commit comments