Fix llm_build_kqv to be more generic wrt n_embd_head_k

This commit is contained in:
Nam Nguyen 2023-12-30 14:26:46 -08:00
parent 51e251a83c
commit f56cbce6b7

View file

@ -4281,7 +4281,6 @@ static struct ggml_tensor * llm_build_kqv(
float kq_scale, float kq_scale,
const llm_build_cb & cb, const llm_build_cb & cb,
int il) { int il) {
const int64_t n_embd = hparams.n_embd;
const int64_t n_head = hparams.n_head; const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head_k = hparams.n_embd_head_k; const int64_t n_embd_head_k = hparams.n_embd_head_k;
@ -4346,7 +4345,7 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
cb(kqv_merged, "kqv_merged", il); cb(kqv_merged, "kqv_merged", il);
struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd, n_tokens); struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
cb(cur, "kqv_merged_cont", il); cb(cur, "kqv_merged_cont", il);
cur = ggml_mul_mat(ctx, wo, cur); cur = ggml_mul_mat(ctx, wo, cur);