Skip to content

Commit c273d75

Browse files
hexagon: various Op fixes (#17135)
* hexagon: explicitly check for ops with zero nrows llm_graph_context::build_inp_out_ids() can generate tensors with zero nrows. Somehow other backends seems to handle this without obvious explicit checks. In the hexagon case we need to check explicitly and skip them. * hexagon: introduce fastdiv, fix test-backend-ops for ADD/SUB/MUL Co-authored-by: chraac <chraac@gmail.com> * hexagon: use fastdiv in ADD_ID * hexagon: use ggml_op_is_empty and ggml_is_empty to check for NOPs --------- Co-authored-by: chraac <chraac@gmail.com>
1 parent 7d019cf commit c273d75

File tree

5 files changed

+106
-59
lines changed

5 files changed

+106
-59
lines changed

ggml/src/ggml-hexagon/ggml-hexagon.cpp

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3156,26 +3156,17 @@ static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op
31563156
return (op0 && op0->src[1] == op1->src[1]);
31573157
}
31583158

3159+
static inline bool is_compute_op(ggml_tensor *node)
3160+
{
3161+
return !(ggml_op_is_empty(node->op) || ggml_is_empty(node));
3162+
}
3163+
31593164
// scan the graph and figure out last compute op index
31603165
static inline int last_compute_op(ggml_cgraph * graph) {
3161-
int last;
3166+
int last = 0;
31623167
for (int i = 0; i < graph->n_nodes; ++i) {
3163-
ggml_tensor * node = graph->nodes[i];
3164-
3165-
switch (node->op) {
3166-
case GGML_OP_MUL_MAT:
3167-
case GGML_OP_MUL_MAT_ID:
3168-
case GGML_OP_MUL:
3169-
case GGML_OP_ADD:
3170-
case GGML_OP_SUB:
3171-
case GGML_OP_RMS_NORM:
3172-
case GGML_OP_GLU:
3173-
case GGML_OP_ADD_ID:
3174-
last = i;
3175-
break;
3176-
3177-
default:
3178-
break;
3168+
if (is_compute_op(graph->nodes[i])) {
3169+
last = i;
31793170
}
31803171
}
31813172

@@ -3194,6 +3185,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
31943185
for (int i = 0; i < graph->n_nodes; ++i) {
31953186
ggml_tensor * node = graph->nodes[i];
31963187

3188+
if (!is_compute_op(node)) {
3189+
continue;
3190+
}
3191+
31973192
uint32_t flags = 0;
31983193

31993194
// skip quantizer if src1 is reused
@@ -3245,14 +3240,6 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
32453240
ggml_hexagon_rope(node, flags);
32463241
break;
32473242

3248-
// non-compute ops
3249-
case GGML_OP_NONE:
3250-
case GGML_OP_RESHAPE:
3251-
case GGML_OP_VIEW:
3252-
case GGML_OP_PERMUTE:
3253-
case GGML_OP_TRANSPOSE:
3254-
break;
3255-
32563243
default:
32573244
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
32583245
}

ggml/src/ggml-hexagon/htp/binary-ops.c

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ static hvx_elemwise_f32_func func_table_HVX[] = { hvx_mul_f32, hvx_add_f32,
3434
static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f32_opt, hvx_sub_f32_opt };
3535

3636
#define htp_binary_preamble \
37+
const struct htp_tensor * src0 = &octx->src0; \
38+
const struct htp_tensor * src1 = &octx->src1; \
39+
const struct htp_tensor * src2 = &octx->src2; \
40+
struct htp_tensor * dst = &octx->dst; \
41+
\
3742
const uint32_t ne00 = src0->ne[0]; \
3843
const uint32_t ne01 = src0->ne[1]; \
3944
const uint32_t ne02 = src0->ne[2]; \
@@ -62,16 +67,15 @@ static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f
6267
const uint32_t nb0 = dst->nb[0]; \
6368
const uint32_t nb1 = dst->nb[1]; \
6469
const uint32_t nb2 = dst->nb[2]; \
65-
const uint32_t nb3 = dst->nb[3];
66-
67-
static void binary_job_f32_per_thread(const struct htp_tensor * src0,
68-
const struct htp_tensor * src1,
69-
struct htp_tensor * dst,
70-
uint8_t * spad_data,
71-
uint32_t nth,
72-
uint32_t ith,
73-
uint32_t src0_nrows_per_thread,
74-
enum htp_op op) {
70+
const uint32_t nb3 = dst->nb[3]; \
71+
\
72+
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
73+
74+
static void binary_job_f32_per_thread(struct htp_ops_context * octx,
75+
uint8_t * spad_data,
76+
uint32_t nth,
77+
uint32_t ith,
78+
enum htp_op op) {
7579
htp_binary_preamble;
7680

7781
const size_t src0_row_size = nb01;
@@ -107,16 +111,23 @@ static void binary_job_f32_per_thread(const struct htp_tensor * src0,
107111

108112
uint8_t * restrict spad_data_th = spad_data + (ith * src0_row_size);
109113

110-
const uint32_t nr0 = ne00 / ne10;
111-
112114
const uint8_t * restrict src0_ptr = (const uint8_t *) src0->data + (src0_start_row * src0_row_size);
113115
uint8_t * restrict dst_ptr = (uint8_t *) dst->data + (src0_start_row * dst_row_size);
114116

115117
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
116-
const uint8_t * restrict src1_ptr = NULL;
118+
119+
const uint32_t ne02_ne01 = ne02 * ne01;
117120

118121
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
119-
src1_ptr = data_src1 + (ir % src1_nrows) * src1_row_size;
122+
const uint32_t i03 = fastdiv(ir, &octx->src0_div21);
123+
const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1);
124+
const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
125+
126+
const uint32_t i13 = fastmodulo(i03, ne13, &octx->src1_div3);
127+
const uint32_t i12 = fastmodulo(i02, ne12, &octx->src1_div2);
128+
const uint32_t i11 = fastmodulo(i01, ne11, &octx->src1_div1);
129+
130+
const uint8_t * restrict src1_ptr = data_src1 + i13 * nb13 + i12 * nb12 + i11 * src1_row_size;
120131

121132
if (ir + 1 < src0_end_row) {
122133
htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
@@ -125,6 +136,7 @@ static void binary_job_f32_per_thread(const struct htp_tensor * src0,
125136
}
126137
}
127138

139+
const uint32_t nr0 = ne00 / ne10;
128140
if (nr0 > 1) {
129141
if ((1 == is_aligned) && (nr0 == ne00)) {
130142
hvx_bcast_fp32_a(spad_data_th, *(float *) src1_ptr, nr0);
@@ -149,22 +161,17 @@ static void binary_job_f32_per_thread(const struct htp_tensor * src0,
149161
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
150162
}
151163

152-
static void binary_add_id_job_f32_per_thread(const struct htp_tensor * src0,
153-
const struct htp_tensor * src1,
154-
const struct htp_tensor * src2,
155-
struct htp_tensor * dst,
156-
uint8_t * spad_data,
157-
uint32_t nth,
158-
uint32_t ith,
159-
uint32_t src0_nrows_per_thread,
160-
hvx_elemwise_f32_func func_HVX) {
164+
static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx,
165+
uint8_t * spad_data,
166+
uint32_t nth,
167+
uint32_t ith,
168+
hvx_elemwise_f32_func func_HVX) {
161169
htp_binary_preamble;
162170

163171
const size_t src0_row_size = nb01;
164172
const size_t src1_row_size = nb11;
165173
const size_t dst_row_size = nb1;
166174

167-
const uint32_t ne02_ne01 = ne02 * ne01;
168175
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
169176

170177
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
@@ -187,10 +194,11 @@ static void binary_add_id_job_f32_per_thread(const struct htp_tensor * src0,
187194
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
188195
uint8_t * restrict data_dst = (uint8_t *) dst->data;
189196

197+
const uint32_t ne02_ne01 = ne02 * ne01;
190198
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
191199
// src0 indices
192-
const uint32_t i03 = ir / ne02_ne01;
193-
const uint32_t i02 = (ir - i03 * ne02_ne01) / ne01;
200+
const uint32_t i03 = fastdiv(ir, &octx->src0_div21);
201+
const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1);
194202
const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
195203

196204
// src1 indices
@@ -234,13 +242,11 @@ static void binary_job_dispatcher_f32(unsigned int n, unsigned int i, void * dat
234242
case HTP_OP_MUL:
235243
case HTP_OP_ADD:
236244
case HTP_OP_SUB:
237-
binary_job_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->src1_spad.data, n, i,
238-
octx->src0_nrows_per_thread, octx->op);
245+
binary_job_f32_per_thread(octx, octx->src1_spad.data, n, i, octx->op);
239246
break;
240247

241248
case HTP_OP_ADD_ID:
242-
binary_add_id_job_f32_per_thread(&octx->src0, &octx->src1, &octx->src2, &octx->dst, octx->src0_spad.data, n,
243-
i, octx->src0_nrows_per_thread, hvx_add_f32);
249+
binary_add_id_job_f32_per_thread(octx, octx->src0_spad.data, n, i, hvx_add_f32);
244250
break;
245251

246252
default:
@@ -321,6 +327,16 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
321327

322328
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
323329

330+
octx->src0_div21 = init_fastdiv_values(src0->ne[2] * src0->ne[1]);
331+
octx->src0_div3 = init_fastdiv_values(src0->ne[3]);
332+
octx->src0_div2 = init_fastdiv_values(src0->ne[2]);
333+
octx->src0_div1 = init_fastdiv_values(src0->ne[1]);
334+
335+
octx->src1_div21 = init_fastdiv_values(src1->ne[2] * src1->ne[1]);
336+
octx->src1_div3 = init_fastdiv_values(src1->ne[3]);
337+
octx->src1_div2 = init_fastdiv_values(src1->ne[2]);
338+
octx->src1_div1 = init_fastdiv_values(src1->ne[1]);
339+
324340
worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, octx, n_jobs);
325341
}
326342

ggml/src/ggml-hexagon/htp/htp-msg.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,10 @@ static const char * htp_type_name(uint32_t t) {
119119
#define HTP_MAX_DIMS 4
120120

121121
struct htp_tensor {
122-
uint32_t data; // Buffer offset in the messages, and data pointer on the NSP
123-
uint32_t type; // Data type
124-
uint32_t ne[HTP_MAX_DIMS]; // Number of elements
125-
uint32_t nb[HTP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor)
122+
uint32_t data; // Buffer offset in the messages, and data pointer on the NSP
123+
uint32_t type; // Data type
124+
uint32_t ne[HTP_MAX_DIMS]; // Number of elements
125+
uint32_t nb[HTP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor)
126126
};
127127

128128
#define HTP_MAX_OP_PARAMS 64

ggml/src/ggml-hexagon/htp/htp-ops.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "htp-ctx.h"
55
#include "htp-msg.h"
66
#include "worker-pool.h"
7+
#include "ops-utils.h"
78

89
#include <assert.h>
910
#include <stdint.h>
@@ -38,6 +39,16 @@ struct htp_ops_context {
3839
uint32_t src0_nrows_per_thread;
3940
uint32_t src1_nrows_per_thread;
4041

42+
struct fastdiv_values src0_div1; // fastdiv values for ne1
43+
struct fastdiv_values src0_div2; // fastdiv values for ne2
44+
struct fastdiv_values src0_div3; // fastdiv values for ne3
45+
struct fastdiv_values src0_div21; // fastdiv values for ne2 * ne1
46+
47+
struct fastdiv_values src1_div1; // fastdiv values for ne1
48+
struct fastdiv_values src1_div2; // fastdiv values for ne2
49+
struct fastdiv_values src1_div3; // fastdiv values for ne3
50+
struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
51+
4152
uint32_t flags;
4253
};
4354

ggml/src/ggml-hexagon/htp/ops-utils.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,39 @@ static inline uint32_t htp_round_up(uint32_t n, uint32_t m) {
3131
return m * ((n + m - 1) / m);
3232
}
3333

34+
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
35+
// Precompute mp (m' in the paper) and L such that division
36+
// can be computed using a multiply (high 32b of 64b result)
37+
// and a shift:
38+
//
39+
// n/d = (mulhi(n, mp) + n) >> L;
40+
struct fastdiv_values {
41+
uint32_t mp;
42+
uint32_t l;
43+
};
44+
45+
static inline struct fastdiv_values init_fastdiv_values(uint32_t d) {
46+
struct fastdiv_values result = { 0, 0 };
47+
// compute L = ceil(log2(d));
48+
while (result.l < 32 && ((uint32_t) 1 << result.l) < d) {
49+
++(result.l);
50+
}
51+
52+
result.mp = (uint32_t) (((uint64_t) 1 << 32) * (((uint64_t) 1 << result.l) - d) / d + 1);
53+
return result;
54+
}
55+
56+
static inline uint32_t fastdiv(uint32_t n, const struct fastdiv_values * vals) {
57+
// Compute high 32 bits of n * mp
58+
const uint32_t hi = (uint32_t) (((uint64_t) n * vals->mp) >> 32); // mulhi(n, mp)
59+
// add n, apply bit shift
60+
return (hi + n) >> vals->l;
61+
}
62+
63+
static inline uint32_t fastmodulo(uint32_t n, uint32_t d, const struct fastdiv_values * vals) {
64+
return n - fastdiv(n, vals) * d;
65+
}
66+
3467
static inline void htp_l2fetch(const void * p, uint32_t height, uint32_t width, uint32_t stride) {
3568
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
3669
asm volatile(" l2fetch(%0,%1) " : : "r"(p), "r"(control));

0 commit comments

Comments
 (0)