diff --git a/llama.cpp b/llama.cpp index 78d71b117..cb769ec66 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4278,12 +4278,12 @@ static struct ggml_tensor * llm_build_kqv( float kq_scale, const llm_build_cb & cb, int il) { - const int64_t n_embd = hparams.n_embd; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_key_dim = hparams.n_key_dim; - const int64_t n_key_gqa = n_key_dim * hparams.n_head_kv; + const int64_t n_key_gqa = hparams.n_key_gqa(); const int64_t n_value_dim = hparams.n_value_dim; + const int64_t n_value_gqa = hparams.n_value_gqa(); struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3); cb(q, "q", il); @@ -4343,7 +4343,7 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); cb(kqv_merged, "kqv_merged", il); - struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd, n_tokens); + struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_value_gqa, n_tokens); cb(cur, "kqv_merged_cont", il); cur = ggml_mul_mat(ctx, wo, cur);