@@ -1142,18 +1142,18 @@ struct llm_build_context {
11421142
11431143 ctx0 = ggml_init (params);
11441144
1145- lctx.inp_tokens = nullptr ;
1146- lctx.inp_embd = nullptr ;
1147- lctx.inp_pos = nullptr ;
1148- lctx.inp_out_ids = nullptr ;
1149- lctx.inp_KQ_mask = nullptr ;
1150- lctx.inp_KQ_mask_swa = nullptr ;
1151- lctx.inp_K_shift = nullptr ;
1152- lctx.inp_mean = nullptr ;
1153- lctx.inp_cls = nullptr ;
1154- lctx.inp_s_copy = nullptr ;
1155- lctx.inp_s_mask = nullptr ;
1156- lctx.inp_s_seq = nullptr ;
1145+ lctx.inp_tokens = nullptr ;
1146+ lctx.inp_embd = nullptr ;
1147+ lctx.inp_pos = nullptr ;
1148+ lctx.inp_out_ids = nullptr ;
1149+ lctx.inp_KQ_mask = nullptr ;
1150+ lctx.inp_KQ_mask_swa = nullptr ;
1151+ lctx.inp_K_shift = nullptr ;
1152+ lctx.inp_mean = nullptr ;
1153+ lctx.inp_cls = nullptr ;
1154+ lctx.inp_s_copy = nullptr ;
1155+ lctx.inp_s_mask = nullptr ;
1156+ lctx.inp_s_seq = nullptr ;
11571157 lctx.inp_pos_bucket = nullptr ;
11581158 lctx.inp_embd_enc = nullptr ;
11591159 lctx.inp_KQ_mask_cross = nullptr ;
@@ -1174,9 +1174,11 @@ struct llm_build_context {
11741174 ggml_set_input (lctx.inp_K_shift );
11751175
11761176 for (int il = 0 ; il < n_layer; ++il) {
1177- const int64_t n_head_kv = hparams.n_head_kv (il);
1177+ const int64_t n_head_kv = hparams.n_head_kv (il);
11781178 const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
1179+
11791180 struct ggml_tensor * rope_factors = build_rope_factors (il);
1181+
11801182 struct ggml_tensor * k =
11811183 ggml_view_3d (ctx0, kv_self.k_l [il],
11821184 n_embd_head_k, n_head_kv, n_ctx,
@@ -1189,6 +1191,7 @@ struct llm_build_context {
11891191 // dequantize to f32 -> RoPE -> quantize back
11901192 tmp = ggml_cast (ctx0, k, GGML_TYPE_F32);
11911193 cb (tmp, " K_f32" , il);
1194+
11921195 for (auto & backend : lctx.backends ) {
11931196 // Figure out which backend KV cache belongs to
11941197 if (ggml_backend_supports_buft (backend.get (), ggml_backend_buffer_get_type (kv_self.k_l [il]->buffer ))) {
@@ -1200,6 +1203,7 @@ struct llm_build_context {
12001203 lctx.inp_K_shift , rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
12011204 ext_factor, attn_factor, beta_fast, beta_slow);
12021205 cb (tmp, " K_shifted_f32" , il);
1206+
12031207 tmp = ggml_cpy (ctx0, tmp, k);
12041208 } else {
12051209 // we rotate only the first n_rot dimensions
@@ -1208,6 +1212,7 @@ struct llm_build_context {
12081212 ext_factor, attn_factor, beta_fast, beta_slow);
12091213 }
12101214 cb (tmp, " K_shifted" , il);
1215+
12111216 ggml_build_forward_expand (gf, tmp);
12121217 }
12131218
@@ -9201,7 +9206,7 @@ static void llama_kv_self_update_impl(llama_context & lctx) {
92019206
92029207 ggml_backend_sched_alloc_graph (lctx.sched .get (), gf);
92039208
9204- llama_set_k_shift ( lctx);
9209+ lctx. set_k_shift (kv );
92059210
92069211 llama_graph_compute (lctx, gf, lctx.cparams .n_threads , lctx.threadpool );
92079212
0 commit comments