Skip to content

Commit 299f046

Browse files
committed
Reapply "CUDA: fix crash on large batch size for quant. MoE (ggml-org#13537)"
Except MMQ.
1 parent bd6d3d9 commit 299f046

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

ggml/src/ggml-cuda/quantize.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@ static __global__ void quantize_mmq_q8_1(
5656
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
5757
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
5858

59-
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
59+
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4;
6060

6161
if (i0 >= ne0) {
6262
return;
6363
}
6464

65-
const int64_t i1 = blockIdx.y;
65+
const int64_t i1 = blockIdx.x;
6666
const int64_t i2 = blockIdx.z % ne2;
6767
const int64_t i3 = blockIdx.z / ne2;
6868

@@ -75,8 +75,8 @@ static __global__ void quantize_mmq_q8_1(
7575

7676
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
7777

78-
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
79-
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.y; // block index in channel
78+
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel
79+
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel
8080
const int64_t iqs = i0 % (4*QK8_1); // quant index in block
8181

8282
// Load 4 floats per thread and calculate max. abs. value between them:
@@ -162,8 +162,9 @@ void quantize_mmq_q8_1_cuda(
162162
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
163163
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
164164

165-
const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
166-
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
165+
// ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid:
166+
const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
167+
const dim3 num_blocks(ne1, block_num_y, ne2*ne3);
167168
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
168169
switch (mmq_get_q8_1_ds_layout(type_src0)) {
169170
case MMQ_Q8_1_DS_LAYOUT_D4:

0 commit comments

Comments
 (0)