Skip to content

Commit 9843222

Browse files
Cortex_m backend: Simplify add + linear fusion passes (#15526)
Reuses the FoldAndAnnotateQParamsPass from the Arm backend to greatly simplify the logic for fusing the ops. Additionally updates the linear kernel to be numerically correct and computes the kernel_sum aot in the quantized_linear_fusion pass. Note that since this replaces the bias node it typically causes no extra memory usage. Updates the Linear tests to mirror this, including removing the various matmul tests. Since the linear is handled as a separate op rather than a particular type of matmul these tests are not related anymore. Removes unnecessary stub definitions in operators.py, operators.yaml and op_quantized_linear.cpp Leaving a few TODO:s since the patch is large already. Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent ad27841 commit 9843222

File tree

9 files changed

+368
-1257
lines changed

9 files changed

+368
-1257
lines changed

backends/cortex_m/ops/cortex_m_ops_common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ using Tensor = torch::executor::Tensor;
2222
using ScalarType = executorch::aten::ScalarType;
2323
using Scalar = torch::executor::Scalar;
2424
using Error = executorch::runtime::Error;
25+
using IntArrayRef = executorch::aten::ArrayRef<int64_t>;
2526

2627
// From arm_nn_math_types.h
2728
#define ARM_NN_Q31_MAX ((int32_t)(0x7FFFFFFFL))

backends/cortex_m/ops/op_quantized_linear.cpp

Lines changed: 67 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
/*
22
* Copyright (c) Meta Platforms, Inc. and affiliates.
33
* All rights reserved.
4+
* Copyright 2025 Arm Limited and/or its affiliates.
45
*
56
* This source code is licensed under the BSD-style license found in the
67
* LICENSE file in the root directory of this source tree.
78
*/
89

9-
#include "cmsis_scratch_buffer_context.h"
1010
#include "cortex_m_ops_common.h"
1111

1212
extern "C" {
@@ -20,151 +20,90 @@ using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
2020
Tensor& quantized_linear_out(
2121
KernelRuntimeContext& context,
2222
const Tensor& input,
23-
const Scalar& input_zero_point,
24-
const Scalar& input_multiplier,
25-
const Scalar& input_shift,
2623
const Tensor& weights,
27-
const Tensor& weight_zero_point,
28-
const Tensor& weight_multiplier,
29-
const Tensor& weight_shift,
3024
const torch::executor::optional<Tensor>& bias,
31-
const Tensor& bias_multiplier,
32-
const Tensor& bias_shift,
33-
const Tensor& scratch_buffer,
34-
const Scalar& output_zero_point,
35-
const Scalar& in_features,
36-
const Scalar& out_features,
25+
const torch::executor::optional<Tensor>& kernel_sum,
26+
const Scalar& input_offset,
27+
const Scalar& filter_offset,
28+
const Scalar& output_offset,
29+
const IntArrayRef requantize_multipliers,
30+
const IntArrayRef requantize_shifts,
31+
const Scalar& activation_max,
32+
const Scalar& activation_min,
3733
Tensor& out) {
3834
ET_LOG(Info, "quantized_linear_out: called");
39-
validate_cmsis_nn_tensor_requirements(input, weights, out);
40-
41-
ET_CHECK_MSG(
42-
scratch_buffer.scalar_type() == ScalarType::Char,
43-
"Scratch buffer must be int8");
44-
45-
const int32_t batch_size = input.size(0);
46-
const int32_t in_feat = static_cast<int32_t>(in_features.to<int64_t>());
47-
const int32_t out_feat = static_cast<int32_t>(out_features.to<int64_t>());
48-
const int32_t input_zp = static_cast<int32_t>(input_zero_point.to<int64_t>());
49-
const int32_t output_zp =
50-
static_cast<int32_t>(output_zero_point.to<int64_t>());
51-
const bool is_per_channel = (weight_zero_point.numel() > 1);
5235

5336
const int8_t* input_data = input.const_data_ptr<int8_t>();
5437
const int8_t* weight_data = weights.const_data_ptr<int8_t>();
5538
const int32_t* bias_data =
5639
bias.has_value() ? bias.value().const_data_ptr<int32_t>() : nullptr;
40+
int32_t* kernel_sum_data =
41+
kernel_sum.has_value() ? kernel_sum.value().data_ptr<int32_t>() : nullptr;
5742
int8_t* output_data = out.mutable_data_ptr<int8_t>();
58-
const int32_t* weight_zp_data = weight_zero_point.const_data_ptr<int32_t>();
59-
const int32_t* weight_mult_data = weight_multiplier.const_data_ptr<int32_t>();
60-
const int32_t* weight_shift_data = weight_shift.const_data_ptr<int32_t>();
61-
62-
if (!validate_per_channel_quant_params(
63-
weight_mult_data, weight_shift_data, out_feat)) {
64-
context.fail(Error::InvalidArgument);
65-
return out;
66-
}
67-
68-
// Initialize scratch buffer context (validates early)
69-
CMSISScratchBufferContext scratch_ctx(
70-
const_cast<Tensor&>(scratch_buffer), weights, weight_zero_point, bias);
7143

72-
scratch_ctx.compute_kernel_sums_if_needed();
73-
cmsis_nn_context ctx = scratch_ctx.get_cmsis_ctx();
44+
cmsis_nn_context ctx;
45+
ctx.size = 0; // Not used in CMSIS-NN
46+
ctx.buf = kernel_sum_data;
7447

7548
// Setup CMSIS-NN parameters
7649
cmsis_nn_fc_params fc_params;
77-
fc_params.input_offset = -input_zp;
78-
fc_params.output_offset = output_zp;
79-
fc_params.activation.min = std::numeric_limits<int8_t>::min();
80-
fc_params.activation.max = std::numeric_limits<int8_t>::max();
81-
82-
cmsis_nn_dims input_dims = {1, 1, 1, in_feat};
50+
fc_params.input_offset = static_cast<int32_t>(input_offset.to<int64_t>());
51+
fc_params.filter_offset = static_cast<int32_t>(filter_offset.to<int64_t>());
52+
fc_params.output_offset = static_cast<int32_t>(output_offset.to<int64_t>());
53+
fc_params.activation.min = static_cast<int32_t>(activation_min.to<int64_t>());
54+
fc_params.activation.max = static_cast<int32_t>(activation_max.to<int64_t>());
55+
56+
cmsis_nn_per_tensor_quant_params per_tensor_quant_params;
57+
per_tensor_quant_params.multiplier =
58+
static_cast<int32_t>(requantize_multipliers.at(0));
59+
per_tensor_quant_params.shift = static_cast<int32_t>(requantize_shifts.at(0));
60+
61+
auto in_feat = input.size(input.dim() - 1);
62+
auto out_feat = out.size(out.dim() - 1);
63+
auto batches = 1;
64+
for (size_t i = 0; i < input.dim() - 1; i++) {
65+
batches *= input.size(i);
66+
}
67+
ET_LOG(
68+
Info,
69+
"in features: %d, out_features: %d, batches: %d, kernel_sum_size: %d",
70+
in_feat,
71+
out_feat,
72+
batches,
73+
kernel_sum.has_value() ? kernel_sum.value().numel() : 0);
74+
ET_LOG(
75+
Info,
76+
"kernel_sum[0]: %d, kernel_sum[1]: %d",
77+
kernel_sum_data != nullptr ? kernel_sum_data[0] : -1,
78+
kernel_sum_data != nullptr ? kernel_sum_data[1] : -1);
79+
cmsis_nn_dims input_dims = {batches, 1, 1, in_feat};
8380
cmsis_nn_dims filter_dims = {in_feat, 1, 1, out_feat};
8481
cmsis_nn_dims bias_dims = {1, 1, 1, out_feat};
85-
cmsis_nn_dims output_dims = {1, 1, 1, out_feat};
86-
87-
arm_cmsis_nn_status status;
88-
for (int32_t b = 0; b < batch_size; b++) {
89-
const int8_t* batch_input = input_data + b * in_feat;
90-
int8_t* batch_output = output_data + b * out_feat;
91-
92-
ET_CHECK_MSG(
93-
batch_input != nullptr && weight_data != nullptr,
94-
"Null input pointers");
95-
ET_CHECK_MSG(in_feat > 0 && out_feat > 0, "Invalid dimensions");
96-
97-
if (is_per_channel) {
98-
cmsis_nn_per_channel_quant_params per_channel_quant_params;
99-
per_channel_quant_params.multiplier =
100-
const_cast<int32_t*>(weight_mult_data);
101-
per_channel_quant_params.shift = const_cast<int32_t*>(weight_shift_data);
102-
103-
status = arm_fully_connected_per_channel_s8(
104-
&ctx,
105-
&fc_params,
106-
&per_channel_quant_params,
107-
&input_dims,
108-
batch_input,
109-
&filter_dims,
110-
weight_data,
111-
&bias_dims,
112-
bias_data,
113-
&output_dims,
114-
batch_output);
115-
} else {
116-
fc_params.filter_offset = -weight_zp_data[0];
117-
cmsis_nn_per_tensor_quant_params per_tensor_quant_params;
118-
per_tensor_quant_params.multiplier = weight_mult_data[0];
119-
per_tensor_quant_params.shift = weight_shift_data[0];
120-
121-
status = arm_fully_connected_s8(
122-
&ctx,
123-
&fc_params,
124-
&per_tensor_quant_params,
125-
&input_dims,
126-
batch_input,
127-
&filter_dims,
128-
weight_data,
129-
&bias_dims,
130-
bias_data,
131-
&output_dims,
132-
batch_output);
133-
}
134-
135-
if (status != ARM_CMSIS_NN_SUCCESS) {
136-
ET_LOG(
137-
Error,
138-
"quantized_linear_out: CMSIS-NN failed with status [%d]",
139-
status);
140-
context.fail(Error::Internal);
141-
return out;
142-
}
82+
cmsis_nn_dims output_dims = {batches, 1, 1, out_feat};
83+
84+
arm_cmsis_nn_status status = arm_fully_connected_s8(
85+
&ctx,
86+
&fc_params,
87+
&per_tensor_quant_params,
88+
&input_dims,
89+
input_data,
90+
&filter_dims,
91+
weight_data,
92+
&bias_dims,
93+
bias_data,
94+
&output_dims,
95+
output_data);
96+
97+
if (status != ARM_CMSIS_NN_SUCCESS) {
98+
ET_LOG(
99+
Error,
100+
"quantized_linear_out: CMSIS-NN failed with status [%d]",
101+
status);
102+
context.fail(Error::Internal);
103+
return out;
143104
}
144-
return out;
145-
}
146105

147-
// Functional variant (stub, not used at runtime)
148-
Tensor quantized_linear(
149-
KernelRuntimeContext& context,
150-
const Tensor& input,
151-
const Scalar& input_zero_point,
152-
const Scalar& input_multiplier,
153-
const Scalar& input_shift,
154-
const Tensor& weights,
155-
const Tensor& weight_zero_point,
156-
const Tensor& weight_multiplier,
157-
const Tensor& weight_shift,
158-
const torch::executor::optional<Tensor>& bias,
159-
const Tensor& bias_multiplier,
160-
const Tensor& bias_shift,
161-
const Tensor& scratch_buffer,
162-
const Scalar& output_zero_point,
163-
const Scalar& in_features,
164-
const Scalar& out_features) {
165-
ET_LOG(Info, "quantized_linear: called");
166-
assert(false);
167-
return const_cast<Tensor&>(input);
106+
return out;
168107
}
169108

170109
} // namespace native

0 commit comments

Comments
 (0)