diff --git a/llama.cpp b/llama.cpp index 3264181ec..489b12028 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6124,14 +6124,16 @@ static struct ggml_tensor * llm_build_moe_ffn( ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] cb(weights, "ffn_moe_weights", il); - weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens); - if (norm_w) { + weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens); + ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens] cb(weights_sum, "ffn_moe_weights_sum", il); weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens] cb(weights, "ffn_moe_weights_norm", il); + + weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens); } cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens); @@ -6162,8 +6164,7 @@ static struct ggml_tensor * llm_build_moe_ffn( ggml_tensor * experts = ggml_mul_mat_id(ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); - experts = ggml_mul(ctx, experts, - ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens)); + experts = ggml_mul(ctx, experts, weights); // aggregate experts ggml_tensor * moe_out = nullptr;