Fix SwiGlu2
This commit is contained in:
parent
0084a2a8d7
commit
9858fd1457
1 changed files with 15 additions and 3 deletions
18
llama.cpp
18
llama.cpp
|
@ -6242,6 +6242,7 @@ using llm_build_cb = std::function<void(struct ggml_tensor * cur, const char * n
|
|||
|
||||
enum llm_ffn_op_type {
|
||||
LLM_FFN_SILU,
|
||||
LLM_FFN_SILU2,
|
||||
LLM_FFN_GELU,
|
||||
LLM_FFN_RELU,
|
||||
LLM_FFN_RELU_SQR,
|
||||
|
@ -6406,6 +6407,15 @@ static struct ggml_tensor * llm_build_ffn(
|
|||
cur = ggml_silu(ctx, cur);
|
||||
cb(cur, "ffn_silu", il);
|
||||
} break;
|
||||
case LLM_FFN_SILU2:
|
||||
{
|
||||
struct ggml_tensor * one = ggml_view_2d(ctx, cur, cur->ne[0]/2, cur->ne[1], cur->nb[1], 0);
|
||||
int offset = (cur->ne[0]/2) * (cur->ne[1]);
|
||||
struct ggml_tensor * two = ggml_view_2d(ctx, cur, cur->ne[0]/2, cur->ne[1], cur->nb[1], offset);
|
||||
cur = ggml_mul(ctx, ggml_silu(ctx, one), two);
|
||||
cb(cur, "ffn_silu", il);
|
||||
|
||||
} break;
|
||||
case LLM_FFN_GELU:
|
||||
{
|
||||
cur = ggml_gelu(ctx, cur);
|
||||
|
@ -10734,7 +10744,8 @@ struct llm_build_context {
|
|||
const int64_t n_head = n_head_kv+ num_query_heads[il];
|
||||
const int64_t n_kv = (num_kv_heads[il]+num_kv_heads[il])*n_embd_head;
|
||||
modified_hparams.n_head = n_head;
|
||||
modified_hparams.n_head_kv = n_head_kv; // TODO, testing out setting this to total nmber of heads
|
||||
modified_hparams.n_head = 4*n_head_k; // somehow this works. Some places expect this to be groups*n_head_kv insteal of n_head. maybe this is the defintiion somewhere.
|
||||
modified_hparams.n_head_kv = n_head_kv;
|
||||
const int64_t n_embd_gqa = n_embd_head * n_head;
|
||||
const int64_t n_embd_k_gqa = modified_hparams.n_embd_k_gqa();
|
||||
const int64_t n_embd_v_gqa = modified_hparams.n_embd_v_gqa();
|
||||
|
@ -10840,7 +10851,6 @@ struct llm_build_context {
|
|||
Kcur2->op = GGML_OP_REPEAT;
|
||||
Kcur2->grad = ggml_dup_tensor(ctx0, Vcur);
|
||||
Kcur2 = ggml_reshape_2d(ctx0, Vcur2, modified_hparams.n_embd_k_gqa(), n_tokens);
|
||||
|
||||
cb(Kcur, "Kcur", il);
|
||||
cur = llm_build_kv(ctx0, model, modified_hparams, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
|
@ -10866,12 +10876,14 @@ struct llm_build_context {
|
|||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// Need to figure this out now
|
||||
|
||||
cur = llm_build_ffn(ctx0, cur,
|
||||
model.layers[il].ffn_up, NULL,
|
||||
NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
LLM_FFN_SILU2, LLM_FFN_SEQ, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue