Skip to content

Commit 984bc6a

Browse files
committed
Arm backend: Fix mypy warnings in test_fuse_duplicate...
Fix mypy warning about type. Signed-off-by: per.held@arm.com Change-Id: I09a5f75943c12b304a2e4d4ff6af8739021aeb44
1 parent 747fc6f commit 984bc6a

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

backends/arm/test/passes/test_fuse_duplicate_users_pass.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from typing import Tuple
6+
from typing import Dict, Tuple
77

88
import torch
99
from executorch.backends.arm._passes import FuseDuplicateUsersPass
@@ -13,7 +13,12 @@
1313
input_t = Tuple[torch.Tensor] # Input x
1414

1515

16-
class FuseaAvgPool(torch.nn.Module):
16+
class ModuleWithOps(torch.nn.Module):
17+
ops_before_pass: Dict[str, int]
18+
ops_after_pass: Dict[str, int]
19+
20+
21+
class FuseaAvgPool(ModuleWithOps):
1722
ops_before_pass = {
1823
"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 3,
1924
}
@@ -27,7 +32,7 @@ def forward(self, x):
2732
return self.avg(x) + self.avg(x) + self.avg(x)
2833

2934

30-
class FuseAvgPoolChain(torch.nn.Module):
35+
class FuseAvgPoolChain(ModuleWithOps):
3136
ops_before_pass = {
3237
"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 6,
3338
}
@@ -44,14 +49,14 @@ def forward(self, x):
4449
return first + second + third
4550

4651

47-
modules = {
52+
modules: Dict[str, ModuleWithOps] = {
4853
"fuse_avg_pool": FuseaAvgPool(),
4954
"fuse_avg_pool_chain": FuseAvgPoolChain(),
5055
}
5156

5257

5358
@common.parametrize("module", modules)
54-
def test_fuse_duplicate_ops_FP(module: torch.nn.Module):
59+
def test_fuse_duplicate_ops_FP(module: ModuleWithOps):
5560
pipeline = PassPipeline[input_t](
5661
module=module,
5762
test_data=(torch.ones(1, 1, 1, 1),),

0 commit comments

Comments
 (0)