remove plamo_llm_build_kqv and use llm_build_kqv

This commit is contained in:
okada 2023-12-24 18:14:12 +09:00
parent db1b18dc97
commit 26340a1902

View file

@ -5573,79 +5573,9 @@ struct llm_build_context {
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
auto plamo_llm_build_kqv = [](
struct ggml_context * ctx,
const llama_hparams & hparams,
const llama_kv_cache & kv,
struct ggml_tensor * wo,
struct ggml_tensor * q_cur,
struct ggml_tensor * kq_mask,
int64_t n_ctx,
int32_t n_tokens,
int32_t n_kv,
const llm_build_cb & cb,
int il) {
const int64_t n_embd = hparams.n_embd;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head = hparams.n_embd_head();
const int64_t n_embd_gqa = hparams.n_embd_gqa();
struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
cb(q, "q", il);
struct ggml_tensor * k =
ggml_view_3d(ctx, kv.k_l[il],
n_embd_head, n_kv, n_head_kv,
ggml_row_size(kv.k_l[il]->type, n_embd_gqa),
ggml_row_size(kv.k_l[il]->type, n_embd_head),
0);
cb(k, "k", il);
/*
// we should avoid to repeat K but current ggml_mul_mat generates wrong values for grouped query att
struct ggml_tensor * k_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k->ne[0], k->ne[1], q->ne[2]);
cb(k_repeated, "k_repeated", il);
struct ggml_tensor * kq = ggml_mul_mat(ctx, ggml_repeat(ctx, k, k_repeated), q);
*/
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head)));
cb(kq, "kq_soft_max_ext", il);
// split cached v into n_head heads
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
n_kv, n_embd_head, n_head_kv,
ggml_element_size(kv.v_l[il])*n_ctx,
ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head,
0);
cb(v, "v", il);
/*
// we should avoid to repeat V but current ggml_mul_mat generates wrong values for grouped query att
struct ggml_tensor * v_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, v->ne[0], v->ne[1], q->ne[2]);
cb(k_repeated, "v_repeated", il);
struct ggml_tensor * kqv = ggml_mul_mat(ctx, ggml_repeat(ctx, v, v_repeated), kq);
*/
struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
cb(kqv, "kqv", il);
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
cb(kqv_merged, "kqv_merged", il);
struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd, n_tokens);
cb(cur, "kqv_merged_cont", il);
cur = ggml_mul_mat(ctx, wo, cur);
return cur;
};
cur = plamo_llm_build_kqv(ctx0, hparams, kv_self,
model.layers[il].wo,
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, cb, il);
cur = llm_build_kqv(ctx0, hparams, kv_self,
model.layers[il].wo, NULL,
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 0.0f, cb, il);
cb(cur, "kqv_out", il);
}
struct ggml_tensor * sa_out = cur;