diff --git a/models/FCModel.py b/models/FCModel.py index f6f9a44b..6f43d7f8 100644 --- a/models/FCModel.py +++ b/models/FCModel.py @@ -177,10 +177,10 @@ def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): it = it.view(-1).long() else: if temperature == 1.0: - prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) + prob_prev = torch.exp(logprobs.data) # .cpu() # fetch prev distribution: shape Nx(M+1) else: # scale logprobs by temperature - prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() + prob_prev = torch.exp(torch.div(logprobs.data, temperature)) # .cpu() it = torch.multinomial(prob_prev, 1).cuda() sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions it = it.view(-1).long() # and flatten indices for downstream processing diff --git a/models/OldModel.py b/models/OldModel.py index 4a654034..10bd5a31 100644 --- a/models/OldModel.py +++ b/models/OldModel.py @@ -148,10 +148,10 @@ def sample(self, fc_feats, att_feats, opt={}): it = it.view(-1).long() else: if temperature == 1.0: - prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) + prob_prev = torch.exp(logprobs.data) # .cpu() # fetch prev distribution: shape Nx(M+1) else: # scale logprobs by temperature - prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() + prob_prev = torch.exp(torch.div(logprobs.data, temperature)) # .cpu() it = torch.multinomial(prob_prev, 1).cuda() sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions it = it.view(-1).long() # and flatten indices for downstream processing diff --git a/models/ShowTellModel.py b/models/ShowTellModel.py index e466bef7..7f85dc87 100644 --- a/models/ShowTellModel.py +++ b/models/ShowTellModel.py @@ -147,10 +147,10 @@ def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): it = it.view(-1).long() else: if temperature == 1.0: - prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) + prob_prev = torch.exp(logprobs.data) # .cpu() # fetch prev distribution: shape Nx(M+1) else: # scale logprobs by temperature - prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() + prob_prev = torch.exp(torch.div(logprobs.data, temperature)) # .cpu() it = torch.multinomial(prob_prev, 1).cuda() sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions it = it.view(-1).long() # and flatten indices for downstream processing