refactor moe ffn to llm_build_moe_ffn

This commit is contained in:
slaren 2024-04-07 21:14:23 +02:00
parent bc615548d3
commit f3f7627bd8

222
llama.cpp
View file

@ -5849,6 +5849,93 @@ static struct ggml_tensor * llm_build_ffn(
return cur; return cur;
} }
static struct ggml_tensor * llm_build_moe_ffn(
struct ggml_context * ctx,
struct ggml_tensor * cur,
struct ggml_tensor * gate_inp,
struct ggml_tensor * up_exps,
struct ggml_tensor * gate_exps,
struct ggml_tensor * down_exps,
int64_t n_expert,
int64_t n_expert_used,
llm_ffn_op_type type_op,
const llm_build_cb & cb,
int il) {
int64_t n_embd = cur->ne[0];
int64_t n_tokens = cur->ne[1];
ggml_tensor * logits = ggml_mul_mat(ctx, gate_inp, cur); // [n_expert, n_tokens]
cb(logits, "ffn_moe_logits", il);
ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
cb(probs, "ffn_moe_probs", il);
// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
cb(selected_experts->src[0], "ffn_moe_argsort", il);
cb(selected_experts, "ffn_moe_topk", il);
ggml_tensor * weights = ggml_get_rows(ctx,
ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights", il);
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens]
cb(weights_sum, "ffn_moe_weights_sum", il);
weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights_norm", il);
cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
ggml_tensor * up = ggml_mul_mat_id(ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(up, "ffn_moe_up", il);
ggml_tensor * gate = ggml_mul_mat_id(ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(gate, "ffn_moe_gate", il);
switch (type_op) {
case LLM_FFN_SILU:
{
gate = ggml_silu(ctx, gate);
cb(gate, "ffn_moe_silu", il);
} break;
case LLM_FFN_GELU:
{
gate = ggml_gelu(ctx, gate);
cb(gate, "ffn_moe_gelu", il);
} break;
default:
GGML_ASSERT(false);
}
ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens]
cb(par, "ffn_moe_gate_par", il);
ggml_tensor * experts = ggml_mul_mat_id(ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
cb(experts, "ffn_moe_down", il);
experts = ggml_mul(ctx, experts,
ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens));
// aggregate experts
ggml_tensor * moe_out = nullptr;
for (int i = 0; i < n_expert_used; ++i) {
ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
experts->nb[2], i*experts->nb[1]);
// FIXME: non-contiguous add broken in cuda
cur_expert = ggml_cont(ctx, cur_expert);
if (i == 0) {
moe_out = cur_expert;
} else {
moe_out = ggml_add(ctx, moe_out, cur_expert);
}
}
return moe_out;
}
// if max_alibi_bias > 0 then apply ALiBi // if max_alibi_bias > 0 then apply ALiBi
static struct ggml_tensor * llm_build_kqv( static struct ggml_tensor * llm_build_kqv(
struct ggml_context * ctx, struct ggml_context * ctx,
@ -6392,66 +6479,14 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il); LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il); cb(cur, "ffn_norm", il);
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_expert, n_tokens] cur = llm_build_moe_ffn(ctx0, cur,
cb(logits, "ffn_moe_logits", il); model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens] model.layers[il].ffn_gate_exps,
cb(probs, "ffn_moe_probs", il); model.layers[il].ffn_down_exps,
n_expert, n_expert_used,
// select experts LLM_FFN_SILU, cb, il);
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_expert_used, n_tokens] cb(cur, "ffn_moe_out", il);
cb(selected_experts->src[0], "ffn_moe_argsort", il);
cb(selected_experts, "ffn_moe_topk", il);
ggml_tensor * weights = ggml_get_rows(ctx0,
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights", il);
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
cb(weights_sum, "ffn_moe_weights_sum", il);
weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights_norm", il);
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
ggml_tensor * up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(up, "ffn_moe_up", il);
ggml_tensor * gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(gate, "ffn_moe_gate", il);
gate = ggml_silu(ctx0, gate);
cb(gate, "ffn_moe_silu", il);
ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
cb(par, "ffn_moe_gate_par", il);
ggml_tensor * experts = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
cb(experts, "ffn_moe_down", il);
experts = ggml_mul(ctx0, experts,
ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens));
// aggregate experts
ggml_tensor * moe_out = nullptr;
for (int i = 0; i < n_expert_used; ++i) {
ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
experts->nb[2], i*experts->nb[1]);
// FIXME: non-contiguous add broken in cuda
cur_expert = ggml_cont(ctx0, cur_expert);
if (i == 0) {
moe_out = cur_expert;
} else {
moe_out = ggml_add(ctx0, moe_out, cur_expert);
cb(moe_out, "ffn_moe_out", il);
}
}
cur = moe_out;
} }
cur = ggml_add(ctx0, cur, ffn_inp); cur = ggml_add(ctx0, cur, ffn_inp);
@ -6930,64 +6965,14 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il); LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il); cb(cur, "ffn_norm", il);
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts] cur = llm_build_moe_ffn(ctx0, cur,
cb(logits, "ffn_moe_logits", il); model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts] model.layers[il].ffn_gate_exps,
cb(probs, "ffn_moe_probs", il); model.layers[il].ffn_down_exps,
n_expert, n_expert_used,
// select experts LLM_FFN_GELU, cb, il);
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok] cb(cur, "ffn_moe_out", il);
cb(selected_experts->src[0], "ffn_moe_argsort", il);
ggml_tensor * weights = ggml_get_rows(ctx0,
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
cb(weights, "ffn_moe_weights", il);
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
cb(weights_sum, "ffn_moe_weights_sum", il);
weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
cb(weights, "ffn_moe_weights_norm", il);
// compute expert outputs
ggml_tensor * moe_out = nullptr;
for (int i = 0; i < n_expert_used; ++i) {
ggml_tensor * cur_expert;
// FIXME
ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, cur);
cb(cur_up, "ffn_moe_up", il);
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, cur);
cb(cur_gate, "ffn_moe_gate", il);
//GeLU
cur_gate = ggml_gelu(ctx0, cur_gate);
cb(cur_gate, "ffn_moe_gelu", il);
cur_expert = ggml_mul(ctx0, cur_up, cur_gate);
cb(cur_expert, "ffn_moe_gate_par", il);
cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, cur_expert); // [n_tokens, n_embd]
cb(cur_expert, "ffn_moe_down", il);
cur_expert = ggml_mul(ctx0, cur_expert,
ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
cb(cur_expert, "ffn_moe_weighted", il);
if (i == 0) {
moe_out = cur_expert;
} else {
moe_out = ggml_add(ctx0, moe_out, cur_expert);
cb(moe_out, "ffn_moe_out", il);
}
}
cur = moe_out;
// Grok // Grok
// if layer_out_norm is present then apply it before adding the input // if layer_out_norm is present then apply it before adding the input
@ -6999,7 +6984,6 @@ struct llm_build_context {
cb(cur, "layer_out_norm", il); cb(cur, "layer_out_norm", il);
} }
cur = ggml_add(ctx0, cur, ffn_inp); cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);