@@ -60,7 +60,7 @@ class IMatrixCollector {
6060 int m_last_call = 0 ;
6161 int m_last_layer = 9999 ;
6262 int m_last_ffn = -1 ;
63- std::vector<float > m_src1_data;
63+ std::vector<char > m_src1_data;
6464 std::vector<char > m_ids; // the expert ids from ggml_mul_mat_id
6565 std::vector<float > m_last_input;
6666 std::vector<float > m_ffn_input;
@@ -189,11 +189,12 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
189189 const bool is_host = ggml_backend_buffer_is_host (src1->buffer );
190190
191191 if (!is_host) {
192- m_src1_data.resize (ggml_nelements (src1));
193- ggml_backend_tensor_get (src1, m_src1_data.data (), 0 , ggml_nbytes (src1));
192+ auto nbytes = ggml_nbytes (src1);
193+ m_src1_data.resize (nbytes);
194+ ggml_backend_tensor_get (src1, m_src1_data.data (), 0 , nbytes);
194195 }
195196
196- const float * data = is_host ? (const float *) src1->data : m_src1_data.data ();
197+ const float * data = is_host ? (const float *) src1->data : ( const float *) m_src1_data.data ();
197198
198199 if (m_collect_lsim) {
199200 if (wname.find (" .ffn_" ) != std::string::npos) {
@@ -331,25 +332,38 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
331332 }
332333 auto & e = m_stats[wname];
333334 if (e.values .empty ()) {
334- e.values .resize (src1->ne [0 ], 0 );
335- e.counts .resize (src1->ne [0 ], 0 );
335+ if (src0->ne [3 ] > 1 ) {
336+ fprintf (stderr, " Unsupported 4D tensor %s\n " , wname.c_str ());
337+ exit (1 );
338+ }
339+ // If we have a 3D tensor as it is the case for the attn_k_b and attn_v_b for DeepSeek MLA models,
340+ // than we need to compute the imatrix for each head, and not just one imatrx for all heads.
341+ // Hence, the storage we need is src0->ne[0]*src0->ne[2].
342+ e.values .resize (src0->ne [0 ]*src0->ne [2 ], 0 );
343+ e.counts .resize (src0->ne [0 ]*src0->ne [2 ], 0 );
336344 }
337- else if (e.values .size () != (size_t )src1 ->ne [0 ]) {
345+ else if (e.values .size () != (size_t )(src0 ->ne [0 ]*src0-> ne [ 2 ]) ) {
338346 fprintf (stderr, " Oops: inconsistent size for %s (%d vs %d)\n " , wname.c_str (), (int )e.values .size (), (int )src1->ne [0 ]);
339347 exit (1 ); // GGML_ABORT("fatal error");
340348 }
341349 ++e.ncall ;
342350 if (m_params.verbosity > 1 ) {
343351 printf (" %s[%d]: %32s, %s, %5d x %5d, %d\n " , __func__, m_last_call, wname.c_str (), ggml_op_name (t->op ), (int )src1->ne [0 ], (int )src1->ne [1 ], (int )src1->type );
344352 }
345- for (int row = 0 ; row < (int )(src1->ne [1 ]*src1->ne [2 ]); ++row) {
346- const float * x = data + row * src1->ne [0 ];
347- for (int j = 0 ; j < (int )src1->ne [0 ]; ++j) {
348- e.values [j] += x[j]*x[j];
349- e.counts [j]++;
350- if (!std::isfinite (e.values [j])) {
351- fprintf (stderr, " %f detected in %s\n " , e.values [j], wname.c_str ());
352- exit (1 );
353+ int rk2 = src1->ne [2 ]/src0->ne [2 ];
354+ for (int i12 = 0 ; i12 < (int )src1->ne [2 ]; ++i12) { // i.e., loop over attention heads for MLA models
355+ int i02 = i12/rk2;
356+ auto values = e.values .data () + i02*src0->ne [0 ];
357+ auto counts = e.counts .data () + i02*src0->ne [0 ];
358+ for (int i11 = 0 ; i11 < (int )src1->ne [1 ]; ++i11) {
359+ const float * x = (const float *)((const char *)data + i11*src1->nb [1 ] + i12*src1->nb [2 ]);
360+ for (int j = 0 ; j < (int )src1->ne [0 ]; ++j) {
361+ values[j] += x[j]*x[j];
362+ counts[j]++;
363+ if (!std::isfinite (values[j])) {
364+ fprintf (stderr, " %f detected in %s\n " , e.values [j], wname.c_str ());
365+ exit (1 );
366+ }
353367 }
354368 }
355369 }
0 commit comments