@@ -1700,10 +1700,16 @@ kernel void kernel_ssm_scan_f32(
17001700 device const void * src5,
17011701 device const void * src6,
17021702 device float * dst,
1703+ threadgroup float * shared [[threadgroup(0 )]],
17031704 constant ggml_metal_kargs_ssm_scan & args,
1704- uint3 tgpig[[threadgroup_position_in_grid]],
1705- uint3 tpitg[[thread_position_in_threadgroup]],
1706- uint3 ntg[[threads_per_threadgroup]]) {
1705+ uint3 tgpig[[threadgroup_position_in_grid]],
1706+ uint3 tpitg[[thread_position_in_threadgroup]],
1707+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1708+ ushort tiisg[[thread_index_in_simdgroup]],
1709+ ushort sgptg[[simdgroups_per_threadgroup]],
1710+ uint3 tgpg[[threadgroups_per_grid]]) {
1711+
1712+ const int64_t i0 = tpitg.x ;
17071713 const int64_t i1 = 0 ;
17081714 const int64_t ir = tgpig.x ; // current head
17091715 const int64_t i3 = tgpig.y ; // current seq
@@ -1718,37 +1724,85 @@ kernel void kernel_ssm_scan_f32(
17181724 const int64_t ng = args.n_group ;
17191725 const int64_t n_t = args.n_seq_tokens ;
17201726
1721- const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof ( float ) ;
1727+ const int64_t s_off = args.s_off ;
17221728
17231729 device const int32_t * ids = (device const int32_t *) src6;
17241730
1725- device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03 );
1726- device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1731+ device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03 );
1732+ device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1733+ const int64_t i = i0 + i1*nc;
1734+ float s0 = s0_buff[i];
1735+ float s = s_buff[i];
1736+
1737+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31 );
1738+ device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13 );
1739+ device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22 );
1740+ device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1 ))*args.nb41 + i3*args.nb43 );
1741+ device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1 ))*args.nb51 + i3*args.nb53 );
1742+ device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t *nh*nr))*nb00);
17271743
17281744 for (int64_t i2 = 0 ; i2 < n_t ; ++i2) {
1729- device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13 ); // {dim, nh, nt, ns}
1730- device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22 ); // {nh, nt, ns}
1731- device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31 ); // {d_state, nh}
1732- device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1 ))*args.nb41 + i2*args.nb42 + i3*args.nb43 ); // {d_state, ng, nt, ns}
1733- device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1 ))*args.nb51 + i2*args.nb52 + i3*args.nb53 ); // {d_state, ng, nt, ns}
1734- device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t *nh*nr))*nb00); // {dim, nh, nt, ns}
1745+ device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12 ); // {dim, nh, nt, ns}
1746+ device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21 ); // {nh, nt, ns}
1747+ device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42 ); // {d_state, ng, nt, ns}
1748+ device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52 ); // {d_state, ng, nt, ns}
1749+ device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
17351750
17361751 const float dt_soft_plus = dt[0 ] <= 20 .0f ? log (1 .0f + exp (dt[0 ])) : dt[0 ];
17371752 const float x_dt = x[0 ] * dt_soft_plus;
1738- float sumf = 0 .0f ;
17391753
1740- for (int64_t i0 = 0 ; i0 < nc; ++i0) {
1741- const int64_t i = i0 + i1*nc;
1742- const float state = (s0[i] * exp (dt_soft_plus * A[i0])) + (B[i0] * x_dt);
1743- sumf += state * C[i0];
1744- s[i] = state;
1745- }
1754+ const float state = (s0 * exp (dt_soft_plus * A[i0])) + (B[i0] * x_dt);
1755+ s = state;
1756+
1757+ // Parallel sum: This relies on the fact that this kernel will be
1758+ // dispatched with each threadgroup having (d_state, 1, 1) threads which
1759+ // are subdivided into SIMD groups of size `sgptg`. The goal is to
1760+ // compute y = sum({state * C[i] for i in range(d_state)}).
1761+ // To parallelize this effectively, we first use simd_sum over each SIMD
1762+ // group to compute the sum of each SIMD group, then place the result in
1763+ // the SIMD group's indexed bucket in the shared memory. We then sum
1764+ // over the individual group sums to compute the final sum.
1765+
1766+ // Computed for each thread
1767+ float sumf = state * C[i0];
1768+
1769+ // Sum the threads in the simd group => simd sum
1770+ sumf = simd_sum (sumf);
17461771
1747- y[0 ] = sumf;
1772+ if (sgptg > 1 ) {
1773+
1774+ // Once per simd group, place the group sum into the shared buffer
1775+ if (tiisg == 0 ) {
1776+ shared[sgitg] = sumf;
1777+ }
1778+
1779+ // Wait for all threads in the threadgroup to reach this point. This
1780+ // ensures that all elements of the shared buffer are populated with the
1781+ // sum of the individual simd groups.
1782+ threadgroup_barrier (mem_flags::mem_threadgroup);
1783+
1784+ // For simd group 0 at indices < num simd groups, extract the shared
1785+ // simd sum
1786+ sumf = 0 .0f ;
1787+ if (sgitg == 0 ) {
1788+ if (tiisg < sgptg) {
1789+ sumf = shared[tiisg];
1790+ }
1791+ sumf = simd_sum (sumf);
1792+ if (tiisg == 0 ) {
1793+ y[0 ] = sumf;
1794+ }
1795+ }
1796+ } else if (tiisg == 0 ) {
1797+ y[0 ] = sumf;
1798+ }
17481799
17491800 // recurse
17501801 s0 = s;
17511802 }
1803+
1804+ // Assign the final state to the output buffer
1805+ s_buff[i] = s;
17521806}
17531807
17541808// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
@@ -1770,6 +1824,7 @@ kernel void kernel_ssm_scan_f32_group(
17701824 ushort sgptg[[simdgroups_per_threadgroup]],
17711825 uint3 tgpg[[threadgroups_per_grid]]) {
17721826
1827+ const int64_t i0 = tpitg.x ;
17731828 const int64_t i1 = tgpig.x ;
17741829 const int64_t ir = tgpig.y ; // current head
17751830 const int64_t i3 = tgpig.z ; // current seq
@@ -1790,7 +1845,7 @@ kernel void kernel_ssm_scan_f32_group(
17901845
17911846 device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03 );
17921847 device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1793- const int64_t i = tpitg. x + i1*nc;
1848+ const int64_t i = i0 + i1*nc;
17941849 float s0 = s0_buff[i];
17951850 float s = s_buff[i];
17961851
@@ -1812,7 +1867,7 @@ kernel void kernel_ssm_scan_f32_group(
18121867 const float x_dt = x[0 ] * dt_soft_plus;
18131868 const float dA = exp (dt_soft_plus * A[0 ]);
18141869
1815- const float state = (s0 * dA) + (B[tpitg. x ] * x_dt);
1870+ const float state = (s0 * dA) + (B[i0 ] * x_dt);
18161871 s = state;
18171872
18181873 // Parallel sum: This relies on the fact that this kernel will be
@@ -1825,7 +1880,7 @@ kernel void kernel_ssm_scan_f32_group(
18251880 // over the individual group sums to compute the final sum.
18261881
18271882 // Computed for each thread
1828- float sumf = state * C[tpitg. x ];
1883+ float sumf = state * C[i0 ];
18291884
18301885 // Sum the threads in the simd group => simd sum
18311886 sumf = simd_sum (sumf);
0 commit comments