Skip to content

Commit f169ca4

Browse files
authored
[Executorch] parallelize op_choose_qparams
Differential Revision: D84962234 Pull Request resolved: #15607
1 parent b370f31 commit f169ca4

File tree

3 files changed

+133
-11
lines changed

3 files changed

+133
-11
lines changed

kernels/quantized/cpu/op_choose_qparams.cpp

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/kernels/portable/cpu/vec_ops.h>
1010
#include <executorch/runtime/kernel/kernel_includes.h>
11+
#include <executorch/runtime/kernel/thread_parallel_interface.h>
1112
#include <algorithm>
1213
#include <cinttypes>
1314
#include <cmath>
@@ -202,17 +203,42 @@ void choose_qparams_per_token(
202203
num_tokens *= input.size(i);
203204
}
204205
auto token_dim_size = input.size(input.dim() - 1);
205-
for (auto i = 0; i < num_tokens; i++) {
206-
// vec_minf uses std::min_element. Check if it actually
207-
// gets vectorized.
208-
float min = torch::executor::vec_minf(x_fp32, token_dim_size);
209-
float max = torch::executor::vec_maxf(x_fp32, token_dim_size);
210-
double scale;
211-
int32_t zero_point;
212-
calculate_scale_and_zero_point(min, max, qmin, qmax, scale, zero_point);
213-
scale_out.mutable_data_ptr<double>()[i] = scale;
214-
zero_point_out.mutable_data_ptr<int64_t>()[i] = zero_point;
215-
x_fp32 += token_dim_size;
206+
207+
const int64_t total_elements = num_tokens * token_dim_size;
208+
constexpr int64_t MIN_ELEMENTS_FOR_PARALLEL = 512;
209+
const bool use_parallel = total_elements >= MIN_ELEMENTS_FOR_PARALLEL;
210+
211+
if (use_parallel) {
212+
auto* scale_data = scale_out.mutable_data_ptr<double>();
213+
auto* zero_point_data = zero_point_out.mutable_data_ptr<int64_t>();
214+
215+
::executorch::extension::parallel_for(
216+
0, num_tokens, 1, [&](const int64_t begin, const int64_t end) {
217+
for (int64_t i = begin; i < end; i++) {
218+
const float* token_data = x_fp32 + i * token_dim_size;
219+
float min = torch::executor::vec_minf(token_data, token_dim_size);
220+
float max = torch::executor::vec_maxf(token_data, token_dim_size);
221+
double scale;
222+
int32_t zero_point;
223+
calculate_scale_and_zero_point(
224+
min, max, qmin, qmax, scale, zero_point);
225+
scale_data[i] = scale;
226+
zero_point_data[i] = zero_point;
227+
}
228+
});
229+
} else {
230+
for (auto i = 0; i < num_tokens; i++) {
231+
// vec_minf uses std::min_element. Check if it actually
232+
// gets vectorized.
233+
float min = torch::executor::vec_minf(x_fp32, token_dim_size);
234+
float max = torch::executor::vec_maxf(x_fp32, token_dim_size);
235+
double scale;
236+
int32_t zero_point;
237+
calculate_scale_and_zero_point(min, max, qmin, qmax, scale, zero_point);
238+
scale_out.mutable_data_ptr<double>()[i] = scale;
239+
zero_point_out.mutable_data_ptr<int64_t>()[i] = zero_point;
240+
x_fp32 += token_dim_size;
241+
}
216242
}
217243
}
218244
} // namespace

kernels/quantized/cpu/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ _QUANT_OPS = (
99
name = "op_choose_qparams",
1010
deps = [
1111
"//executorch/kernels/portable/cpu:vec_ops",
12+
"//executorch/extension/threadpool:threadpool",
1213
],
1314
),
1415
op_target(

kernels/quantized/test/op_choose_qparams_test.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <executorch/test/utils/DeathTest.h>
1616

1717
#include <gtest/gtest.h>
18+
#include <cmath>
1819
#include <limits>
1920

2021
using namespace ::testing;
@@ -163,3 +164,97 @@ TEST(OpChooseQparamsPerTokenAsymmetricTensorOutTest, DynamicShapeFloat) {
163164
EXPECT_TENSOR_CLOSE_WITH_TOL(scale_out, new_expected_scale, 1e-4, 1e-4);
164165
EXPECT_TENSOR_EQ(zero_point_out, new_expected_zero_point);
165166
}
167+
168+
TEST(
169+
OpChooseQparamsPerTokenAsymmetricTensorOutTest,
170+
LargeInputParallelization) {
171+
et_pal_init();
172+
TensorFactory<ScalarType::Float> tf_float;
173+
TensorFactory<ScalarType::Double> tf_double;
174+
TensorFactory<ScalarType::Long> tf_long;
175+
176+
// Create input with 8 tokens x 128 elements per token = 1024 total elements
177+
// This exceeds the MIN_ELEMENTS_FOR_PARALLEL threshold of 512
178+
const int num_tokens = 8;
179+
const int token_size = 128;
180+
std::vector<float> input_data(num_tokens * token_size);
181+
182+
// Generate test data with known min/max per token for easier verification
183+
std::vector<float> expected_min(num_tokens);
184+
std::vector<float> expected_max(num_tokens);
185+
186+
for (int i = 0; i < num_tokens; i++) {
187+
float token_min = -1.0f * (i + 1);
188+
float token_max = 2.0f * (i + 1);
189+
expected_min[i] = token_min;
190+
expected_max[i] = token_max;
191+
192+
for (int j = 0; j < token_size; j++) {
193+
// Linearly interpolate between min and max
194+
float t = j / static_cast<float>(token_size - 1);
195+
input_data[i * token_size + j] = token_min + t * (token_max - token_min);
196+
}
197+
}
198+
199+
Tensor input = tf_float.make({num_tokens, token_size}, input_data);
200+
Tensor scale_out = tf_double.zeros({num_tokens, 1});
201+
Tensor zero_point_out = tf_long.zeros({num_tokens, 1});
202+
203+
choose_qparams_per_token_asymmetric_out(
204+
input, ScalarType::Float, scale_out, zero_point_out);
205+
206+
// Manually calculate expected scale and zero_point using the same algorithm
207+
// as calculate_scale_and_zero_point function
208+
const int32_t qmin = -128;
209+
const int32_t qmax = 127;
210+
const float SMALL_SCALE_THRESHOLD = 6.1e-5f;
211+
212+
for (int i = 0; i < num_tokens; i++) {
213+
float min = std::min(expected_min[i], 0.0f);
214+
float max = std::max(expected_max[i], 0.0f);
215+
216+
// Calculate scale
217+
double scale = (static_cast<double>(max) - min) / (qmax - qmin);
218+
if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) {
219+
scale = 0.1;
220+
}
221+
222+
// Cut off small scale
223+
if (scale < SMALL_SCALE_THRESHOLD) {
224+
scale = SMALL_SCALE_THRESHOLD;
225+
if (min == 0.0f) {
226+
max = SMALL_SCALE_THRESHOLD * (qmax - qmin);
227+
} else if (max == 0.0f) {
228+
min = -SMALL_SCALE_THRESHOLD * (qmax - qmin);
229+
} else {
230+
float amplifier = SMALL_SCALE_THRESHOLD / scale;
231+
min *= amplifier;
232+
max *= amplifier;
233+
}
234+
}
235+
236+
// Calculate zero_point
237+
double zero_point_from_min = qmin - min / scale;
238+
double zero_point_from_max = qmax - max / scale;
239+
double zero_point_from_min_error = std::abs(qmin) - std::abs(min / scale);
240+
double zero_point_from_max_error = std::abs(qmax) - std::abs(max / scale);
241+
double initial_zero_point =
242+
zero_point_from_min_error < zero_point_from_max_error
243+
? zero_point_from_min
244+
: zero_point_from_max;
245+
246+
int32_t nudged_zero_point = 0;
247+
if (initial_zero_point < qmin) {
248+
nudged_zero_point = qmin;
249+
} else if (initial_zero_point > qmax) {
250+
nudged_zero_point = qmax;
251+
} else {
252+
nudged_zero_point =
253+
std::nearbyint(static_cast<float>(initial_zero_point));
254+
}
255+
256+
// Verify computed values match expected
257+
EXPECT_NEAR(scale_out.const_data_ptr<double>()[i], scale, 1e-6);
258+
EXPECT_EQ(zero_point_out.const_data_ptr<int64_t>()[i], nudged_zero_point);
259+
}
260+
}

0 commit comments

Comments
 (0)