@@ -34,6 +34,11 @@ static hvx_elemwise_f32_func func_table_HVX[] = { hvx_mul_f32, hvx_add_f32,
3434static hvx_elemwise_f32_func func_table_HVX_opt [] = { hvx_mul_f32_opt , hvx_add_f32_opt , hvx_sub_f32_opt };
3535
3636#define htp_binary_preamble \
37+ const struct htp_tensor * src0 = &octx->src0; \
38+ const struct htp_tensor * src1 = &octx->src1; \
39+ const struct htp_tensor * src2 = &octx->src2; \
40+ struct htp_tensor * dst = &octx->dst; \
41+ \
3742 const uint32_t ne00 = src0->ne[0]; \
3843 const uint32_t ne01 = src0->ne[1]; \
3944 const uint32_t ne02 = src0->ne[2]; \
@@ -62,16 +67,15 @@ static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f
6267 const uint32_t nb0 = dst->nb[0]; \
6368 const uint32_t nb1 = dst->nb[1]; \
6469 const uint32_t nb2 = dst->nb[2]; \
65- const uint32_t nb3 = dst->nb[3];
66-
67- static void binary_job_f32_per_thread (const struct htp_tensor * src0 ,
68- const struct htp_tensor * src1 ,
69- struct htp_tensor * dst ,
70- uint8_t * spad_data ,
71- uint32_t nth ,
72- uint32_t ith ,
73- uint32_t src0_nrows_per_thread ,
74- enum htp_op op ) {
70+ const uint32_t nb3 = dst->nb[3]; \
71+ \
72+ const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
73+
74+ static void binary_job_f32_per_thread (struct htp_ops_context * octx ,
75+ uint8_t * spad_data ,
76+ uint32_t nth ,
77+ uint32_t ith ,
78+ enum htp_op op ) {
7579 htp_binary_preamble ;
7680
7781 const size_t src0_row_size = nb01 ;
@@ -107,16 +111,23 @@ static void binary_job_f32_per_thread(const struct htp_tensor * src0,
107111
108112 uint8_t * restrict spad_data_th = spad_data + (ith * src0_row_size );
109113
110- const uint32_t nr0 = ne00 / ne10 ;
111-
112114 const uint8_t * restrict src0_ptr = (const uint8_t * ) src0 -> data + (src0_start_row * src0_row_size );
113115 uint8_t * restrict dst_ptr = (uint8_t * ) dst -> data + (src0_start_row * dst_row_size );
114116
115117 const uint8_t * restrict data_src1 = (const uint8_t * ) src1 -> data ;
116- const uint8_t * restrict src1_ptr = NULL ;
118+
119+ const uint32_t ne02_ne01 = ne02 * ne01 ;
117120
118121 for (uint32_t ir = src0_start_row ; ir < src0_end_row ; ir ++ ) {
119- src1_ptr = data_src1 + (ir % src1_nrows ) * src1_row_size ;
122+ const uint32_t i03 = fastdiv (ir , & octx -> src0_div21 );
123+ const uint32_t i02 = fastdiv (ir - i03 * ne02_ne01 , & octx -> src0_div1 );
124+ const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01 );
125+
126+ const uint32_t i13 = fastmodulo (i03 , ne13 , & octx -> src1_div3 );
127+ const uint32_t i12 = fastmodulo (i02 , ne12 , & octx -> src1_div2 );
128+ const uint32_t i11 = fastmodulo (i01 , ne11 , & octx -> src1_div1 );
129+
130+ const uint8_t * restrict src1_ptr = data_src1 + i13 * nb13 + i12 * nb12 + i11 * src1_row_size ;
120131
121132 if (ir + 1 < src0_end_row ) {
122133 htp_l2fetch (src0_ptr + ne00 , 1 , src0_row_size , src0_row_size );
@@ -125,6 +136,7 @@ static void binary_job_f32_per_thread(const struct htp_tensor * src0,
125136 }
126137 }
127138
139+ const uint32_t nr0 = ne00 / ne10 ;
128140 if (nr0 > 1 ) {
129141 if ((1 == is_aligned ) && (nr0 == ne00 )) {
130142 hvx_bcast_fp32_a (spad_data_th , * (float * ) src1_ptr , nr0 );
@@ -149,22 +161,17 @@ static void binary_job_f32_per_thread(const struct htp_tensor * src0,
149161 (unsigned ) HAP_perf_qtimer_count_to_us (t2 - t1 ));
150162}
151163
152- static void binary_add_id_job_f32_per_thread (const struct htp_tensor * src0 ,
153- const struct htp_tensor * src1 ,
154- const struct htp_tensor * src2 ,
155- struct htp_tensor * dst ,
156- uint8_t * spad_data ,
157- uint32_t nth ,
158- uint32_t ith ,
159- uint32_t src0_nrows_per_thread ,
160- hvx_elemwise_f32_func func_HVX ) {
164+ static void binary_add_id_job_f32_per_thread (struct htp_ops_context * octx ,
165+ uint8_t * spad_data ,
166+ uint32_t nth ,
167+ uint32_t ith ,
168+ hvx_elemwise_f32_func func_HVX ) {
161169 htp_binary_preamble ;
162170
163171 const size_t src0_row_size = nb01 ;
164172 const size_t src1_row_size = nb11 ;
165173 const size_t dst_row_size = nb1 ;
166174
167- const uint32_t ne02_ne01 = ne02 * ne01 ;
168175 const uint32_t src0_nrows = ne01 * ne02 * ne03 ; // src0 rows
169176
170177 const uint32_t src0_start_row = src0_nrows_per_thread * ith ;
@@ -187,10 +194,11 @@ static void binary_add_id_job_f32_per_thread(const struct htp_tensor * src0,
187194 const uint8_t * restrict data_src1 = (const uint8_t * ) src1 -> data ;
188195 uint8_t * restrict data_dst = (uint8_t * ) dst -> data ;
189196
197+ const uint32_t ne02_ne01 = ne02 * ne01 ;
190198 for (uint32_t ir = src0_start_row ; ir < src0_end_row ; ir ++ ) {
191199 // src0 indices
192- const uint32_t i03 = ir / ne02_ne01 ;
193- const uint32_t i02 = (ir - i03 * ne02_ne01 ) / ne01 ;
200+ const uint32_t i03 = fastdiv ( ir , & octx -> src0_div21 ) ;
201+ const uint32_t i02 = fastdiv (ir - i03 * ne02_ne01 , & octx -> src0_div1 ) ;
194202 const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01 );
195203
196204 // src1 indices
@@ -234,13 +242,11 @@ static void binary_job_dispatcher_f32(unsigned int n, unsigned int i, void * dat
234242 case HTP_OP_MUL :
235243 case HTP_OP_ADD :
236244 case HTP_OP_SUB :
237- binary_job_f32_per_thread (& octx -> src0 , & octx -> src1 , & octx -> dst , octx -> src1_spad .data , n , i ,
238- octx -> src0_nrows_per_thread , octx -> op );
245+ binary_job_f32_per_thread (octx , octx -> src1_spad .data , n , i , octx -> op );
239246 break ;
240247
241248 case HTP_OP_ADD_ID :
242- binary_add_id_job_f32_per_thread (& octx -> src0 , & octx -> src1 , & octx -> src2 , & octx -> dst , octx -> src0_spad .data , n ,
243- i , octx -> src0_nrows_per_thread , hvx_add_f32 );
249+ binary_add_id_job_f32_per_thread (octx , octx -> src0_spad .data , n , i , hvx_add_f32 );
244250 break ;
245251
246252 default :
@@ -321,6 +327,16 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
321327
322328 octx -> src0_nrows_per_thread = (src0_nrows + n_jobs - 1 ) / n_jobs ;
323329
330+ octx -> src0_div21 = init_fastdiv_values (src0 -> ne [2 ] * src0 -> ne [1 ]);
331+ octx -> src0_div3 = init_fastdiv_values (src0 -> ne [3 ]);
332+ octx -> src0_div2 = init_fastdiv_values (src0 -> ne [2 ]);
333+ octx -> src0_div1 = init_fastdiv_values (src0 -> ne [1 ]);
334+
335+ octx -> src1_div21 = init_fastdiv_values (src1 -> ne [2 ] * src1 -> ne [1 ]);
336+ octx -> src1_div3 = init_fastdiv_values (src1 -> ne [3 ]);
337+ octx -> src1_div2 = init_fastdiv_values (src1 -> ne [2 ]);
338+ octx -> src1_div1 = init_fastdiv_values (src1 -> ne [1 ]);
339+
324340 worker_pool_run_func (octx -> ctx -> worker_pool , binary_op_func , octx , n_jobs );
325341 }
326342
0 commit comments