Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 4aee21a

Browse files
authored
Fix CUDA device mismatch (#371)
Making devices for zero or target tensors consistent with inputs. Signed-off-by: Yaochen Xie <ethanycx@tamu.edu>
1 parent b4c66b7 commit 4aee21a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

generative/losses/adversarial_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def get_target_tensor(self, input: torch.FloatTensor, target_is_real: bool) -> t
8888
Returns:
8989
"""
9090
filling_label = self.real_label if target_is_real else self.fake_label
91-
label_tensor = torch.tensor(1).fill_(filling_label).type(input.type())
91+
label_tensor = torch.tensor(1).fill_(filling_label).type(input.type()).to(input[0].device)
9292
label_tensor.requires_grad_(False)
9393
return label_tensor.expand_as(input)
9494

@@ -101,7 +101,7 @@ def get_zero_tensor(self, input: torch.FloatTensor) -> torch.Tensor:
101101
Returns:
102102
"""
103103

104-
zero_label_tensor = torch.tensor(0).type(input[0].type())
104+
zero_label_tensor = torch.tensor(0).type(input[0].type()).to(input[0].device)
105105
zero_label_tensor.requires_grad_(False)
106106
return zero_label_tensor.expand_as(input)
107107

0 commit comments

Comments
 (0)