1+ # Copyright 2025 NXP
2+ #
3+ # This source code is licensed under the BSD-style license found in the
4+ # LICENSE file in the root directory of this source tree.
5+
16import numpy as np
27import pytest
38import torch
813from executorch .backends .nxp .tests .executorch_pipeline import to_quantized_edge_program
914from executorch .backends .nxp .tests .executors import (
1015 convert_run_compare ,
16+ graph_contains_any_of_ops ,
1117 ToChannelFirstPreprocess ,
1218 ToChannelLastPreprocess ,
1319)
1420from executorch .backends .nxp .tests .models import MeanDimConvModule , MeanDimLinearModule
21+ from executorch .exir .dialects ._ops import ops as exir_ops
1522from torch .export import ExportedProgram
1623
1724
@@ -21,19 +28,36 @@ def reseed_model_per_test_run():
2128 np .random .seed (23 )
2229
2330
31+ class MeanDimModule (torch .nn .Module ):
32+ def __init__ (self , dim , keepdim ):
33+ super ().__init__ ()
34+ self .dim = dim
35+ self .keepdim = keepdim
36+
37+ def forward (self , x ):
38+ return torch .mean (x , dim = self .dim , keepdim = self .keepdim )
39+
40+
2441@pytest .mark .parametrize (
2542 "input_shape, dim" ,
2643 [
2744 pytest .param ((1 , 4 , 8 , 8 ), (- 1 , - 2 ), id = "Dim -1, -2." ),
45+ pytest .param ((1 , 4 , 8 , 8 ), (- 2 , - 1 ), id = "Dim -2, -1." ),
46+ pytest .param ((1 , 4 , 8 , 8 ), (2 , 3 ), id = "Dim 2, 3." ),
47+ pytest .param ((1 , 4 , 8 , 8 ), (3 , 2 ), id = "Dim 3, 2." ),
2848 ],
2949)
30- def test_mean_dim_conv_quant_conversion (mocker , input_shape , dim , keeepdim = True ):
31- model = MeanDimConvModule (dim , keeepdim )
50+ def test_mean_dim_conv_quant_conversion (mocker , input_shape , dim , keepdim = True ):
51+ model = MeanDimConvModule (dim , keepdim )
3252
3353 converter_spy = mocker .spy (EdgeProgramToIRConverter , "convert_program" )
3454
3555 # Run conversion
36- _ = to_quantized_edge_program (model , input_shape )
56+ ep = to_quantized_edge_program (model , input_shape ).exported_program ()
57+
58+ # Make sure the `mean.dim` was delegated.
59+ assert not graph_contains_any_of_ops (ep .graph , [exir_ops .edge .aten .mean .dim ])
60+ assert any ("lowered_module" in n .name for n in ep .graph .nodes )
3761
3862 # Capture generated model
3963 tflite_flatbuffers_model , io_formats = converter_spy .spy_return
@@ -61,16 +85,16 @@ def test_mean_dim_conv_quant_conversion(mocker, input_shape, dim, keeepdim=True)
6185 ],
6286)
6387@pytest .mark .parametrize (
64- "keeepdim " ,
88+ "keepdim " ,
6589 [
6690 pytest .param (False , id = "Don't keep dim." ),
6791 pytest .param (True , id = "Keep dim." ),
6892 ],
6993)
7094def test_mean_dim_linear_unsupported_quant_conversion (
71- mocker , input_shape , dim , keeepdim
95+ mocker , input_shape , dim , keepdim
7296):
73- model = MeanDimLinearModule (dim , keeepdim )
97+ model = MeanDimLinearModule (dim , keepdim )
7498
7599 converter_spy = mocker .spy (EdgeProgramToIRConverter , "convert_program" )
76100
@@ -107,14 +131,14 @@ def test_mean_dim_linear_unsupported_quant_conversion(
107131 ],
108132)
109133@pytest .mark .parametrize (
110- "keeepdim " ,
134+ "keepdim " ,
111135 [
112136 pytest .param (False , id = "Don't keep dim." ),
113137 pytest .param (True , id = "Keep dim." ),
114138 ],
115139)
116- def test_mean_dim_conv_unsupported_quant_conversion (mocker , input_shape , dim , keeepdim ):
117- model = MeanDimConvModule (dim , keeepdim )
140+ def test_mean_dim_conv_unsupported_quant_conversion (mocker , input_shape , dim , keepdim ):
141+ model = MeanDimConvModule (dim , keepdim )
118142
119143 converter_spy = mocker .spy (EdgeProgramToIRConverter , "convert_program" )
120144
@@ -140,3 +164,93 @@ def test_mean_dim_conv_unsupported_quant_conversion(mocker, input_shape, dim, ke
140164 tflite_output_preprocess = ToChannelFirstPreprocess (),
141165 tfl_model = tflite_flatbuffers_model ,
142166 )
167+
168+
169+ @pytest .mark .parametrize (
170+ "input_shape, dim" ,
171+ [
172+ pytest .param ((1 , 2 , 3 , 8 ), (1 , 2 ), id = "Dim 1, 2." ),
173+ pytest .param ((1 , 2 , 3 , 8 ), (2 , 1 ), id = "Dim 2, 1." ),
174+ pytest .param ((1 , 2 , 3 , 8 ), (- 3 , - 2 ), id = "Dim -3, -2." ),
175+ pytest .param ((1 , 2 , 3 , 8 ), (- 2 , - 3 ), id = "Dim -2, -3." ),
176+ ],
177+ )
178+ def test_mean_dim__formatless__supported (mocker , input_shape , dim , keepdim = True ):
179+ model = MeanDimModule (dim , keepdim )
180+
181+ converter_spy = mocker .spy (EdgeProgramToIRConverter , "convert_program" )
182+
183+ ep = to_quantized_edge_program (model , input_shape ).exported_program ()
184+
185+ # Make sure the `mean.dim` was delegated.
186+ assert not graph_contains_any_of_ops (ep .graph , [exir_ops .edge .aten .mean .dim ])
187+ assert any ("lowered_module" in n .name for n in ep .graph .nodes )
188+
189+ # Capture generated model
190+ tflite_flatbuffers_model , io_formats = converter_spy .spy_return
191+
192+ # Capture converted program
193+ exported_program : ExportedProgram = converter_spy .call_args .args [1 ]
194+
195+ input_data = (np .random .random (input_shape ).astype (np .float32 ) * 50 ).astype (np .int8 )
196+
197+ convert_run_compare (
198+ exported_program ,
199+ input_data = input_data ,
200+ tfl_model = tflite_flatbuffers_model ,
201+ atol = 1 ,
202+ )
203+
204+
205+ @pytest .mark .parametrize (
206+ "input_shape, dim" ,
207+ [
208+ pytest .param ((1 , 2 , 3 , 8 ), (2 , 3 ), id = "Dim 2, 3." ),
209+ ],
210+ )
211+ def test_mean_dim__formatless__unsupported (input_shape , dim , keepdim = True ):
212+ model = MeanDimModule (dim , keepdim )
213+
214+ ep = to_quantized_edge_program (model , input_shape ).exported_program ()
215+
216+ # Make sure the `mean.dim` was NOT delegated.
217+ assert graph_contains_any_of_ops (ep .graph , [exir_ops .edge .aten .mean .dim ])
218+ assert not any ("lowered_module" in n .name for n in ep .graph .nodes )
219+
220+
221+ @pytest .mark .parametrize (
222+ "input_shape, dim" ,
223+ [
224+ pytest .param (
225+ (1 , 8 , 8 , 4 ), (1 , 2 ), id = "Dim 1, 2 (supported), channels = 4 (unsupported)."
226+ ),
227+ ],
228+ )
229+ def test_mean_dim__formatless__unsupported_channels (input_shape , dim , keepdim = True ):
230+ model = MeanDimModule (dim , keepdim )
231+
232+ ep = to_quantized_edge_program (model , input_shape ).exported_program ()
233+
234+ # Make sure the `mean.dim` was NOT delegated.
235+ assert graph_contains_any_of_ops (ep .graph , [exir_ops .edge .aten .mean .dim ])
236+ assert not any ("lowered_module" in n .name for n in ep .graph .nodes )
237+
238+
239+ @pytest .mark .parametrize (
240+ "input_shape, dim" ,
241+ [
242+ pytest .param (
243+ (1 , 4 , 8 , 8 ), (2 , 3 ), id = "Dim 2, 3 (supported), channels = 5 (unsupported)."
244+ ),
245+ ],
246+ )
247+ def test_mean_dim__channels_first__unsupported_channels (input_shape , dim , keepdim = True ):
248+ model = MeanDimConvModule (
249+ dim , keepdim , out_channels = 5
250+ ) # Only multiples of 8 (num_macs) are supported.
251+
252+ # Run conversion
253+ ep = to_quantized_edge_program (model , input_shape ).exported_program ()
254+
255+ # Make sure the `mean.dim` was NOT delegated.
256+ assert graph_contains_any_of_ops (ep .graph , [exir_ops .edge .aten .mean .dim ])
0 commit comments