Skip to content

Commit 9c3a147

Browse files
committed
Fix #7
1 parent 5bba9be commit 9c3a147

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/torchcrf/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,6 @@ def _log_sum_exp(tensor: Variable, dim: int) -> Variable:
306306
# Add offset back
307307
return offset + safe_log_sum_exp
308308

309-
def _new(self, *args, **kwargs) -> torch.FloatTensor:
309+
def _new(self, *args, **kwargs) -> Union[torch.FloatTensor, torch.cuda.FloatTensor]:
310310
param = next(self.parameters())
311311
return param.data.new(*args, **kwargs)

0 commit comments

Comments
 (0)