11/*
22 * Copyright (c) Meta Platforms, Inc. and affiliates.
33 * All rights reserved.
4+ * Copyright 2025 Arm Limited and/or its affiliates.
45 *
56 * This source code is licensed under the BSD-style license found in the
67 * LICENSE file in the root directory of this source tree.
78 */
89
9- #include " cmsis_scratch_buffer_context.h"
1010#include " cortex_m_ops_common.h"
1111
1212extern " C" {
@@ -20,151 +20,90 @@ using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
2020Tensor& quantized_linear_out (
2121 KernelRuntimeContext& context,
2222 const Tensor& input,
23- const Scalar& input_zero_point,
24- const Scalar& input_multiplier,
25- const Scalar& input_shift,
2623 const Tensor& weights,
27- const Tensor& weight_zero_point,
28- const Tensor& weight_multiplier,
29- const Tensor& weight_shift,
3024 const torch::executor::optional<Tensor>& bias,
31- const Tensor& bias_multiplier,
32- const Tensor& bias_shift,
33- const Tensor& scratch_buffer,
34- const Scalar& output_zero_point,
35- const Scalar& in_features,
36- const Scalar& out_features,
25+ const torch::executor::optional<Tensor>& kernel_sum,
26+ const Scalar& input_offset,
27+ const Scalar& filter_offset,
28+ const Scalar& output_offset,
29+ const IntArrayRef requantize_multipliers,
30+ const IntArrayRef requantize_shifts,
31+ const Scalar& activation_max,
32+ const Scalar& activation_min,
3733 Tensor& out) {
3834 ET_LOG (Info, " quantized_linear_out: called" );
39- validate_cmsis_nn_tensor_requirements (input, weights, out);
40-
41- ET_CHECK_MSG (
42- scratch_buffer.scalar_type () == ScalarType::Char,
43- " Scratch buffer must be int8" );
44-
45- const int32_t batch_size = input.size (0 );
46- const int32_t in_feat = static_cast <int32_t >(in_features.to <int64_t >());
47- const int32_t out_feat = static_cast <int32_t >(out_features.to <int64_t >());
48- const int32_t input_zp = static_cast <int32_t >(input_zero_point.to <int64_t >());
49- const int32_t output_zp =
50- static_cast <int32_t >(output_zero_point.to <int64_t >());
51- const bool is_per_channel = (weight_zero_point.numel () > 1 );
5235
5336 const int8_t * input_data = input.const_data_ptr <int8_t >();
5437 const int8_t * weight_data = weights.const_data_ptr <int8_t >();
5538 const int32_t * bias_data =
5639 bias.has_value () ? bias.value ().const_data_ptr <int32_t >() : nullptr ;
40+ int32_t * kernel_sum_data =
41+ kernel_sum.has_value () ? kernel_sum.value ().data_ptr <int32_t >() : nullptr ;
5742 int8_t * output_data = out.mutable_data_ptr <int8_t >();
58- const int32_t * weight_zp_data = weight_zero_point.const_data_ptr <int32_t >();
59- const int32_t * weight_mult_data = weight_multiplier.const_data_ptr <int32_t >();
60- const int32_t * weight_shift_data = weight_shift.const_data_ptr <int32_t >();
61-
62- if (!validate_per_channel_quant_params (
63- weight_mult_data, weight_shift_data, out_feat)) {
64- context.fail (Error::InvalidArgument);
65- return out;
66- }
67-
68- // Initialize scratch buffer context (validates early)
69- CMSISScratchBufferContext scratch_ctx (
70- const_cast <Tensor&>(scratch_buffer), weights, weight_zero_point, bias);
7143
72- scratch_ctx.compute_kernel_sums_if_needed ();
73- cmsis_nn_context ctx = scratch_ctx.get_cmsis_ctx ();
44+ cmsis_nn_context ctx;
45+ ctx.size = 0 ; // Not used in CMSIS-NN
46+ ctx.buf = kernel_sum_data;
7447
7548 // Setup CMSIS-NN parameters
7649 cmsis_nn_fc_params fc_params;
77- fc_params.input_offset = -input_zp;
78- fc_params.output_offset = output_zp;
79- fc_params.activation .min = std::numeric_limits<int8_t >::min ();
80- fc_params.activation .max = std::numeric_limits<int8_t >::max ();
81-
82- cmsis_nn_dims input_dims = {1 , 1 , 1 , in_feat};
50+ fc_params.input_offset = static_cast <int32_t >(input_offset.to <int64_t >());
51+ fc_params.filter_offset = static_cast <int32_t >(filter_offset.to <int64_t >());
52+ fc_params.output_offset = static_cast <int32_t >(output_offset.to <int64_t >());
53+ fc_params.activation .min = static_cast <int32_t >(activation_min.to <int64_t >());
54+ fc_params.activation .max = static_cast <int32_t >(activation_max.to <int64_t >());
55+
56+ cmsis_nn_per_tensor_quant_params per_tensor_quant_params;
57+ per_tensor_quant_params.multiplier =
58+ static_cast <int32_t >(requantize_multipliers.at (0 ));
59+ per_tensor_quant_params.shift = static_cast <int32_t >(requantize_shifts.at (0 ));
60+
61+ auto in_feat = input.size (input.dim () - 1 );
62+ auto out_feat = out.size (out.dim () - 1 );
63+ auto batches = 1 ;
64+ for (size_t i = 0 ; i < input.dim () - 1 ; i++) {
65+ batches *= input.size (i);
66+ }
67+ ET_LOG (
68+ Info,
69+ " in features: %d, out_features: %d, batches: %d, kernel_sum_size: %d" ,
70+ in_feat,
71+ out_feat,
72+ batches,
73+ kernel_sum.has_value () ? kernel_sum.value ().numel () : 0 );
74+ ET_LOG (
75+ Info,
76+ " kernel_sum[0]: %d, kernel_sum[1]: %d" ,
77+ kernel_sum_data != nullptr ? kernel_sum_data[0 ] : -1 ,
78+ kernel_sum_data != nullptr ? kernel_sum_data[1 ] : -1 );
79+ cmsis_nn_dims input_dims = {batches, 1 , 1 , in_feat};
8380 cmsis_nn_dims filter_dims = {in_feat, 1 , 1 , out_feat};
8481 cmsis_nn_dims bias_dims = {1 , 1 , 1 , out_feat};
85- cmsis_nn_dims output_dims = {1 , 1 , 1 , out_feat};
86-
87- arm_cmsis_nn_status status;
88- for (int32_t b = 0 ; b < batch_size; b++) {
89- const int8_t * batch_input = input_data + b * in_feat;
90- int8_t * batch_output = output_data + b * out_feat;
91-
92- ET_CHECK_MSG (
93- batch_input != nullptr && weight_data != nullptr ,
94- " Null input pointers" );
95- ET_CHECK_MSG (in_feat > 0 && out_feat > 0 , " Invalid dimensions" );
96-
97- if (is_per_channel) {
98- cmsis_nn_per_channel_quant_params per_channel_quant_params;
99- per_channel_quant_params.multiplier =
100- const_cast <int32_t *>(weight_mult_data);
101- per_channel_quant_params.shift = const_cast <int32_t *>(weight_shift_data);
102-
103- status = arm_fully_connected_per_channel_s8 (
104- &ctx,
105- &fc_params,
106- &per_channel_quant_params,
107- &input_dims,
108- batch_input,
109- &filter_dims,
110- weight_data,
111- &bias_dims,
112- bias_data,
113- &output_dims,
114- batch_output);
115- } else {
116- fc_params.filter_offset = -weight_zp_data[0 ];
117- cmsis_nn_per_tensor_quant_params per_tensor_quant_params;
118- per_tensor_quant_params.multiplier = weight_mult_data[0 ];
119- per_tensor_quant_params.shift = weight_shift_data[0 ];
120-
121- status = arm_fully_connected_s8 (
122- &ctx,
123- &fc_params,
124- &per_tensor_quant_params,
125- &input_dims,
126- batch_input,
127- &filter_dims,
128- weight_data,
129- &bias_dims,
130- bias_data,
131- &output_dims,
132- batch_output);
133- }
134-
135- if (status != ARM_CMSIS_NN_SUCCESS) {
136- ET_LOG (
137- Error,
138- " quantized_linear_out: CMSIS-NN failed with status [%d]" ,
139- status);
140- context.fail (Error::Internal);
141- return out;
142- }
82+ cmsis_nn_dims output_dims = {batches, 1 , 1 , out_feat};
83+
84+ arm_cmsis_nn_status status = arm_fully_connected_s8 (
85+ &ctx,
86+ &fc_params,
87+ &per_tensor_quant_params,
88+ &input_dims,
89+ input_data,
90+ &filter_dims,
91+ weight_data,
92+ &bias_dims,
93+ bias_data,
94+ &output_dims,
95+ output_data);
96+
97+ if (status != ARM_CMSIS_NN_SUCCESS) {
98+ ET_LOG (
99+ Error,
100+ " quantized_linear_out: CMSIS-NN failed with status [%d]" ,
101+ status);
102+ context.fail (Error::Internal);
103+ return out;
143104 }
144- return out;
145- }
146105
147- // Functional variant (stub, not used at runtime)
148- Tensor quantized_linear (
149- KernelRuntimeContext& context,
150- const Tensor& input,
151- const Scalar& input_zero_point,
152- const Scalar& input_multiplier,
153- const Scalar& input_shift,
154- const Tensor& weights,
155- const Tensor& weight_zero_point,
156- const Tensor& weight_multiplier,
157- const Tensor& weight_shift,
158- const torch::executor::optional<Tensor>& bias,
159- const Tensor& bias_multiplier,
160- const Tensor& bias_shift,
161- const Tensor& scratch_buffer,
162- const Scalar& output_zero_point,
163- const Scalar& in_features,
164- const Scalar& out_features) {
165- ET_LOG (Info, " quantized_linear: called" );
166- assert (false );
167- return const_cast <Tensor&>(input);
106+ return out;
168107}
169108
170109} // namespace native
0 commit comments