File tree Expand file tree Collapse file tree 1 file changed +19
-2
lines changed Expand file tree Collapse file tree 1 file changed +19
-2
lines changed Original file line number Diff line number Diff line change @@ -888,8 +888,25 @@ void _dequant_and_store(
888888 for (int m = 0 ; m < M; ++m) {
889889 float a_scale = *(scale_a + m * ldsa);
890890 int32_t a_zp = *(zp_a + m * ldsa);
891- #pragma omp simd
892- for (int n = 0 ; n < N; ++n) {
891+ __m512 va_scale = _mm512_set1_ps (a_scale);
892+ __m512i va_zp = _mm512_set1_epi32 (a_zp);
893+ int n = 0 ;
894+ for (; n < N; n += 16 ) {
895+ __m512i va = _mm512_loadu_si512 (input + m * ld + n);
896+ __m512i vb_comp = _mm512_loadu_si512 (comp_b + n);
897+ __m512i vc = _mm512_sub_epi32 (va, _mm512_mullo_epi32 (vb_comp, va_zp));
898+ __m512 vc_f = _mm512_cvtepi32_ps (vc);
899+ __m512 vc_f_mul = _mm512_mul_ps (vc_f, va_scale);
900+ __m512 vb_s = _mm512_loadu_ps (scale_b + n);
901+ vc_f_mul = _mm512_mul_ps (vc_f_mul, vb_s);
902+ if constexpr (accum) {
903+ __m512 vo = _mm512_loadu_ps (output + m * ld + n);
904+ _mm512_storeu_ps (output + m * ld + n, _mm512_add_ps (vo, vc_f_mul));
905+ } else {
906+ _mm512_storeu_ps (output + m * ld + n, vc_f_mul);
907+ }
908+ }
909+ for (; n < N; ++n) {
893910 float dq_val =
894911 (float )(input[m * ld + n] - a_zp * comp_b[n]) * a_scale * scale_b[n];
895912 if constexpr (accum) {
You can’t perform that action at this time.
0 commit comments