Fix llm_build_kqv
to use n_value_gqa
This commit is contained in:
parent
94d170b7e9
commit
e96fad12c5
1 changed files with 3 additions and 3 deletions
|
@ -4278,12 +4278,12 @@ static struct ggml_tensor * llm_build_kqv(
|
|||
float kq_scale,
|
||||
const llm_build_cb & cb,
|
||||
int il) {
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_head = hparams.n_head;
|
||||
const int64_t n_head_kv = hparams.n_head_kv;
|
||||
const int64_t n_key_dim = hparams.n_key_dim;
|
||||
const int64_t n_key_gqa = n_key_dim * hparams.n_head_kv;
|
||||
const int64_t n_key_gqa = hparams.n_key_gqa();
|
||||
const int64_t n_value_dim = hparams.n_value_dim;
|
||||
const int64_t n_value_gqa = hparams.n_value_gqa();
|
||||
|
||||
struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
|
||||
cb(q, "q", il);
|
||||
|
@ -4343,7 +4343,7 @@ static struct ggml_tensor * llm_build_kqv(
|
|||
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
|
||||
cb(kqv_merged, "kqv_merged", il);
|
||||
|
||||
struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd, n_tokens);
|
||||
struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_value_gqa, n_tokens);
|
||||
cb(cur, "kqv_merged_cont", il);
|
||||
|
||||
cur = ggml_mul_mat(ctx, wo, cur);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue