@@ -49,29 +49,38 @@ static __global__ void quantize_q8_1(
4949
5050template <mmq_q8_1_ds_layout ds_layout>
5151static __global__ void quantize_mmq_q8_1 (
52- const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
52+ const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
53+ const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
54+ const int64_t ne0, const int ne1, const int ne2) {
5355
5456 constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32 ;
5557 constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32 ;
5658
57- const int64_t ix0 = ((int64_t )blockDim .x *blockIdx .x + threadIdx .x )*4 ;
59+ const int64_t i0 = ((int64_t )blockDim .x *blockIdx .x + threadIdx .x )*4 ;
5860
59- if (ix0 >= kx0_padded ) {
61+ if (i0 >= ne0 ) {
6062 return ;
6163 }
6264
63- const float4 * x4 = (const float4 *) x;
65+ const int64_t i1 = blockIdx .y ;
66+ const int64_t i2 = blockIdx .z % ne2;
67+ const int64_t i3 = blockIdx .z / ne2;
6468
65- const int64_t ix1 = kx1*blockIdx .z + blockIdx .y ;
69+ const int64_t i00 = i0;
70+ const int64_t i01 = ids ? ids[i1] : i1;
71+ const int64_t i02 = i2;
72+ const int64_t i03 = i3;
73+
74+ const float4 * x4 = (const float4 *) x;
6675
6776 block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
6877
6978 const int64_t ib0 = blockIdx .z *((int64_t )gridDim .y *gridDim .x *blockDim .x /QK8_1); // first block of channel
70- const int64_t ib = ib0 + (ix0 / (4 *QK8_1))*kx1 + blockIdx .y ; // block index in channel
71- const int64_t iqs = ix0 % (4 *QK8_1); // quant index in block
79+ const int64_t ib = ib0 + (i0 / (4 *QK8_1))*ne1 + blockIdx .y ; // block index in channel
80+ const int64_t iqs = i0 % (4 *QK8_1); // quant index in block
7281
7382 // Load 4 floats per thread and calculate max. abs. value between them:
74- const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0 )/4 ] : make_float4 (0 .0f , 0 .0f , 0 .0f , 0 .0f );
83+ const float4 xi = i0 < ne00 ? x4[(i03*s03 + i02*s02 + i01*s01 + i00 )/4 ] : make_float4 (0 .0f , 0 .0f , 0 .0f , 0 .0f );
7584 float amax = fabsf (xi.x );
7685 amax = fmaxf (amax, fabsf (xi.y ));
7786 amax = fmaxf (amax, fabsf (xi.z ));
@@ -87,7 +96,7 @@ static __global__ void quantize_mmq_q8_1(
8796 if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
8897 sum = xi.x + xi.y + xi.z + xi.w ;
8998
90- // Exchange calculate sum across vals_per_sum/4 threads.
99+ // Calculate sums across vals_per_sum/4 threads.
91100#pragma unroll
92101 for (int offset = vals_per_sum/8 ; offset > 0 ; offset >>= 1 ) {
93102 sum += __shfl_xor_sync (0xFFFFFFFF , sum, offset, WARP_SIZE);
@@ -134,9 +143,10 @@ static __global__ void quantize_mmq_q8_1(
134143}
135144
136145void quantize_row_q8_1_cuda (
137- const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
138- const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
139-
146+ const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
147+ const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
148+ const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
149+ GGML_ASSERT (!ids);
140150 GGML_ASSERT (ne0 % QK8_1 == 0 );
141151
142152 const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1 ) / CUDA_QUANTIZE_BLOCK_SIZE;
@@ -147,9 +157,9 @@ void quantize_row_q8_1_cuda(
147157}
148158
149159void quantize_mmq_q8_1_cuda (
150- const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01 , const int64_t s02, const int64_t s03 ,
151- const int64_t ne0, const int64_t ne1 , const int64_t ne2 , const int64_t ne3, cudaStream_t stream) {
152-
160+ const float * x, const int32_t * ids, void * vy , const ggml_type type_src0 ,
161+ const int64_t ne00 , const int64_t s01 , const int64_t s02, const int64_t s03,
162+ const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
153163 GGML_ASSERT (ne0 % (4 *QK8_1) == 0 );
154164
155165 const int64_t block_num_x = (ne0 + 4 *CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1 ) / (4 *CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
@@ -158,21 +168,18 @@ void quantize_mmq_q8_1_cuda(
158168 switch (mmq_get_q8_1_ds_layout (type_src0)) {
159169 case MMQ_Q8_1_DS_LAYOUT_D4:
160170 quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
161- <<<num_blocks, block_size, 0 , stream>>> (x, vy, ne00, ne1, ne0);
171+ <<<num_blocks, block_size, 0 , stream>>> (x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2 );
162172 break ;
163173 case MMQ_Q8_1_DS_LAYOUT_DS4:
164174 quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
165- <<<num_blocks, block_size, 0 , stream>>> (x, vy, ne00, ne1, ne0);
175+ <<<num_blocks, block_size, 0 , stream>>> (x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2 );
166176 break ;
167177 case MMQ_Q8_1_DS_LAYOUT_D2S6:
168178 quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
169- <<<num_blocks, block_size, 0 , stream>>> (x, vy, ne00, ne1, ne0);
179+ <<<num_blocks, block_size, 0 , stream>>> (x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2 );
170180 break ;
171181 default :
172182 GGML_ABORT (" fatal error" );
173183 break ;
174184 }
175- GGML_UNUSED (s01);
176- GGML_UNUSED (s02);
177- GGML_UNUSED (s03);
178185}
0 commit comments