Skip to content

Commit ad27841

Browse files
authored
Arm backend: Fix mypy warnings in test_insert_int32_casts_after_... (#15630)
Fix mypy warnings in test_insert_int32_casts_after_int64_placeholders_pass.py about using Tensor instead of LongTensor. Signed-off-by: per.held@arm.com
1 parent d07a49a commit ad27841

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,13 @@ def test_int64_model_tosa_FP():
5353
class UpcastToInt64ForIndexCopyInplaceModel(torch.nn.Module):
5454
aten_op = "torch.ops.aten.index_copy_.default"
5555

56-
def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.tensor):
56+
def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.Tensor):
5757
return x.index_copy_(0, index, y)
5858

5959
def get_inputs(self) -> input_t3:
6060
return (
6161
torch.zeros(5, 3),
62-
torch.tensor([0, 4, 2]),
62+
torch.LongTensor([0, 4, 2]),
6363
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float),
6464
)
6565

@@ -85,13 +85,13 @@ def test_upcast_to_int64_for_index_copy_inplace_tosa_INT():
8585
class UpcastToInt64ForIndexCopyModel(torch.nn.Module):
8686
aten_op = "torch.ops.aten.index_copy.default"
8787

88-
def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.tensor):
88+
def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.Tensor):
8989
return x.index_copy(0, index, y)
9090

9191
def get_inputs(self) -> input_t3:
9292
return (
9393
torch.zeros(5, 3),
94-
torch.tensor([0, 4, 2]),
94+
torch.LongTensor([0, 4, 2]),
9595
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float),
9696
)
9797

0 commit comments

Comments
 (0)