From acc877f4ab345cef5f27c02fef1e0a2dca7f82aa Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 13 Jul 2024 18:40:43 +0300 Subject: [PATCH] llama : fix Gemma-2 Query scaling factor ggml-ci --- src/llama.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index 59b76a6d8..0b84fcb46 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -11679,7 +11679,12 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); - Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); + // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e + switch (model.type) { + case e_model::MODEL_9B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); break; + case e_model::MODEL_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break; + default: GGML_ASSERT(false); + }; cb(Qcur, "Qcur_scaled", il); Kcur = ggml_rope_ext(