llama: dbrx: move norm2 after attention, fix build kv

This commit is contained in:
Pierrick HYMBERT 2024-04-08 00:11:19 +02:00
parent 2897aa628c
commit 993f836029

View file

@ -7125,12 +7125,6 @@ struct llm_build_context {
LLM_NORM, cb, il); LLM_NORM, cb, il);
cb(cur, "attn_norm", il); cb(cur, "attn_norm", il);
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm_2,
NULL,
LLM_NORM, cb, il);
cb(cur, "attn_norm_2", il);
// self-attention // self-attention
{ {
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
@ -7161,7 +7155,7 @@ struct llm_build_context {
cb(Vcur, "Vcur", il); cb(Vcur, "Vcur", il);
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
NULL, NULL, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
} }
@ -7179,9 +7173,8 @@ struct llm_build_context {
// feed-forward network // feed-forward network
// MoE branch // MoE branch
{ {
// FIXME REVIEW: I do not see this op in https://huggingface.co/databricks/dbrx-instruct/blob/464e701f50aef4c1b59c81fb5667819a5d08e108/modeling_dbrx.py#L727
cur = llm_build_norm(ctx0, ffn_inp, hparams, cur = llm_build_norm(ctx0, ffn_inp, hparams,
NULL, NULL, model.layers[il].attn_norm_2, NULL,
LLM_NORM, cb, il); LLM_NORM, cb, il);
cb(cur, "ffn_norm", il); cb(cur, "ffn_norm", il);