Skip to content

Commit 2aeee9b

Browse files
Cortex_m backend: Add mul operator (#15591)
Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent 211176d commit 2aeee9b

File tree

13 files changed

+275
-54
lines changed

13 files changed

+275
-54
lines changed

backends/cortex_m/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ set(_cortex_m_kernels__srcs
5757
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp
5858
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp
5959
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_linear.cpp
60+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_mul.cpp
6061
)
6162

6263
# Generate C++ bindings to register kernels into Executorch

backends/cortex_m/ops/op_quantized_add.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,15 @@ Tensor& quantized_add_out(
7878
output_mult,
7979
output_shift_val);
8080

81+
// Note 1: The CMSIS-NN kernel implementation uses offsets which are always
82+
// added to the data, whereas zero_points are subtracted when dequantizing
83+
// (for the inputs) and added when quantizing (for the output). Hence the
84+
// negative signs required here.
85+
86+
// Note 2: It is not possible to perform the same rewrite as for mul for
87+
// addition. To preserve precision when rescaling the inputs, they are first
88+
// upscaled as much as possible, Hence the left_shift parameter required here.
89+
8190
// Call CMSIS-NN kernel with precomputed parameters
8291
arm_cmsis_nn_status status = arm_elementwise_add_s8(
8392
input1_int8.const_data_ptr<int8_t>(),
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright 2025 Arm Limited and/or its affiliates.
3+
*
4+
* This source code is licensed under the BSD-style license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include "cortex_m_ops_common.h"
9+
10+
// Include CMSIS-NN headers with C linkage
11+
extern "C" {
12+
#include "arm_nnfunctions.h"
13+
}
14+
15+
namespace cortex_m {
16+
namespace native {
17+
namespace {
18+
19+
constexpr int32_t kInt8ActivationMin = std::numeric_limits<int8_t>::min();
20+
constexpr int32_t kInt8ActivationMax = std::numeric_limits<int8_t>::max();
21+
22+
} // namespace
23+
24+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
25+
26+
Tensor& quantized_mul_out(
27+
KernelRuntimeContext& context,
28+
const Tensor& input1_int8,
29+
const Scalar& input1_zero_point,
30+
const Tensor& input2_int8,
31+
const Scalar& input2_zero_point,
32+
const Scalar& output_zero_point,
33+
const Scalar& output_multiplier,
34+
const Scalar& output_shift,
35+
Tensor& out) {
36+
// Validate tensor types and quantization parameters
37+
validate_cmsis_nn_tensor_requirements(input1_int8, input2_int8, out);
38+
39+
const Scalar kIdentityMultiplier(/*value=*/1);
40+
const Scalar kZeroShift(/*value=*/0);
41+
validate_quantization_params(
42+
input1_zero_point,
43+
kIdentityMultiplier,
44+
kZeroShift,
45+
input2_zero_point,
46+
kIdentityMultiplier,
47+
kZeroShift,
48+
output_zero_point,
49+
output_multiplier,
50+
output_shift,
51+
out);
52+
53+
// Extract quantization parameters
54+
const int32_t zp1 = extractScalarToInt32(input1_zero_point);
55+
const int32_t zp2 = extractScalarToInt32(input2_zero_point);
56+
const int32_t out_zp = extractScalarToInt32(output_zero_point);
57+
const int32_t output_mult = extractScalarToInt32(output_multiplier);
58+
const int32_t output_shift_val = extractScalarToInt32(output_shift);
59+
60+
// Note 1: The CMSIS-NN kernel implementation uses offsets which are always
61+
// added to the data, whereas zero_points are subtracted when dequantizing
62+
// (for the inputs) and added when quantizing (for the output). Hence the
63+
// negative signs required here.
64+
65+
// Note 2: The following rewrite is used
66+
// yq = y / scale_out + zp_out
67+
// y = x_1*x_2
68+
// x_i = scale_in_i * (xq_i - xq_i), i = 1, 2
69+
// ==>
70+
// yq = (xq_1 - zp_in1) * (xq_2 - zp_in_2) * effective_scale + zp_out
71+
// where
72+
// effective_scale = (scale_in1 * scale_in2 / scale_out)
73+
// Hence no input quantization params required here.
74+
75+
// Call CMSIS-NN elementwise multiply kernel
76+
arm_cmsis_nn_status status = arm_elementwise_mul_s8(
77+
input1_int8.const_data_ptr<int8_t>(),
78+
input2_int8.const_data_ptr<int8_t>(),
79+
-static_cast<int32_t>(zp1),
80+
-static_cast<int32_t>(zp2),
81+
out.mutable_data_ptr<int8_t>(),
82+
static_cast<int32_t>(out_zp),
83+
output_mult,
84+
output_shift_val,
85+
kInt8ActivationMin,
86+
kInt8ActivationMax,
87+
static_cast<int32_t>(out.numel()));
88+
89+
if (status != ARM_CMSIS_NN_SUCCESS) {
90+
ET_LOG(
91+
Error,
92+
"quantized_mul_out: arm_elementwise_mul_s8 failed with status [%d]",
93+
status);
94+
context.fail(Error::Internal);
95+
return out;
96+
}
97+
98+
return out;
99+
}
100+
101+
} // namespace native
102+
} // namespace cortex_m

backends/cortex_m/ops/operators.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ def quantized_add_meta(
138138
output_multiplier: int,
139139
output_shift: int,
140140
) -> torch.Tensor:
141+
assert self.shape == other.shape, (
142+
"Cortex-M quantized_mul: broadcasting is not yet supported — "
143+
f"got self.shape={self.shape}, other.shape={other.shape}"
144+
)
141145
broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape)
142146
return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device)
143147

@@ -156,6 +160,10 @@ def quantized_add_impl(
156160
output_multiplier: int,
157161
output_shift: int,
158162
) -> torch.Tensor:
163+
assert self.shape == other.shape, (
164+
"Cortex-M quantized_mul: broadcasting is not yet supported — "
165+
f"got self.shape={self.shape}, other.shape={other.shape}"
166+
)
159167
self_shifted = (self.to(torch.int32) - self_zero_point) << SHIFT_INT8
160168
self_fp = requantize_cmsis(self_shifted, self_multiplier, self_shift)
161169

@@ -168,6 +176,68 @@ def quantized_add_impl(
168176
return result
169177

170178

179+
# ===================================================================
180+
# QUANTIZED MUL OPERATION DEFINITION
181+
# ===================================================================
182+
lib.define(
183+
"quantized_mul("
184+
"Tensor self, Scalar self_zero_point, "
185+
"Tensor other, Scalar other_zero_point, "
186+
"Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor"
187+
)
188+
lib.define(
189+
"quantized_mul.out("
190+
"Tensor self, Scalar self_zero_point, "
191+
"Tensor other, Scalar other_zero_point, "
192+
"Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, "
193+
"*, Tensor(a!) out) -> Tensor(a!)"
194+
)
195+
196+
197+
@register_fake("cortex_m::quantized_mul")
198+
def quantized_mul_meta(
199+
self: torch.Tensor,
200+
self_zero_point: int,
201+
other: torch.Tensor,
202+
other_zero_point: int,
203+
output_zero_point: int,
204+
output_multiplier: int,
205+
output_shift: int,
206+
) -> torch.Tensor:
207+
# Broadcast to output shape
208+
assert self.shape == other.shape, (
209+
"Cortex-M quantized_mul: broadcasting is not yet supported — "
210+
f"got self.shape={self.shape}, other.shape={other.shape}"
211+
)
212+
broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape)
213+
return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device)
214+
215+
216+
@impl(lib, "quantized_mul", "CompositeExplicitAutograd")
217+
def quantized_mul_impl(
218+
self: torch.Tensor,
219+
self_zero_point: int,
220+
other: torch.Tensor,
221+
other_zero_point: int,
222+
output_zero_point: int,
223+
output_multiplier: int,
224+
output_shift: int,
225+
) -> torch.Tensor:
226+
# CMSIS-NN kernel multiplies raw int8 tensors (after zero-point offset) and
227+
# only uses the output multiplier/shift for rescaling. Mirror that here to
228+
# keep the composite implementation numerically aligned with the backend.
229+
assert self.shape == other.shape, (
230+
"Cortex-M quantized_mul: broadcasting is not yet supported — "
231+
f"got self.shape={self.shape}, other.shape={other.shape}"
232+
)
233+
self_int = self.to(torch.int32) - self_zero_point
234+
other_int = other.to(torch.int32) - other_zero_point
235+
result_fp = self_int * other_int
236+
result_quantized = requantize_cmsis(result_fp, output_multiplier, output_shift)
237+
result = torch.clamp(result_quantized + output_zero_point, -128, 127).to(torch.int8)
238+
return result
239+
240+
171241
# ===================================================================
172242
# QUANTIZED LINEAR OPERATION DEFINITION
173243
# ===================================================================

backends/cortex_m/ops/operators.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,14 @@
2323
- arg_meta: null
2424
kernel_name: cortex_m::quantized_add_out
2525

26+
- func: cortex_m::quantized_mul.out(Tensor self, Scalar self_zero_point, Tensor other, Scalar other_zero_point, Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, *, Tensor(a!) out) -> Tensor(a!)
27+
variants: function
28+
kernels:
29+
- arg_meta: null
30+
kernel_name: cortex_m::quantized_mul_out
31+
2632
- func: cortex_m::quantized_linear.out(Tensor input, Tensor weights, Tensor? bias, Tensor? kernel_sum, Scalar input_offset, Scalar filter_offset, Scalar output_offset, int[] requantize_multipliers, int[] requantize_shifts, Scalar activation_max, Scalar activation_min, *, Tensor(a!) out) -> Tensor(a!)
2733
variants: function
2834
kernels:
2935
- arg_meta: null
30-
kernel_name: cortex_m::quantized_linear_out
36+
kernel_name: cortex_m::quantized_linear_out

backends/cortex_m/passes/cortex_m_pass_manager.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66

77
from executorch.backends.arm._passes import (
8-
DecorateFp32toInt32CastingPass,
98
FoldAndAnnotateQParamsPass,
109
ScalarsToAttributePass,
1110
)
@@ -29,7 +28,6 @@ class CortexMPassManager(XNNPACKPassManager):
2928
ReplaceQuantNodesPass,
3029
QuantizedOpFusionPass,
3130
QuantizedLinearFusionPass,
32-
DecorateFp32toInt32CastingPass,
3331
]
3432

3533
pass_list_transform_for_annotation: list[ExportPass] = [

backends/cortex_m/passes/passes_utils.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,32 @@ def requantize_cmsis(
5050
multiplier: int,
5151
shift: int,
5252
) -> torch.Tensor:
53-
"""
54-
Simulate CMSIS-NN fixed-point requantization:
55-
result = round(tensor * multiplier / (2 ^ shift))
56-
with double rounding
57-
"""
58-
multiplied = torch.round(tensor.to(torch.int64) * multiplier)
59-
shifted = torch.round(multiplied / (2 ** (31 - shift)))
60-
return shifted.to(torch.int32)
53+
"""Simulate CMSIS-NN's arm_nn_requantize helper."""
54+
55+
tensor_64 = tensor.to(torch.int64)
56+
left_shift = max(shift, 0)
57+
right_shift = max(-shift, 0)
58+
59+
# Equivalent to val * (1 << LEFT_SHIFT(shift))
60+
value = tensor_64 << left_shift
61+
62+
# arm_nn_doubling_high_mult_no_sat(value, multiplier)
63+
product = value * int(multiplier)
64+
product = product + (1 << 30)
65+
result = product >> 31
66+
67+
if right_shift:
68+
remainder_mask = (1 << right_shift) - 1
69+
remainder = torch.bitwise_and(result, remainder_mask)
70+
result = result >> right_shift
71+
threshold = remainder_mask >> 1
72+
threshold_tensor = torch.full_like(result, threshold, dtype=torch.int64)
73+
threshold_tensor = torch.where(
74+
result < 0, threshold_tensor + 1, threshold_tensor
75+
)
76+
result = result + torch.where(remainder > threshold_tensor, 1, 0)
77+
78+
return result.to(torch.int32)
6179

6280

6381
def extract_scalar_value(node_arg) -> float:

backends/cortex_m/passes/quantized_op_fusion_pass.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,31 @@ def _get_add_replacement(self, args, meta):
6464

6565
return exir_ops.edge.cortex_m.quantized_add.default, args
6666

67+
def _get_mul_replacement(self, args, meta) -> int:
68+
69+
# Extract values
70+
scale1 = meta["input_qparams"][0].scale
71+
zero_point1 = meta["input_qparams"][0].zp
72+
scale2 = meta["input_qparams"][1].scale
73+
zero_point2 = meta["input_qparams"][1].zp
74+
output_scale = meta["output_qparams"][0].scale
75+
output_zero_point = meta["output_qparams"][0].zp
76+
77+
scale_factor = (scale1 * scale2) / output_scale
78+
output_mult, output_shift = quantize_multiplier_aot(scale_factor)
79+
80+
args = (
81+
args[0],
82+
zero_point1,
83+
args[1],
84+
zero_point2,
85+
output_zero_point,
86+
output_mult,
87+
output_shift,
88+
)
89+
90+
return exir_ops.edge.cortex_m.quantized_mul.default, args
91+
6792
def call_operator(
6893
self,
6994
op: EdgeOpOverload,
@@ -80,6 +105,8 @@ def call_operator(
80105
match op:
81106
case exir_ops.edge.aten.add.Tensor:
82107
op, args = self._get_add_replacement(args, meta)
108+
case exir_ops.edge.aten.mul.Tensor:
109+
op, args = self._get_mul_replacement(args, meta)
83110
case _:
84111
pass
85112

backends/cortex_m/quantizer/operator_configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# ----------------- OPERATOR PATTERN PRESETS -----------------
1818
BINARY_OP_PATTERNS = [
1919
[torch.ops.aten.add.Tensor],
20+
[torch.ops.aten.mul.Tensor],
2021
]
2122

2223
LINEAR_OP_PATTERNS = [

backends/cortex_m/quantizer/quantizer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66

77
from typing import Callable, List, Optional
88

9-
import torch
10-
119
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
1210

1311
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
1412
from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager
1513
from executorch.backends.cortex_m.quantizer.operator_configs import (
14+
BINARY_OP_PATTERNS,
1615
INT8_BINARY_OPS_OPERATOR_CONFIG,
1716
INT8_LINEAR_OPERATOR_CONFIG,
1817
)
@@ -37,7 +36,7 @@ def broadcasting_filter(self, node: Optional[Node]) -> bool:
3736
"""
3837
if node is None:
3938
return False
40-
if node.target not in [torch.ops.aten.add.Tensor]:
39+
if [node.target] not in BINARY_OP_PATTERNS:
4140
return False
4241

4342
if len(node.all_input_nodes) == 2:

0 commit comments

Comments
 (0)