Fix SwiGlu2

This commit is contained in:
joshcarp 2024-04-29 23:22:02 -04:00
parent 0084a2a8d7
commit 9858fd1457

View file

@ -6242,6 +6242,7 @@ using llm_build_cb = std::function<void(struct ggml_tensor * cur, const char * n
enum llm_ffn_op_type { enum llm_ffn_op_type {
LLM_FFN_SILU, LLM_FFN_SILU,
LLM_FFN_SILU2,
LLM_FFN_GELU, LLM_FFN_GELU,
LLM_FFN_RELU, LLM_FFN_RELU,
LLM_FFN_RELU_SQR, LLM_FFN_RELU_SQR,
@ -6406,6 +6407,15 @@ static struct ggml_tensor * llm_build_ffn(
cur = ggml_silu(ctx, cur); cur = ggml_silu(ctx, cur);
cb(cur, "ffn_silu", il); cb(cur, "ffn_silu", il);
} break; } break;
case LLM_FFN_SILU2:
{
struct ggml_tensor * one = ggml_view_2d(ctx, cur, cur->ne[0]/2, cur->ne[1], cur->nb[1], 0);
int offset = (cur->ne[0]/2) * (cur->ne[1]);
struct ggml_tensor * two = ggml_view_2d(ctx, cur, cur->ne[0]/2, cur->ne[1], cur->nb[1], offset);
cur = ggml_mul(ctx, ggml_silu(ctx, one), two);
cb(cur, "ffn_silu", il);
} break;
case LLM_FFN_GELU: case LLM_FFN_GELU:
{ {
cur = ggml_gelu(ctx, cur); cur = ggml_gelu(ctx, cur);
@ -10734,7 +10744,8 @@ struct llm_build_context {
const int64_t n_head = n_head_kv+ num_query_heads[il]; const int64_t n_head = n_head_kv+ num_query_heads[il];
const int64_t n_kv = (num_kv_heads[il]+num_kv_heads[il])*n_embd_head; const int64_t n_kv = (num_kv_heads[il]+num_kv_heads[il])*n_embd_head;
modified_hparams.n_head = n_head; modified_hparams.n_head = n_head;
modified_hparams.n_head_kv = n_head_kv; // TODO, testing out setting this to total nmber of heads modified_hparams.n_head = 4*n_head_k; // somehow this works. Some places expect this to be groups*n_head_kv insteal of n_head. maybe this is the defintiion somewhere.
modified_hparams.n_head_kv = n_head_kv;
const int64_t n_embd_gqa = n_embd_head * n_head; const int64_t n_embd_gqa = n_embd_head * n_head;
const int64_t n_embd_k_gqa = modified_hparams.n_embd_k_gqa(); const int64_t n_embd_k_gqa = modified_hparams.n_embd_k_gqa();
const int64_t n_embd_v_gqa = modified_hparams.n_embd_v_gqa(); const int64_t n_embd_v_gqa = modified_hparams.n_embd_v_gqa();
@ -10840,7 +10851,6 @@ struct llm_build_context {
Kcur2->op = GGML_OP_REPEAT; Kcur2->op = GGML_OP_REPEAT;
Kcur2->grad = ggml_dup_tensor(ctx0, Vcur); Kcur2->grad = ggml_dup_tensor(ctx0, Vcur);
Kcur2 = ggml_reshape_2d(ctx0, Vcur2, modified_hparams.n_embd_k_gqa(), n_tokens); Kcur2 = ggml_reshape_2d(ctx0, Vcur2, modified_hparams.n_embd_k_gqa(), n_tokens);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, model, modified_hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, modified_hparams, kv_self, gf,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
@ -10866,12 +10876,14 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il); LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il); cb(cur, "ffn_norm", il);
// Need to figure this out now
cur = llm_build_ffn(ctx0, cur, cur = llm_build_ffn(ctx0, cur,
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up, NULL,
NULL, NULL, NULL, NULL,
model.layers[il].ffn_down, NULL, model.layers[il].ffn_down, NULL,
NULL, NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il); LLM_FFN_SILU2, LLM_FFN_SEQ, cb, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }