llama : fix Gemma-2 Query scaling factor

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-07-13 18:40:43 +03:00
parent df78f19612
commit acc877f4ab
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -11679,7 +11679,12 @@ struct llm_build_context {
ext_factor, attn_factor, beta_fast, beta_slow); ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
switch (model.type) {
case e_model::MODEL_9B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); break;
case e_model::MODEL_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
default: GGML_ASSERT(false);
};
cb(Qcur, "Qcur_scaled", il); cb(Qcur, "Qcur_scaled", il);
Kcur = ggml_rope_ext( Kcur = ggml_rope_ext(