Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions src/hierarchical_att_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,10 @@ def __init__(self, word_hidden_size, sent_hidden_size, batch_size, num_classes,
self.sent_att_net = SentAttNet(sent_hidden_size, word_hidden_size, num_classes)
self._init_hidden_state()

def _init_hidden_state(self, last_batch_size=None):
if last_batch_size:
batch_size = last_batch_size
else:
batch_size = self.batch_size
self.word_hidden_state = torch.zeros(2, batch_size, self.word_hidden_size)
self.sent_hidden_state = torch.zeros(2, batch_size, self.sent_hidden_size)
def _init_hidden_state(self, current_batch_size):
# Hidden state initialization always takes batch size from the train/eval batch
self.word_hidden_state = torch.zeros(2, current_batch_size, self.word_hidden_size)
self.sent_hidden_state = torch.zeros(2, current_batch_size, self.sent_hidden_size)
if torch.cuda.is_available():
self.word_hidden_state = self.word_hidden_state.cuda()
self.sent_hidden_state = self.sent_hidden_state.cuda()
Expand Down
2 changes: 1 addition & 1 deletion src/sent_att_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def forward(self, input, hidden_state):

f_output, h_output = self.gru(input, hidden_state)
output = matrix_mul(f_output, self.sent_weight, self.sent_bias)
output = matrix_mul(output, self.context_weight).permute(1, 0)
output = matrix_mul(output, self.context_weight,apply_tanh=False).permute(1, 0)
output = F.softmax(output)
output = element_wise_mul(f_output, output.permute(1, 0)).squeeze(0)
output = self.fc(output)
Expand Down
6 changes: 4 additions & 2 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ def get_evaluation(y_true, y_prob, list_metrics):
output['confusion_matrix'] = str(metrics.confusion_matrix(y_true, y_pred))
return output

def matrix_mul(input, weight, bias=False):
def matrix_mul(input, weight, bias=False, apply_tanh=True):
feature_list = []
for feature in input:
feature = torch.mm(feature, weight)
if isinstance(bias, torch.nn.parameter.Parameter):
feature = feature + bias.expand(feature.size()[0], bias.size()[1])
feature = torch.tanh(feature).unsqueeze(0)
if apply_tanh:
feature = torch.tanh(feature)
feature = feature.unsqueeze(0)
feature_list.append(feature)

return torch.cat(feature_list, 0).squeeze()
Expand Down
2 changes: 1 addition & 1 deletion src/word_att_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def forward(self, input, hidden_state):
output = self.lookup(input)
f_output, h_output = self.gru(output.float(), hidden_state) # feature output and hidden state output
output = matrix_mul(f_output, self.word_weight, self.word_bias)
output = matrix_mul(output, self.context_weight).permute(1,0)
output = matrix_mul(output, self.context_weight,apply_tanh=False).permute(1,0)
output = F.softmax(output)
output = element_wise_mul(f_output,output.permute(1,0))

Expand Down
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def train(opt):
feature = feature.cuda()
label = label.cuda()
optimizer.zero_grad()
model._init_hidden_state()
# Adding batch size to the
train_num_sample = len(label)
model._init_hidden_state(train_num_sample)
predictions = model(feature)
loss = criterion(predictions, label)
loss.backward()
Expand Down