@@ -1663,16 +1663,10 @@ kernel void kernel_ssm_conv_f32(
16631663 device const void * src0,
16641664 device const void * src1,
16651665 device float * dst,
1666- threadgroup float * shared [[threadgroup(0 )]],
16671666 constant ggml_metal_kargs_ssm_conv & args,
1668- uint3 tgpig[[threadgroup_position_in_grid]],
1669- uint3 tpitg[[thread_position_in_threadgroup]],
1670- ushort sgitg[[simdgroup_index_in_threadgroup]],
1671- ushort tiisg[[thread_index_in_simdgroup]],
1672- ushort sgptg[[simdgroups_per_threadgroup]],
1673- uint3 tgpg[[threadgroups_per_grid]]) {
1674-
1675- const int64_t i0 = tpitg.x ;
1667+ uint3 tgpig[[threadgroup_position_in_grid]],
1668+ uint3 tpitg[[thread_position_in_threadgroup]],
1669+ uint3 ntg[[threads_per_threadgroup]]) {
16761670 const int64_t ir = tgpig.x ;
16771671 const int64_t i2 = tgpig.y ;
16781672 const int64_t i3 = tgpig.z ;
@@ -1687,31 +1681,13 @@ kernel void kernel_ssm_conv_f32(
16871681 device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11 );
16881682 device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2 );
16891683
1690- float sumf = s[i0] * c[i0];
1691-
1692- // Parallel sum: first sum over threads in simd group, then sum over simd
1693- // group sums
1694- sumf = simd_sum (sumf);
1684+ float sumf = 0 .0f ;
16951685
1696- // If multiple simd groups per threadgroup, sum over simd group sums
1697- if (sgptg > 1 ) {
1698- if (tiisg == 0 ) {
1699- shared[sgitg] = sumf;
1700- }
1701- threadgroup_barrier (mem_flags::mem_threadgroup);
1702- sumf = 0 .0f ;
1703- if (sgitg == 0 ) {
1704- if (tiisg < sgptg) {
1705- sumf = shared[tiisg];
1706- }
1707- sumf = simd_sum (sumf);
1708- if (tiisg == 0 ) {
1709- x[0 ] = sumf;
1710- }
1711- }
1712- } else if (tiisg == 0 ) {
1713- x[0 ] = sumf;
1686+ for (int64_t i0 = 0 ; i0 < nc; ++i0) {
1687+ sumf += s[i0] * c[i0];
17141688 }
1689+
1690+ x[0 ] = sumf;
17151691}
17161692
17171693// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
0 commit comments