33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
55
6- from typing import Tuple
6+ from typing import Dict , Tuple
77
88import torch
99from executorch .backends .arm ._passes import FuseDuplicateUsersPass
1313input_t = Tuple [torch .Tensor ] # Input x
1414
1515
16- class FuseaAvgPool (torch .nn .Module ):
16+ class ModuleWithOps (torch .nn .Module ):
17+ ops_before_pass : Dict [str , int ]
18+ ops_after_pass : Dict [str , int ]
19+
20+
21+ class FuseaAvgPool (ModuleWithOps ):
1722 ops_before_pass = {
1823 "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default" : 3 ,
1924 }
@@ -27,7 +32,7 @@ def forward(self, x):
2732 return self .avg (x ) + self .avg (x ) + self .avg (x )
2833
2934
30- class FuseAvgPoolChain (torch . nn . Module ):
35+ class FuseAvgPoolChain (ModuleWithOps ):
3136 ops_before_pass = {
3237 "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default" : 6 ,
3338 }
@@ -44,14 +49,14 @@ def forward(self, x):
4449 return first + second + third
4550
4651
47- modules = {
52+ modules : Dict [ str , ModuleWithOps ] = {
4853 "fuse_avg_pool" : FuseaAvgPool (),
4954 "fuse_avg_pool_chain" : FuseAvgPoolChain (),
5055}
5156
5257
5358@common .parametrize ("module" , modules )
54- def test_fuse_duplicate_ops_FP (module : torch . nn . Module ):
59+ def test_fuse_duplicate_ops_FP (module : ModuleWithOps ):
5560 pipeline = PassPipeline [input_t ](
5661 module = module ,
5762 test_data = (torch .ones (1 , 1 , 1 , 1 ),),
0 commit comments