llama : reuse build_moe_ffn()

This commit is contained in:
Georgi Gerganov 2024-04-16 13:42:19 +03:00
parent f88e6844a4
commit de3555119e
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -6587,7 +6587,7 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, il);
cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, true, il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
@ -6620,7 +6620,7 @@ struct llm_build_context {
}
// REVIEW: will be replaced by https://github.com/ggerganov/llama.cpp/pull/6505
ggml_tensor * build_moe_ffn(ggml_tensor * cur, int32_t n_tokens, llm_ffn_op_type type_op, int il) {
ggml_tensor * build_moe_ffn(ggml_tensor * cur, int32_t n_tokens, llm_ffn_op_type type_op, bool norm_w, int il) {
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
cb(logits, "ffn_moe_logits", il);
@ -6637,11 +6637,13 @@ struct llm_build_context {
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
cb(weights_sum, "ffn_moe_weights_sum", il);
if (norm_w) {
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
cb(weights_sum, "ffn_moe_weights_sum", il);
weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
cb(weights, "ffn_moe_weights_norm", il);
weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
cb(weights, "ffn_moe_weights_norm", il);
}
// compute expert outputs
ggml_tensor * moe_out = nullptr;
@ -7138,7 +7140,7 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
cur = build_moe_ffn(cur, n_tokens, LLM_FFN_GELU, il);
cur = build_moe_ffn(cur, n_tokens, LLM_FFN_GELU, true, il);
// Grok
// if layer_out_norm is present then apply it before adding the input
@ -7274,7 +7276,7 @@ struct llm_build_context {
LLM_NORM, cb, il);
cb(cur, "attn_out_norm", il);
cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, il);
cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, true, il);
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);
@ -8573,54 +8575,7 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
cb(logits, "ffn_moe_logits", il);
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
cb(probs, "ffn_moe_probs", il);
// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
cb(selected_experts->src[0], "ffn_moe_argsort", il);
ggml_tensor * weights = ggml_get_rows(ctx0,
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
cb(weights, "ffn_moe_weights", il);
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
// compute expert outputs
ggml_tensor * moe_out = nullptr;
for (int i = 0; i < n_expert_used; ++i) {
ggml_tensor * cur_expert;
ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
cb(cur_up, "ffn_moe_up", il);
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
cb(cur_gate, "ffn_moe_gate", il);
cur_gate = ggml_silu(ctx0, cur_gate);
cb(cur_gate, "ffn_moe_silu", il);
cur_expert = ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd]
cb(cur_expert, "ffn_moe_gate_par", il);
cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
cb(cur_expert, "ffn_moe_down", il);
cur_expert = ggml_mul(ctx0, cur_expert,
ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
cb(cur_expert, "ffn_moe_weighted", il);
if (i == 0) {
moe_out = cur_expert;
} else {
moe_out = ggml_add(ctx0, moe_out, cur_expert);
cb(moe_out, "ffn_moe_out", il);
}
}
ggml_tensor * moe_out = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, false, il);
// FFN shared expert
{