Skip to content

Commit d07a49a

Browse files
Cortex_m backend: Add IO quantizers + tests of non rescaling ops (#15590)
A number of ops only handles shape/meta-data without changing the dynamic range. In these cases, no rescaling needs to be performed and the int8 portable_ops kernel can be used directly. A new test is added to ensure this behaviour, as well as a test showing how operators which does change the dynamic range (SUB) are not supported. To support quantization of graphs with no-rescale ops in the beginning/ end of the graph, two new quantizers InputQuantizer and OutputQuantizer are introduced. By explicitly stating the dtpye of the input/output, no-rescale ops inherit dtypes from them as with any other op. Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent 4ea9ddf commit d07a49a

File tree

2 files changed

+216
-0
lines changed

2 files changed

+216
-0
lines changed

backends/cortex_m/quantizer/quantizer.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
INT8_BINARY_OPS_OPERATOR_CONFIG,
1717
INT8_LINEAR_OPERATOR_CONFIG,
1818
)
19+
from executorch.backends.cortex_m.quantizer.quantization_configs import (
20+
INT8_PER_TENSOR_CONFIG,
21+
)
1922
from torch._ops import OpOverload
2023
from torch.fx import GraphModule, Node
2124
from torchao.quantization.pt2e.quantizer import (
@@ -50,6 +53,8 @@ def __init__(self) -> None:
5053
INT8_BINARY_OPS_OPERATOR_CONFIG, filter_fn=self.broadcasting_filter
5154
),
5255
OperatorConfigQuantizer(INT8_LINEAR_OPERATOR_CONFIG),
56+
InputQuantizer(INT8_PER_TENSOR_CONFIG),
57+
OutputQuantizer(INT8_PER_TENSOR_CONFIG),
5358
]
5459
super().__init__(quantizers)
5560

@@ -197,3 +202,58 @@ def annotate(self, model: GraphModule) -> None:
197202

198203
def validate(self, model: GraphModule) -> bool:
199204
return True
205+
206+
207+
class InputQuantizer(Quantizer):
208+
"""
209+
Quantizes only the input activations of the graph.
210+
"""
211+
212+
def __init__(
213+
self,
214+
quantization_config: QuantizationConfig,
215+
filter_fn: Callable[[Node], bool] = lambda node: False,
216+
) -> None:
217+
self.quantization_config = quantization_config
218+
self.filter_fn = filter_fn
219+
220+
def annotate(self, model: GraphModule) -> None:
221+
for node in model.graph.nodes:
222+
is_placeholder = node.op == "placeholder"
223+
is_filtered = self.filter_fn(node)
224+
if is_placeholder and not is_filtered:
225+
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
226+
{}, self.quantization_config.output_activation
227+
)
228+
229+
def validate(self, model: GraphModule) -> bool:
230+
return True
231+
232+
233+
class OutputQuantizer(Quantizer):
234+
"""
235+
Quantizes only the output activations of the graph.
236+
"""
237+
238+
def __init__(
239+
self,
240+
quantization_config: QuantizationConfig,
241+
filter_fn: Callable[[Node], bool] = lambda node: False,
242+
) -> None:
243+
self.quantization_config = quantization_config
244+
self.filter_fn = filter_fn
245+
246+
def annotate(self, model: GraphModule) -> None:
247+
output_node = model.graph.output_node()
248+
input_qspec_map = {
249+
n: self.quantization_config.input_activation
250+
for n in output_node.all_input_nodes
251+
if not self.filter_fn(n)
252+
}
253+
output_qspec = self.quantization_config.output_activation
254+
output_node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
255+
input_qspec_map, output_qspec
256+
)
257+
258+
def validate(self, model: GraphModule) -> bool:
259+
return True
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import pytest
8+
import torch
9+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
10+
from executorch.backends.arm.test.common import parametrize
11+
from executorch.backends.cortex_m.test.tester import (
12+
CortexMTester,
13+
McuTestCase,
14+
ramp_tensor,
15+
)
16+
from executorch.exir.dialects._ops import ops as exir_ops
17+
18+
19+
class CortexMInheritAllOps(torch.nn.Module):
20+
ops_before_transforms = {
21+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3,
22+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3,
23+
}
24+
25+
ops_after_transforms = {
26+
"executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1,
27+
"executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1,
28+
}
29+
30+
def forward(self, x):
31+
# x shape: (1, 3, 4, 5)
32+
x = x + x
33+
x = torch.ops.aten.squeeze.default(x) # Remove dim 0: (3, 4, 5)
34+
x = torch.ops.aten.unsqueeze.default(x, 0) # Add dim at 0: (1, 3, 4, 5)
35+
x = torch.ops.aten.squeeze_copy.default(x) # (3, 4, 5)
36+
x = torch.ops.aten.unsqueeze_copy.default(x, 0) # (1, 3, 4, 5)
37+
x = torch.ops.aten.squeeze.dims(x, [0]) # (3, 4, 5)
38+
x = torch.ops.aten.squeeze_copy.dim(
39+
x, 0
40+
) # Remove first dim if size 1, otherwise same
41+
x = torch.ops.aten.squeeze.dim(x, 0) # Same
42+
x = torch.ops.aten.unbind.int(x, 0)[0] # Unbind and take first: (4, 5)
43+
x = torch.ops.aten.reshape.default(x, (1, 4, 5, 1)) # (1, 4, 5, 1)
44+
x = torch.ops.aten.repeat.default(x, [1, 1, 1, 2]) # (1, 4, 5, 2)
45+
x = torch.ops.aten.view.default(x, (1, 4, 10)) # (1, 4, 10)
46+
target_shape = torch.zeros(1, 4, 10)
47+
x = torch.ops.aten.view_as.default(x, target_shape) # (1, 4, 10)
48+
x = torch.ops.aten.view_copy.default(x, (1, 2, 20)) # (1, 2, 20)
49+
x = torch.ops.aten.unflatten.int(x, 2, [4, 5]) # (1, 2, 4, 5)
50+
x = torch.ops.aten.flatten.using_ints(x, 1, 3) # (1, 40)
51+
x = torch.ops.aten.repeat_interleave.self_int(x, 2, 1) # (1, 80)
52+
x = torch.ops.aten.expand_copy.default(x, (2, 80)) # (2, 80)
53+
x = torch.ops.aten.expand.default(x, (2, 80)) # (2, 80)
54+
x = torch.ops.aten.tile.default(x, [1, 1]) # (2, 80)
55+
x = torch.ops.aten.split.Tensor(x, 40, 1)[0] # (2, 40)
56+
x = torch.ops.aten.split_with_sizes.default(x, [20, 20], 1)[0] # (2, 20)
57+
x = torch.ops.aten.split_copy.Tensor(x, 10, 1)[0] # (2, 10)
58+
x = torch.ops.aten.chunk.default(x, 2, 1)[0] # (2, 5)
59+
x = torch.ops.aten.pad.default(x, [1, 1, 0, 0], "constant", 0.0) # (2, 7)
60+
x = torch.ops.aten.select.int(x, 1, 0) # (2,)
61+
x = torch.ops.aten.select_copy.int(x, 0, 0) # scalar -> need to reshape
62+
x = torch.ops.aten.unsqueeze.default(x, 0) # (1,)
63+
x = torch.ops.aten.unsqueeze.default(x, 1) # (1, 1)
64+
x = torch.ops.aten.slice.Tensor(x, 0, 0, 1) # (1, 1)
65+
x = torch.ops.aten.slice_copy.Tensor(x, 1, 0, 1) # (1, 1)
66+
x = torch.ops.aten.reshape.default(x, (1, 1)) # Ensure shape for transpose
67+
x = torch.ops.aten.transpose.int(x, 0, 1) # (1, 1)
68+
x = torch.ops.aten.transpose_copy.int(x, 0, 1) # (1, 1)
69+
x = torch.ops.aten.t_copy.default(x) # (1, 1)
70+
x = torch.ops.aten.contiguous.default(x) # (1, 1)
71+
x = torch.ops.aten.permute.default(x, [1, 0]) # (1, 1)
72+
x = torch.ops.aten.permute_copy.default(x, [0, 1]) # (1, 1)
73+
x = torch.ops.aten.flip.default(x, [0]) # (1, 1)
74+
y = torch.zeros_like(x)
75+
x = torch.ops.aten.copy_.default(y, x) # (1, 1)
76+
x = torch.ops.aten.clone.default(x) # (1, 1)
77+
return x
78+
79+
80+
class CortexMOnlyInheritOps(torch.nn.Module):
81+
ops_before_transforms = {
82+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
83+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2,
84+
}
85+
86+
ops_after_transforms = {
87+
"executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1,
88+
"executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1,
89+
}
90+
91+
def forward(self, x):
92+
return torch.permute(torch.clone(x), (0, 1, 3, 2))
93+
94+
95+
class CortexMQuantizeNonSupportedSub(torch.nn.Module):
96+
ops_before_transforms = {}
97+
98+
ops_after_transforms = {}
99+
100+
def forward(self, x, y):
101+
return y - x
102+
103+
104+
test_cases = {
105+
"all_ops": McuTestCase(
106+
CortexMInheritAllOps(),
107+
(ramp_tensor(0, 10, (1, 3, 4, 5)),),
108+
),
109+
"only_inherit_ops": McuTestCase(
110+
CortexMOnlyInheritOps(),
111+
(ramp_tensor(0, 10, (1, 3, 4, 5)),),
112+
),
113+
}
114+
115+
116+
@parametrize("test_case", test_cases)
117+
def test_inherit_int8_dtype(test_case):
118+
"""
119+
Test that ops which does not change dynamic range are able to use int8 portable kernels.
120+
"""
121+
tester = CortexMTester(test_case.model, test_case.example_inputs)
122+
tester.test_dialect(
123+
test_case.model.ops_before_transforms, test_case.model.ops_after_transforms
124+
)
125+
126+
# Check that all nodes in the graph are in int8
127+
artifact = tester.get_artifact()
128+
for node in artifact.exported_program().module().graph.nodes:
129+
if node.op != "call_function":
130+
continue
131+
if node.target == exir_ops.edge.cortex_m.dequantize_per_tensor.default:
132+
continue
133+
134+
assert get_first_fake_tensor(node).dtype == torch.int8, f"{node.name}"
135+
136+
137+
test_cases = {
138+
"sub": McuTestCase(
139+
CortexMQuantizeNonSupportedSub(),
140+
(ramp_tensor(0, 10, (1, 3, 4, 5)), ramp_tensor(0, 1, (1, 3, 4, 5))),
141+
),
142+
}
143+
144+
145+
@pytest.mark.xfail(
146+
reason="Non handled ops which change dynamic range currently not supported."
147+
)
148+
@parametrize("test_case", test_cases)
149+
def test_quantize_unsupported_op(test_case):
150+
"""
151+
Test an op which does change dynamic range and which is not suported by CMSIS-NN. Currently not supported.
152+
"""
153+
tester = CortexMTester(test_case.model, test_case.example_inputs)
154+
tester.test_dialect(
155+
test_case.model.ops_before_transforms, test_case.model.ops_after_transforms
156+
)

0 commit comments

Comments
 (0)