File tree Expand file tree Collapse file tree 1 file changed +10
-7
lines changed Expand file tree Collapse file tree 1 file changed +10
-7
lines changed Original file line number Diff line number Diff line change @@ -1840,16 +1840,19 @@ kernel void kernel_ssm_scan_f32_group(
18401840 // sum of the individual simd groups.
18411841 threadgroup_barrier (mem_flags::mem_threadgroup);
18421842
1843- // Sum the simd buckets => threadgroup sum
1843+ // For simd group 0 at indices < num simd groups, extract the shared
1844+ // simd sum
18441845 sumf = 0 .0f ;
1845- for (int64_t i0 = 0 ; i0 < sgptg; ++i0) {
1846- sumf += shared[i0];
1846+ if (sgitg == 0 ) {
1847+ if (tiisg < sgptg) {
1848+ sumf = shared[tiisg];
1849+ }
1850+ sumf = simd_sum (sumf);
1851+ if (tiisg == 0 ) {
1852+ y[0 ] = sumf;
1853+ }
18471854 }
18481855
1849- threadgroup_barrier (mem_flags::mem_threadgroup);
1850-
1851- y[0 ] = sumf;
1852-
18531856 // recurse
18541857 s0 = s;
18551858 }
You can’t perform that action at this time.
0 commit comments