Skip to content

Commit 211176d

Browse files
NXP backend: Fix mean.dim delegation and tests. (#14581)
### Summary This PR updates the delegation condition for the `aten,mean,dim` operator to better reflect the requirements of Neutron. ### Test plan Unit tests provided. cc @robert-kalmar @JakeStevens @digantdesai
1 parent b1a8a68 commit 211176d

File tree

3 files changed

+149
-17
lines changed

3 files changed

+149
-17
lines changed

backends/nxp/backend/ir/converter/node_converters/ops_converters/mean_dim_converter.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
# Copyright (c) 2025 NXP
2-
# All rights reserved.
1+
# Copyright 2025 NXP
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -20,6 +19,7 @@
2019
mean_options,
2120
)
2221
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
22+
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
2323
from torch.fx import Node
2424
from torch.nn import Parameter
2525

@@ -32,15 +32,33 @@ def _is_supported_on_target(
3232
parameters_mapping: dict[str, Parameter],
3333
custom_delegation_options: CustomDelegationOptions,
3434
) -> bool:
35-
dim = node.args[1]
3635
keepdim = node.args[2] if len(node.args) >= 3 else False
3736
rank = len(node.args[0].meta["val"].shape)
38-
dim = [d - rank if d > 0 else d for d in dim]
37+
dim = [MeanDimConverter._to_pos_dim(d, rank) for d in node.args[1]]
38+
39+
if rank != 4 or not keepdim:
40+
# neutron-converter/src/OperatorC/GlobalAvgPoolPlugin.cpp#74-77
41+
return False
3942

40-
# Only last 2 dimensions (H, W) and keepdim=True with rank=4 are supported on Neutron.
41-
if rank != 4 or dim not in [[-1, -2], [-2, -1]] or not keepdim:
43+
# The `mean.dim` gets converted to AveragePool by the NeutronConverter, so the channels must be a
44+
# multiple of `num_macs`.
45+
# neutron-converter/src/OperatorC/GlobalAvgPoolPlugin.cpp#59-85
46+
num_macs = neutron_target_spec.get_num_macs()
47+
channels_dim = 1 if node.meta[NXP_NODE_FORMAT].is_channels_first() else -1
48+
if (node.meta["val"].shape[channels_dim] % num_macs) != 0:
4249
return False
4350

51+
# Neutron only supports reduction over the spatial dimensions H, W.
52+
if node.meta[NXP_NODE_FORMAT].is_channels_first():
53+
# The input is NCHW. H and W are at indices 2 and 3.
54+
if dim not in [[2, 3], [3, 2]]:
55+
return False
56+
else:
57+
# The input is formatless. It can be considered as NHWC, as this is the way Neutron will look at
58+
# the dimensions. So H and W are the middle dimensions.
59+
if dim not in [[1, 2], [2, 1]]:
60+
return False
61+
4462
return True
4563

4664
@staticmethod

backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py

Lines changed: 123 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
16
import numpy as np
27
import pytest
38
import torch
@@ -8,10 +13,12 @@
813
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
914
from executorch.backends.nxp.tests.executors import (
1015
convert_run_compare,
16+
graph_contains_any_of_ops,
1117
ToChannelFirstPreprocess,
1218
ToChannelLastPreprocess,
1319
)
1420
from executorch.backends.nxp.tests.models import MeanDimConvModule, MeanDimLinearModule
21+
from executorch.exir.dialects._ops import ops as exir_ops
1522
from torch.export import ExportedProgram
1623

1724

@@ -21,19 +28,36 @@ def reseed_model_per_test_run():
2128
np.random.seed(23)
2229

2330

31+
class MeanDimModule(torch.nn.Module):
32+
def __init__(self, dim, keepdim):
33+
super().__init__()
34+
self.dim = dim
35+
self.keepdim = keepdim
36+
37+
def forward(self, x):
38+
return torch.mean(x, dim=self.dim, keepdim=self.keepdim)
39+
40+
2441
@pytest.mark.parametrize(
2542
"input_shape, dim",
2643
[
2744
pytest.param((1, 4, 8, 8), (-1, -2), id="Dim -1, -2."),
45+
pytest.param((1, 4, 8, 8), (-2, -1), id="Dim -2, -1."),
46+
pytest.param((1, 4, 8, 8), (2, 3), id="Dim 2, 3."),
47+
pytest.param((1, 4, 8, 8), (3, 2), id="Dim 3, 2."),
2848
],
2949
)
30-
def test_mean_dim_conv_quant_conversion(mocker, input_shape, dim, keeepdim=True):
31-
model = MeanDimConvModule(dim, keeepdim)
50+
def test_mean_dim_conv_quant_conversion(mocker, input_shape, dim, keepdim=True):
51+
model = MeanDimConvModule(dim, keepdim)
3252

3353
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
3454

3555
# Run conversion
36-
_ = to_quantized_edge_program(model, input_shape)
56+
ep = to_quantized_edge_program(model, input_shape).exported_program()
57+
58+
# Make sure the `mean.dim` was delegated.
59+
assert not graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim])
60+
assert any("lowered_module" in n.name for n in ep.graph.nodes)
3761

3862
# Capture generated model
3963
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
@@ -61,16 +85,16 @@ def test_mean_dim_conv_quant_conversion(mocker, input_shape, dim, keeepdim=True)
6185
],
6286
)
6387
@pytest.mark.parametrize(
64-
"keeepdim",
88+
"keepdim",
6589
[
6690
pytest.param(False, id="Don't keep dim."),
6791
pytest.param(True, id="Keep dim."),
6892
],
6993
)
7094
def test_mean_dim_linear_unsupported_quant_conversion(
71-
mocker, input_shape, dim, keeepdim
95+
mocker, input_shape, dim, keepdim
7296
):
73-
model = MeanDimLinearModule(dim, keeepdim)
97+
model = MeanDimLinearModule(dim, keepdim)
7498

7599
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
76100

@@ -107,14 +131,14 @@ def test_mean_dim_linear_unsupported_quant_conversion(
107131
],
108132
)
109133
@pytest.mark.parametrize(
110-
"keeepdim",
134+
"keepdim",
111135
[
112136
pytest.param(False, id="Don't keep dim."),
113137
pytest.param(True, id="Keep dim."),
114138
],
115139
)
116-
def test_mean_dim_conv_unsupported_quant_conversion(mocker, input_shape, dim, keeepdim):
117-
model = MeanDimConvModule(dim, keeepdim)
140+
def test_mean_dim_conv_unsupported_quant_conversion(mocker, input_shape, dim, keepdim):
141+
model = MeanDimConvModule(dim, keepdim)
118142

119143
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
120144

@@ -140,3 +164,93 @@ def test_mean_dim_conv_unsupported_quant_conversion(mocker, input_shape, dim, ke
140164
tflite_output_preprocess=ToChannelFirstPreprocess(),
141165
tfl_model=tflite_flatbuffers_model,
142166
)
167+
168+
169+
@pytest.mark.parametrize(
170+
"input_shape, dim",
171+
[
172+
pytest.param((1, 2, 3, 8), (1, 2), id="Dim 1, 2."),
173+
pytest.param((1, 2, 3, 8), (2, 1), id="Dim 2, 1."),
174+
pytest.param((1, 2, 3, 8), (-3, -2), id="Dim -3, -2."),
175+
pytest.param((1, 2, 3, 8), (-2, -3), id="Dim -2, -3."),
176+
],
177+
)
178+
def test_mean_dim__formatless__supported(mocker, input_shape, dim, keepdim=True):
179+
model = MeanDimModule(dim, keepdim)
180+
181+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
182+
183+
ep = to_quantized_edge_program(model, input_shape).exported_program()
184+
185+
# Make sure the `mean.dim` was delegated.
186+
assert not graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim])
187+
assert any("lowered_module" in n.name for n in ep.graph.nodes)
188+
189+
# Capture generated model
190+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
191+
192+
# Capture converted program
193+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
194+
195+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8)
196+
197+
convert_run_compare(
198+
exported_program,
199+
input_data=input_data,
200+
tfl_model=tflite_flatbuffers_model,
201+
atol=1,
202+
)
203+
204+
205+
@pytest.mark.parametrize(
206+
"input_shape, dim",
207+
[
208+
pytest.param((1, 2, 3, 8), (2, 3), id="Dim 2, 3."),
209+
],
210+
)
211+
def test_mean_dim__formatless__unsupported(input_shape, dim, keepdim=True):
212+
model = MeanDimModule(dim, keepdim)
213+
214+
ep = to_quantized_edge_program(model, input_shape).exported_program()
215+
216+
# Make sure the `mean.dim` was NOT delegated.
217+
assert graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim])
218+
assert not any("lowered_module" in n.name for n in ep.graph.nodes)
219+
220+
221+
@pytest.mark.parametrize(
222+
"input_shape, dim",
223+
[
224+
pytest.param(
225+
(1, 8, 8, 4), (1, 2), id="Dim 1, 2 (supported), channels = 4 (unsupported)."
226+
),
227+
],
228+
)
229+
def test_mean_dim__formatless__unsupported_channels(input_shape, dim, keepdim=True):
230+
model = MeanDimModule(dim, keepdim)
231+
232+
ep = to_quantized_edge_program(model, input_shape).exported_program()
233+
234+
# Make sure the `mean.dim` was NOT delegated.
235+
assert graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim])
236+
assert not any("lowered_module" in n.name for n in ep.graph.nodes)
237+
238+
239+
@pytest.mark.parametrize(
240+
"input_shape, dim",
241+
[
242+
pytest.param(
243+
(1, 4, 8, 8), (2, 3), id="Dim 2, 3 (supported), channels = 5 (unsupported)."
244+
),
245+
],
246+
)
247+
def test_mean_dim__channels_first__unsupported_channels(input_shape, dim, keepdim=True):
248+
model = MeanDimConvModule(
249+
dim, keepdim, out_channels=5
250+
) # Only multiples of 8 (num_macs) are supported.
251+
252+
# Run conversion
253+
ep = to_quantized_edge_program(model, input_shape).exported_program()
254+
255+
# Make sure the `mean.dim` was NOT delegated.
256+
assert graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim])

backends/nxp/tests/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,9 +494,9 @@ def forward(self, x):
494494

495495

496496
class MeanDimConvModule(torch.nn.Module):
497-
def __init__(self, dim, keepdim):
497+
def __init__(self, dim, keepdim, out_channels=8):
498498
super().__init__()
499-
self.conv = Conv2dModule(stride=1, padding=1)
499+
self.conv = Conv2dModule(stride=1, padding=1, out_channels=out_channels)
500500
self.dim = dim
501501
self.keepdim = keepdim
502502

0 commit comments

Comments
 (0)