Skip to content

Commit 81bec47

Browse files
committed
Reapply commit "CUDA: batched+noncont MMQ, refactor bs>1 MoE code (ggml-org#13199)" quantize
1 parent 269a067 commit 81bec47

File tree

3 files changed

+39
-30
lines changed

3 files changed

+39
-30
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1769,8 +1769,7 @@ static void ggml_cuda_op_mul_mat(
17691769
dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size);
17701770

17711771
if (src1_on_device && src1_is_contiguous) {
1772-
quantize_src1(
1773-
dev[id].src1_ddf, dev[id].src1_ddq, src0->type, ne10,
1772+
quantize_src1(dev[id].src1_ddf, nullptr, dev[id].src1_ddq, src0->type, ne10,
17741773
nb11/sizeof(float), nb12/sizeof(float), nb13/sizeof(float),
17751774
src1_padded_col_size, ne11, ne12, ne13, stream);
17761775
CUDA_CHECK(cudaGetLastError());
@@ -1871,7 +1870,7 @@ static void ggml_cuda_op_mul_mat(
18711870

18721871
if (quantize_src1 && !src1_is_contiguous) {
18731872
quantize_src1(
1874-
src1_ddf_i, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10,
1873+
src1_ddf_i, nullptr, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10,
18751874
src1_padded_col_size, src1_ncols, 1, 1, stream);
18761875
CUDA_CHECK(cudaGetLastError());
18771876
}

ggml/src/ggml-cuda/quantize.cu

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -49,29 +49,38 @@ static __global__ void quantize_q8_1(
4949

5050
template <mmq_q8_1_ds_layout ds_layout>
5151
static __global__ void quantize_mmq_q8_1(
52-
const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
52+
const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
53+
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
54+
const int64_t ne0, const int ne1, const int ne2) {
5355

5456
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
5557
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
5658

57-
const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
59+
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
5860

59-
if (ix0 >= kx0_padded) {
61+
if (i0 >= ne0) {
6062
return;
6163
}
6264

63-
const float4 * x4 = (const float4 *) x;
65+
const int64_t i1 = blockIdx.y;
66+
const int64_t i2 = blockIdx.z % ne2;
67+
const int64_t i3 = blockIdx.z / ne2;
6468

65-
const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
69+
const int64_t i00 = i0;
70+
const int64_t i01 = ids ? ids[i1] : i1;
71+
const int64_t i02 = i2;
72+
const int64_t i03 = i3;
73+
74+
const float4 * x4 = (const float4 *) x;
6675

6776
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
6877

6978
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
70-
const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
71-
const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
79+
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.y; // block index in channel
80+
const int64_t iqs = i0 % (4*QK8_1); // quant index in block
7281

7382
// Load 4 floats per thread and calculate max. abs. value between them:
74-
const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
83+
const float4 xi = i0 < ne00 ? x4[(i03*s03 + i02*s02 + i01*s01 + i00)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
7584
float amax = fabsf(xi.x);
7685
amax = fmaxf(amax, fabsf(xi.y));
7786
amax = fmaxf(amax, fabsf(xi.z));
@@ -87,7 +96,7 @@ static __global__ void quantize_mmq_q8_1(
8796
if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
8897
sum = xi.x + xi.y + xi.z + xi.w;
8998

90-
// Exchange calculate sum across vals_per_sum/4 threads.
99+
// Calculate sums across vals_per_sum/4 threads.
91100
#pragma unroll
92101
for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
93102
sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
@@ -134,9 +143,10 @@ static __global__ void quantize_mmq_q8_1(
134143
}
135144

136145
void quantize_row_q8_1_cuda(
137-
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
138-
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
139-
146+
const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
147+
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
148+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
149+
GGML_ASSERT(!ids);
140150
GGML_ASSERT(ne0 % QK8_1 == 0);
141151

142152
const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
@@ -147,9 +157,9 @@ void quantize_row_q8_1_cuda(
147157
}
148158

149159
void quantize_mmq_q8_1_cuda(
150-
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
151-
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
152-
160+
const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
161+
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
162+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
153163
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
154164

155165
const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
@@ -158,21 +168,18 @@ void quantize_mmq_q8_1_cuda(
158168
switch (mmq_get_q8_1_ds_layout(type_src0)) {
159169
case MMQ_Q8_1_DS_LAYOUT_D4:
160170
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
161-
<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
171+
<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
162172
break;
163173
case MMQ_Q8_1_DS_LAYOUT_DS4:
164174
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
165-
<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
175+
<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
166176
break;
167177
case MMQ_Q8_1_DS_LAYOUT_D2S6:
168178
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
169-
<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
179+
<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
170180
break;
171181
default:
172182
GGML_ABORT("fatal error");
173183
break;
174184
}
175-
GGML_UNUSED(s01);
176-
GGML_UNUSED(s02);
177-
GGML_UNUSED(s03);
178185
}

ggml/src/ggml-cuda/quantize.cuh

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@ static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk
1212
static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
1313

1414
typedef void (*quantize_cuda_t)(
15-
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
16-
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
15+
const float * x, const int32_t * ids, void * vy,
16+
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
17+
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
1718

1819
void quantize_row_q8_1_cuda(
19-
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
20-
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
20+
const float * x, const int32_t * ids, void * vy,
21+
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
22+
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
2123

2224
void quantize_mmq_q8_1_cuda(
23-
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
24-
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
25+
const float * x, const int32_t * ids, void * vy,
26+
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
27+
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);

0 commit comments

Comments
 (0)