llama : simplify moe reshapes

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-04-18 13:53:01 +03:00
parent 4d8fe0764b
commit 2080a97c5b
No known key found for this signature in database
GPG key ID: BF970631944C16B7

View file

@ -6124,14 +6124,16 @@ static struct ggml_tensor * llm_build_moe_ffn(
ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights", il);
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
if (norm_w) {
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens]
cb(weights_sum, "ffn_moe_weights_sum", il);
weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights_norm", il);
weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
}
cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
@ -6162,8 +6164,7 @@ static struct ggml_tensor * llm_build_moe_ffn(
ggml_tensor * experts = ggml_mul_mat_id(ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
cb(experts, "ffn_moe_down", il);
experts = ggml_mul(ctx, experts,
ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens));
experts = ggml_mul(ctx, experts, weights);
// aggregate experts
ggml_tensor * moe_out = nullptr;