def forward(self, inputs):
# Avoid breaking if the last batch has a different size
batch_size = inputs.size(0)
if batch_size != self.batch_size:
self.batch_size = batch_size
encoded = self.encoder(inputs)
output, hidden = self.rnn(encoded, self.init_hidden())
output = self.decoder(output[:, :, -1]).squeeze()
return output