@@ -56,13 +56,13 @@ static __global__ void quantize_mmq_q8_1(
5656 constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32 ;
5757 constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32 ;
5858
59- const int64_t i0 = ((int64_t )blockDim .x *blockIdx .x + threadIdx .x )*4 ;
59+ const int64_t i0 = ((int64_t )blockDim .x *blockIdx .y + threadIdx .x )*4 ;
6060
6161 if (i0 >= ne0) {
6262 return ;
6363 }
6464
65- const int64_t i1 = blockIdx .y ;
65+ const int64_t i1 = blockIdx .x ;
6666 const int64_t i2 = blockIdx .z % ne2;
6767 const int64_t i3 = blockIdx .z / ne2;
6868
@@ -75,8 +75,8 @@ static __global__ void quantize_mmq_q8_1(
7575
7676 block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
7777
78- const int64_t ib0 = blockIdx .z *((int64_t )gridDim .y *gridDim .x *blockDim .x /QK8_1); // first block of channel
79- const int64_t ib = ib0 + (i0 / (4 *QK8_1))*ne1 + blockIdx .y ; // block index in channel
78+ const int64_t ib0 = blockIdx .z *((int64_t )gridDim .x *gridDim .y *blockDim .x /QK8_1); // first block of channel
79+ const int64_t ib = ib0 + (i0 / (4 *QK8_1))*ne1 + blockIdx .x ; // block index in channel
8080 const int64_t iqs = i0 % (4 *QK8_1); // quant index in block
8181
8282 // Load 4 floats per thread and calculate max. abs. value between them:
@@ -162,8 +162,9 @@ void quantize_mmq_q8_1_cuda(
162162 const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
163163 GGML_ASSERT (ne0 % (4 *QK8_1) == 0 );
164164
165- const int64_t block_num_x = (ne0 + 4 *CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1 ) / (4 *CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
166- const dim3 num_blocks (block_num_x, ne1, ne2*ne3);
165+ // ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid:
166+ const int64_t block_num_y = (ne0 + 4 *CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1 ) / (4 *CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
167+ const dim3 num_blocks (ne1, block_num_y, ne2*ne3);
167168 const dim3 block_size (CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1 , 1 );
168169 switch (mmq_get_q8_1_ds_layout (type_src0)) {
169170 case MMQ_Q8_1_DS_LAYOUT_D4:
0 commit comments