Skip to content

Commit 0c57f84

Browse files
ikawrakowIwan Kawrakow
andauthored
Fix imatrix calculation for MLA models (#411)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent 553c08b commit 0c57f84

File tree

1 file changed

+29
-15
lines changed

1 file changed

+29
-15
lines changed

examples/imatrix/imatrix.cpp

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class IMatrixCollector {
6060
int m_last_call = 0;
6161
int m_last_layer = 9999;
6262
int m_last_ffn = -1;
63-
std::vector<float> m_src1_data;
63+
std::vector<char> m_src1_data;
6464
std::vector<char> m_ids; // the expert ids from ggml_mul_mat_id
6565
std::vector<float> m_last_input;
6666
std::vector<float> m_ffn_input;
@@ -189,11 +189,12 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
189189
const bool is_host = ggml_backend_buffer_is_host(src1->buffer);
190190

191191
if (!is_host) {
192-
m_src1_data.resize(ggml_nelements(src1));
193-
ggml_backend_tensor_get(src1, m_src1_data.data(), 0, ggml_nbytes(src1));
192+
auto nbytes = ggml_nbytes(src1);
193+
m_src1_data.resize(nbytes);
194+
ggml_backend_tensor_get(src1, m_src1_data.data(), 0, nbytes);
194195
}
195196

196-
const float * data = is_host ? (const float *) src1->data : m_src1_data.data();
197+
const float * data = is_host ? (const float *) src1->data : (const float *)m_src1_data.data();
197198

198199
if (m_collect_lsim) {
199200
if (wname.find(".ffn_") != std::string::npos) {
@@ -331,25 +332,38 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
331332
}
332333
auto & e = m_stats[wname];
333334
if (e.values.empty()) {
334-
e.values.resize(src1->ne[0], 0);
335-
e.counts.resize(src1->ne[0], 0);
335+
if (src0->ne[3] > 1) {
336+
fprintf(stderr, "Unsupported 4D tensor %s\n", wname.c_str());
337+
exit(1);
338+
}
339+
// If we have a 3D tensor as it is the case for the attn_k_b and attn_v_b for DeepSeek MLA models,
340+
// than we need to compute the imatrix for each head, and not just one imatrx for all heads.
341+
// Hence, the storage we need is src0->ne[0]*src0->ne[2].
342+
e.values.resize(src0->ne[0]*src0->ne[2], 0);
343+
e.counts.resize(src0->ne[0]*src0->ne[2], 0);
336344
}
337-
else if (e.values.size() != (size_t)src1->ne[0]) {
345+
else if (e.values.size() != (size_t)(src0->ne[0]*src0->ne[2])) {
338346
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]);
339347
exit(1); //GGML_ABORT("fatal error");
340348
}
341349
++e.ncall;
342350
if (m_params.verbosity > 1) {
343351
printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type);
344352
}
345-
for (int row = 0; row < (int)(src1->ne[1]*src1->ne[2]); ++row) {
346-
const float * x = data + row * src1->ne[0];
347-
for (int j = 0; j < (int)src1->ne[0]; ++j) {
348-
e.values[j] += x[j]*x[j];
349-
e.counts[j]++;
350-
if (!std::isfinite(e.values[j])) {
351-
fprintf(stderr, "%f detected in %s\n", e.values[j], wname.c_str());
352-
exit(1);
353+
int rk2 = src1->ne[2]/src0->ne[2];
354+
for (int i12 = 0; i12 < (int)src1->ne[2]; ++i12) { // i.e., loop over attention heads for MLA models
355+
int i02 = i12/rk2;
356+
auto values = e.values.data() + i02*src0->ne[0];
357+
auto counts = e.counts.data() + i02*src0->ne[0];
358+
for (int i11 = 0; i11 < (int)src1->ne[1]; ++i11) {
359+
const float * x = (const float *)((const char *)data + i11*src1->nb[1] + i12*src1->nb[2]);
360+
for (int j = 0; j < (int)src1->ne[0]; ++j) {
361+
values[j] += x[j]*x[j];
362+
counts[j]++;
363+
if (!std::isfinite(values[j])) {
364+
fprintf(stderr, "%f detected in %s\n", e.values[j], wname.c_str());
365+
exit(1);
366+
}
353367
}
354368
}
355369
}

0 commit comments

Comments
 (0)