refactor: cleanup comments a bit

This commit is contained in:
Meng Zhang 2023-09-16 00:05:32 +08:00
parent caa722095a
commit 57eaa39c16
2 changed files with 5 additions and 8 deletions

View file

@ -209,6 +209,7 @@ for part_name in part_names:
data = data.squeeze().numpy() data = data.squeeze().numpy()
# TODO: implement MQA directly, instead of duplicate into MHA.
if name.endswith(".attn.c_attn.weight") or name.endswith(".attn.c_attn.bias"): if name.endswith(".attn.c_attn.weight") or name.endswith(".attn.c_attn.bias"):
print("Duplicate K,V heads to use MHA instead of MQA for", name) print("Duplicate K,V heads to use MHA instead of MQA for", name)

View file

@ -3620,19 +3620,16 @@ static struct ggml_cgraph * llm_build_starcoder(
// Projection // Projection
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wo, cur), model.layers[il].bo); cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wo, cur), model.layers[il].bo);
// add the input // Add the input
cur = ggml_add(ctx0, cur, inpL); cur = ggml_add(ctx0, cur, inpL);
struct ggml_tensor * inpFF = cur; struct ggml_tensor * inpFF = cur;
// FF // FF
{ {
// norm // Norm
{ {
cur = ggml_norm(ctx0, inpFF, norm_eps); cur = ggml_norm(ctx0, inpFF, norm_eps);
// cur = ln_2_g*cur + ln_2_b
// [ 768, N]
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b); cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b);
} }
@ -3641,14 +3638,14 @@ static struct ggml_cgraph * llm_build_starcoder(
// GELU activation // GELU activation
cur = ggml_gelu(ctx0, cur); cur = ggml_gelu(ctx0, cur);
// projection // Projection
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2); cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2);
} }
inpL = ggml_add(ctx0, cur, inpFF); inpL = ggml_add(ctx0, cur, inpFF);
} }
// norm // Output Norm
{ {
cur = ggml_norm(ctx0, inpL, norm_eps); cur = ggml_norm(ctx0, inpL, norm_eps);
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), model.output_norm_b); cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), model.output_norm_b);
@ -3661,7 +3658,6 @@ static struct ggml_cgraph * llm_build_starcoder(
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
ggml_free(ctx0); ggml_free(ctx0);
// norm
return gf; return gf;
} }