From 9858fd14574f4f9df6fde19d608f3c6a940a71e6 Mon Sep 17 00:00:00 2001 From: joshcarp Date: Mon, 29 Apr 2024 23:22:02 -0400 Subject: [PATCH] Fix SwiGlu2 --- llama.cpp | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index 93618bf85..89fe328e7 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6242,6 +6242,7 @@ using llm_build_cb = std::functionne[0]/2, cur->ne[1], cur->nb[1], 0); + int offset = (cur->ne[0]/2) * (cur->ne[1]); + struct ggml_tensor * two = ggml_view_2d(ctx, cur, cur->ne[0]/2, cur->ne[1], cur->nb[1], offset); + cur = ggml_mul(ctx, ggml_silu(ctx, one), two); + cb(cur, "ffn_silu", il); + + } break; case LLM_FFN_GELU: { cur = ggml_gelu(ctx, cur); @@ -10734,7 +10744,8 @@ struct llm_build_context { const int64_t n_head = n_head_kv+ num_query_heads[il]; const int64_t n_kv = (num_kv_heads[il]+num_kv_heads[il])*n_embd_head; modified_hparams.n_head = n_head; - modified_hparams.n_head_kv = n_head_kv; // TODO, testing out setting this to total nmber of heads + modified_hparams.n_head = 4*n_head_k; // somehow this works. Some places expect this to be groups*n_head_kv insteal of n_head. maybe this is the defintiion somewhere. + modified_hparams.n_head_kv = n_head_kv; const int64_t n_embd_gqa = n_embd_head * n_head; const int64_t n_embd_k_gqa = modified_hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = modified_hparams.n_embd_v_gqa(); @@ -10840,7 +10851,6 @@ struct llm_build_context { Kcur2->op = GGML_OP_REPEAT; Kcur2->grad = ggml_dup_tensor(ctx0, Vcur); Kcur2 = ggml_reshape_2d(ctx0, Vcur2, modified_hparams.n_embd_k_gqa(), n_tokens); - cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, model, modified_hparams, kv_self, gf, model.layers[il].wo, NULL, @@ -10866,12 +10876,14 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); + // Need to figure this out now + cur = llm_build_ffn(ctx0, cur, model.layers[il].ffn_up, NULL, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, - LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + LLM_FFN_SILU2, LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); }