Skip to content

Commit 21c1cb8

Browse files
hexagon: introduce fastdiv, fix test-backend-ops for ADD/SUB/MUL
Co-authored-by: chraac <chraac@gmail.com>
1 parent 40d7e7d commit 21c1cb8

File tree

4 files changed

+89
-29
lines changed

4 files changed

+89
-29
lines changed

ggml/src/ggml-hexagon/htp/binary-ops.c

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ static hvx_elemwise_f32_func func_table_HVX[] = { hvx_mul_f32, hvx_add_f32,
3434
static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f32_opt, hvx_sub_f32_opt };
3535

3636
#define htp_binary_preamble \
37+
const struct htp_tensor * src0 = &octx->src0; \
38+
const struct htp_tensor * src1 = &octx->src1; \
39+
const struct htp_tensor * src2 = &octx->src2; \
40+
struct htp_tensor * dst = &octx->dst; \
41+
\
3742
const uint32_t ne00 = src0->ne[0]; \
3843
const uint32_t ne01 = src0->ne[1]; \
3944
const uint32_t ne02 = src0->ne[2]; \
@@ -62,16 +67,15 @@ static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f
6267
const uint32_t nb0 = dst->nb[0]; \
6368
const uint32_t nb1 = dst->nb[1]; \
6469
const uint32_t nb2 = dst->nb[2]; \
65-
const uint32_t nb3 = dst->nb[3];
66-
67-
static void binary_job_f32_per_thread(const struct htp_tensor * src0,
68-
const struct htp_tensor * src1,
69-
struct htp_tensor * dst,
70-
uint8_t * spad_data,
71-
uint32_t nth,
72-
uint32_t ith,
73-
uint32_t src0_nrows_per_thread,
74-
enum htp_op op) {
70+
const uint32_t nb3 = dst->nb[3]; \
71+
\
72+
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
73+
74+
static void binary_job_f32_per_thread(struct htp_ops_context * octx,
75+
uint8_t * spad_data,
76+
uint32_t nth,
77+
uint32_t ith,
78+
enum htp_op op) {
7579
htp_binary_preamble;
7680

7781
const size_t src0_row_size = nb01;
@@ -113,10 +117,18 @@ static void binary_job_f32_per_thread(const struct htp_tensor * src0,
113117
uint8_t * restrict dst_ptr = (uint8_t *) dst->data + (src0_start_row * dst_row_size);
114118

115119
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
116-
const uint8_t * restrict src1_ptr = NULL;
117120

121+
const uint32_t ne0201 = ne02 * ne01;
118122
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
119-
src1_ptr = data_src1 + (ir % src1_nrows) * src1_row_size;
123+
const uint32_t i03 = fastdiv(ir, &octx->src0_div21);
124+
const uint32_t i02 = fastdiv(ir - i03 * ne0201, &octx->src0_div1);
125+
const uint32_t i01 = (ir - i03 * ne0201 - i02 * ne01);
126+
127+
const uint32_t i13 = fastmodulo(i03, ne13, &octx->src1_div3);
128+
const uint32_t i12 = fastmodulo(i02, ne12, &octx->src1_div2);
129+
const uint32_t i11 = fastmodulo(i01, ne11, &octx->src1_div1);
130+
131+
const uint8_t * restrict src1_ptr = data_src1 + i13 * nb13 + i12 * nb12 + i11 * src1_row_size;
120132

121133
if (ir + 1 < src0_end_row) {
122134
htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
@@ -149,15 +161,11 @@ static void binary_job_f32_per_thread(const struct htp_tensor * src0,
149161
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
150162
}
151163

152-
static void binary_add_id_job_f32_per_thread(const struct htp_tensor * src0,
153-
const struct htp_tensor * src1,
154-
const struct htp_tensor * src2,
155-
struct htp_tensor * dst,
156-
uint8_t * spad_data,
157-
uint32_t nth,
158-
uint32_t ith,
159-
uint32_t src0_nrows_per_thread,
160-
hvx_elemwise_f32_func func_HVX) {
164+
static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx,
165+
uint8_t * spad_data,
166+
uint32_t nth,
167+
uint32_t ith,
168+
hvx_elemwise_f32_func func_HVX) {
161169
htp_binary_preamble;
162170

163171
const size_t src0_row_size = nb01;
@@ -234,13 +242,11 @@ static void binary_job_dispatcher_f32(unsigned int n, unsigned int i, void * dat
234242
case HTP_OP_MUL:
235243
case HTP_OP_ADD:
236244
case HTP_OP_SUB:
237-
binary_job_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->src1_spad.data, n, i,
238-
octx->src0_nrows_per_thread, octx->op);
245+
binary_job_f32_per_thread(octx, octx->src1_spad.data, n, i, octx->op);
239246
break;
240247

241248
case HTP_OP_ADD_ID:
242-
binary_add_id_job_f32_per_thread(&octx->src0, &octx->src1, &octx->src2, &octx->dst, octx->src0_spad.data, n,
243-
i, octx->src0_nrows_per_thread, hvx_add_f32);
249+
binary_add_id_job_f32_per_thread(octx, octx->src0_spad.data, n, i, hvx_add_f32);
244250
break;
245251

246252
default:
@@ -321,6 +327,16 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
321327

322328
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
323329

330+
octx->src0_div21 = init_fastdiv_values(src0->ne[2] * src0->ne[1]);
331+
octx->src0_div3 = init_fastdiv_values(src0->ne[3]);
332+
octx->src0_div2 = init_fastdiv_values(src0->ne[2]);
333+
octx->src0_div1 = init_fastdiv_values(src0->ne[1]);
334+
335+
octx->src1_div21 = init_fastdiv_values(src1->ne[2] * src1->ne[1]);
336+
octx->src1_div3 = init_fastdiv_values(src1->ne[3]);
337+
octx->src1_div2 = init_fastdiv_values(src1->ne[2]);
338+
octx->src1_div1 = init_fastdiv_values(src1->ne[1]);
339+
324340
worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, octx, n_jobs);
325341
}
326342

ggml/src/ggml-hexagon/htp/htp-msg.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,10 @@ static const char * htp_type_name(uint32_t t) {
119119
#define HTP_MAX_DIMS 4
120120

121121
struct htp_tensor {
122-
uint32_t data; // Buffer offset in the messages, and data pointer on the NSP
123-
uint32_t type; // Data type
124-
uint32_t ne[HTP_MAX_DIMS]; // Number of elements
125-
uint32_t nb[HTP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor)
122+
uint32_t data; // Buffer offset in the messages, and data pointer on the NSP
123+
uint32_t type; // Data type
124+
uint32_t ne[HTP_MAX_DIMS]; // Number of elements
125+
uint32_t nb[HTP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor)
126126
};
127127

128128
#define HTP_MAX_OP_PARAMS 64

ggml/src/ggml-hexagon/htp/htp-ops.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "htp-ctx.h"
55
#include "htp-msg.h"
66
#include "worker-pool.h"
7+
#include "ops-utils.h"
78

89
#include <assert.h>
910
#include <stdint.h>
@@ -38,6 +39,16 @@ struct htp_ops_context {
3839
uint32_t src0_nrows_per_thread;
3940
uint32_t src1_nrows_per_thread;
4041

42+
struct fastdiv_values src0_div1; // fastdiv values for ne1
43+
struct fastdiv_values src0_div2; // fastdiv values for ne2
44+
struct fastdiv_values src0_div3; // fastdiv values for ne3
45+
struct fastdiv_values src0_div21; // fastdiv values for ne2 * ne1
46+
47+
struct fastdiv_values src1_div1; // fastdiv values for ne1
48+
struct fastdiv_values src1_div2; // fastdiv values for ne2
49+
struct fastdiv_values src1_div3; // fastdiv values for ne3
50+
struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
51+
4152
uint32_t flags;
4253
};
4354

ggml/src/ggml-hexagon/htp/ops-utils.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,39 @@ static inline uint32_t htp_round_up(uint32_t n, uint32_t m) {
3131
return m * ((n + m - 1) / m);
3232
}
3333

34+
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
35+
// Precompute mp (m' in the paper) and L such that division
36+
// can be computed using a multiply (high 32b of 64b result)
37+
// and a shift:
38+
//
39+
// n/d = (mulhi(n, mp) + n) >> L;
40+
struct fastdiv_values {
41+
uint32_t mp;
42+
uint32_t l;
43+
};
44+
45+
static inline struct fastdiv_values init_fastdiv_values(uint32_t d) {
46+
struct fastdiv_values result = { 0, 0 };
47+
// compute L = ceil(log2(d));
48+
while (result.l < 32 && ((uint32_t) 1 << result.l) < d) {
49+
++(result.l);
50+
}
51+
52+
result.mp = (uint32_t) (((uint64_t) 1 << 32) * (((uint64_t) 1 << result.l) - d) / d + 1);
53+
return result;
54+
}
55+
56+
static inline uint32_t fastdiv(uint32_t n, const struct fastdiv_values * vals) {
57+
// Compute high 32 bits of n * mp
58+
const uint32_t hi = (uint32_t) (((uint64_t) n * vals->mp) >> 32); // mulhi(n, mp)
59+
// add n, apply bit shift
60+
return (hi + n) >> vals->l;
61+
}
62+
63+
static inline uint32_t fastmodulo(uint32_t n, uint32_t d, const struct fastdiv_values * vals) {
64+
return n - fastdiv(n, vals) * d;
65+
}
66+
3467
static inline void htp_l2fetch(const void * p, uint32_t height, uint32_t width, uint32_t stride) {
3568
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
3669
asm volatile(" l2fetch(%0,%1) " : : "r"(p), "r"(control));

0 commit comments

Comments
 (0)