-
Notifications
You must be signed in to change notification settings - Fork 145
Open
Description
Hello @pemami4911,
The problem really was with the mask. I've fixed it and the network started to learn. My Decoder now is:
class Decoder(nn.Module):
def __init__(self, feactures_dim,hidden_size, n_layers=1):
super(Decoder, self).__init__()
self.W1 = Var(hidden_size, hidden_size)
self.W2 = Var(hidden_size, hidden_size)
self.b2 = Var(hidden_size)
self.V = Var(hidden_size)
self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True, num_layers=n_layers)
def forward(self, input, hidden, enc_outputs,mask, prev_idxs):
w1e = torch.matmul(enc_outputs,self.W1)
w2h = (torch.matmul(hidden[0][-1], self.W2) + self.b2).unsqueeze(1)
u = F.tanh(w1e + w2h)
a = torch.matmul(u, self.V)
a, mask = self.apply_mask( a, mask, prev_idxs)
a = F.softmax(a)
res, hidden = self.lstm(input, hidden)
return a, hidden, mask
def apply_mask(self, attentions, mask, prev_idxs):
if mask is None:
mask = Variable(torch.ones(attentions.size())).cuda()
maskk = mask.clone()
if prev_idxs is not None:
for i,j in zip(range(attentions.size(0)),prev_idxs.data):
maskk[i,j[0]] = 0
masked= maskk*attentions + maskk.log()
else:
masked = attentions
return masked, maskk
For n=10, I'm obtaining the following during training the AC version:
Step 0
Average train model loss: -81.07959747314453
Average train critic loss: 29.443866729736328
Average train pred-reward: -0.08028505742549896
Average train reward: 5.2866997718811035
Average loss: -15.106804847717285
------------------------
Step 1000
Average train model loss: -0.7814755792869255
Average train critic loss: 0.7740849611759186
Average train pred-reward: 4.219553744537756
Average train reward: 4.272005982398987
Average loss: -6.201663847446442
------------------------
(...)
Step 19000
Average train model loss: -0.06441724334075116
Average train critic loss: 0.1361817416474223
Average train pred-reward: 3.0573679950237276
Average train reward: 3.0583059163093567
Average loss: -1.5689961900115013
I've checked and the returned solutions are all feasible so it seems that it is really converging.
I will clean my code and hope that by the end of the week I will have some training history plots and test set validation.
If you like, I can share my notebook.
Best regards
unnir and shi5
Metadata
Metadata
Assignees
Labels
No labels