|
15 | 15 | #include <executorch/test/utils/DeathTest.h> |
16 | 16 |
|
17 | 17 | #include <gtest/gtest.h> |
| 18 | +#include <cmath> |
18 | 19 | #include <limits> |
19 | 20 |
|
20 | 21 | using namespace ::testing; |
@@ -163,3 +164,97 @@ TEST(OpChooseQparamsPerTokenAsymmetricTensorOutTest, DynamicShapeFloat) { |
163 | 164 | EXPECT_TENSOR_CLOSE_WITH_TOL(scale_out, new_expected_scale, 1e-4, 1e-4); |
164 | 165 | EXPECT_TENSOR_EQ(zero_point_out, new_expected_zero_point); |
165 | 166 | } |
| 167 | + |
| 168 | +TEST( |
| 169 | + OpChooseQparamsPerTokenAsymmetricTensorOutTest, |
| 170 | + LargeInputParallelization) { |
| 171 | + et_pal_init(); |
| 172 | + TensorFactory<ScalarType::Float> tf_float; |
| 173 | + TensorFactory<ScalarType::Double> tf_double; |
| 174 | + TensorFactory<ScalarType::Long> tf_long; |
| 175 | + |
| 176 | + // Create input with 8 tokens x 128 elements per token = 1024 total elements |
| 177 | + // This exceeds the MIN_ELEMENTS_FOR_PARALLEL threshold of 512 |
| 178 | + const int num_tokens = 8; |
| 179 | + const int token_size = 128; |
| 180 | + std::vector<float> input_data(num_tokens * token_size); |
| 181 | + |
| 182 | + // Generate test data with known min/max per token for easier verification |
| 183 | + std::vector<float> expected_min(num_tokens); |
| 184 | + std::vector<float> expected_max(num_tokens); |
| 185 | + |
| 186 | + for (int i = 0; i < num_tokens; i++) { |
| 187 | + float token_min = -1.0f * (i + 1); |
| 188 | + float token_max = 2.0f * (i + 1); |
| 189 | + expected_min[i] = token_min; |
| 190 | + expected_max[i] = token_max; |
| 191 | + |
| 192 | + for (int j = 0; j < token_size; j++) { |
| 193 | + // Linearly interpolate between min and max |
| 194 | + float t = j / static_cast<float>(token_size - 1); |
| 195 | + input_data[i * token_size + j] = token_min + t * (token_max - token_min); |
| 196 | + } |
| 197 | + } |
| 198 | + |
| 199 | + Tensor input = tf_float.make({num_tokens, token_size}, input_data); |
| 200 | + Tensor scale_out = tf_double.zeros({num_tokens, 1}); |
| 201 | + Tensor zero_point_out = tf_long.zeros({num_tokens, 1}); |
| 202 | + |
| 203 | + choose_qparams_per_token_asymmetric_out( |
| 204 | + input, ScalarType::Float, scale_out, zero_point_out); |
| 205 | + |
| 206 | + // Manually calculate expected scale and zero_point using the same algorithm |
| 207 | + // as calculate_scale_and_zero_point function |
| 208 | + const int32_t qmin = -128; |
| 209 | + const int32_t qmax = 127; |
| 210 | + const float SMALL_SCALE_THRESHOLD = 6.1e-5f; |
| 211 | + |
| 212 | + for (int i = 0; i < num_tokens; i++) { |
| 213 | + float min = std::min(expected_min[i], 0.0f); |
| 214 | + float max = std::max(expected_max[i], 0.0f); |
| 215 | + |
| 216 | + // Calculate scale |
| 217 | + double scale = (static_cast<double>(max) - min) / (qmax - qmin); |
| 218 | + if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) { |
| 219 | + scale = 0.1; |
| 220 | + } |
| 221 | + |
| 222 | + // Cut off small scale |
| 223 | + if (scale < SMALL_SCALE_THRESHOLD) { |
| 224 | + scale = SMALL_SCALE_THRESHOLD; |
| 225 | + if (min == 0.0f) { |
| 226 | + max = SMALL_SCALE_THRESHOLD * (qmax - qmin); |
| 227 | + } else if (max == 0.0f) { |
| 228 | + min = -SMALL_SCALE_THRESHOLD * (qmax - qmin); |
| 229 | + } else { |
| 230 | + float amplifier = SMALL_SCALE_THRESHOLD / scale; |
| 231 | + min *= amplifier; |
| 232 | + max *= amplifier; |
| 233 | + } |
| 234 | + } |
| 235 | + |
| 236 | + // Calculate zero_point |
| 237 | + double zero_point_from_min = qmin - min / scale; |
| 238 | + double zero_point_from_max = qmax - max / scale; |
| 239 | + double zero_point_from_min_error = std::abs(qmin) - std::abs(min / scale); |
| 240 | + double zero_point_from_max_error = std::abs(qmax) - std::abs(max / scale); |
| 241 | + double initial_zero_point = |
| 242 | + zero_point_from_min_error < zero_point_from_max_error |
| 243 | + ? zero_point_from_min |
| 244 | + : zero_point_from_max; |
| 245 | + |
| 246 | + int32_t nudged_zero_point = 0; |
| 247 | + if (initial_zero_point < qmin) { |
| 248 | + nudged_zero_point = qmin; |
| 249 | + } else if (initial_zero_point > qmax) { |
| 250 | + nudged_zero_point = qmax; |
| 251 | + } else { |
| 252 | + nudged_zero_point = |
| 253 | + std::nearbyint(static_cast<float>(initial_zero_point)); |
| 254 | + } |
| 255 | + |
| 256 | + // Verify computed values match expected |
| 257 | + EXPECT_NEAR(scale_out.const_data_ptr<double>()[i], scale, 1e-6); |
| 258 | + EXPECT_EQ(zero_point_out.const_data_ptr<int64_t>()[i], nudged_zero_point); |
| 259 | + } |
| 260 | +} |
0 commit comments