From 4d3f17b4ac2f6e51910c5d17c7a71f001dc663a8 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 28 Jun 2024 15:42:19 -0400 Subject: [PATCH 1/8] Add attention and final logit softcapping. --- convert-hf-to-gguf.py | 8 ++++++++ gguf-py/gguf/constants.py | 2 ++ gguf-py/gguf/gguf_writer.py | 3 +++ src/llama.cpp | 27 +++++++++++++++++++++++++-- 4 files changed, 38 insertions(+), 2 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 5bcc849db..0c7f945ba 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2363,6 +2363,14 @@ class Gemma2Model(Model): self.gguf_writer.add_key_length(hparams["head_dim"]) self.gguf_writer.add_value_length(hparams["head_dim"]) self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_float32( + gguf.Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.model_arch), + self.hparams["attn_logit_softcapping"] + ) + self.gguf_writer.add_float32( + gguf.Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.model_arch), + self.hparams["final_logit_softcapping"] + ) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unusem diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index cf3d09e70..9bfa891d5 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -50,6 +50,8 @@ class Keys: POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" + ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping" + FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping" class Attention: HEAD_COUNT = "{arch}.attention.head_count" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 9869f6fe3..0f1153d36 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -627,6 +627,9 @@ class GGUFWriter: def add_mask_token_id(self, id: int) -> None: self.add_uint32(Keys.Tokenizer.MASK_ID, id) + def add_eot_token_id(self, id: int) -> None: + self.add_uint32(Keys.Tokenizer.EOT_ID, id) + def add_add_bos_token(self, value: bool) -> None: self.add_bool(Keys.Tokenizer.ADD_BOS, value) diff --git a/src/llama.cpp b/src/llama.cpp index 3edaa98e8..39d80d33e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -302,6 +302,8 @@ enum llm_kv { LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, + LLM_KV_ATTN_LOGIT_SOFTCAPPING, + LLM_KV_FINAL_LOGIT_SOFTCAPPING, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -392,6 +394,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, + { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, + { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -2099,6 +2103,9 @@ struct llama_hparams { float f_norm_eps; float f_norm_rms_eps; + float f_attn_logit_softcapping; + float f_final_logit_softcapping; + float rope_attn_factor = 1.0f; float rope_freq_base_train; float rope_freq_scale_train; @@ -2115,8 +2122,9 @@ struct llama_hparams { float f_max_alibi_bias = 0.0f; float f_logit_scale = 0.0f; - bool causal_attn = true; - bool use_alibi = false; + bool causal_attn = true; + bool use_alibi = false; + bool attn_soft_cap = false; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; @@ -4702,6 +4710,9 @@ static void llm_load_hparams( case LLM_ARCH_GEMMA2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping); + hparams.attn_soft_cap = true; switch (hparams.n_layer) { case 42: model.type = e_model::MODEL_9B; break; @@ -7579,6 +7590,12 @@ static struct ggml_tensor * llm_build_kqv( kq = ggml_scale(ctx, kq, 30); } + if (hparams.attn_soft_cap) { + kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping); + kq = ggml_tanh(ctx, kq); + kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping); + } + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); @@ -11106,6 +11123,12 @@ struct llm_build_context { // lm_head cur = ggml_mul_mat(ctx0, model.output, cur); + + // final logit soft-capping + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); + cur = ggml_tanh(ctx0, cur); + cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); + cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); From d3d3c4eb35075e826e1f8c5b5a0b11042aa7ca6e Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 28 Jun 2024 15:46:45 -0400 Subject: [PATCH 2/8] fix --- gguf-py/gguf/gguf_writer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 0f1153d36..9869f6fe3 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -627,9 +627,6 @@ class GGUFWriter: def add_mask_token_id(self, id: int) -> None: self.add_uint32(Keys.Tokenizer.MASK_ID, id) - def add_eot_token_id(self, id: int) -> None: - self.add_uint32(Keys.Tokenizer.EOT_ID, id) - def add_add_bos_token(self, value: bool) -> None: self.add_bool(Keys.Tokenizer.ADD_BOS, value) From d1137c20f12f5b546b3ff8ebb1bba99570083b26 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 28 Jun 2024 15:58:02 -0400 Subject: [PATCH 3/8] Add custom add_ functions --- convert-hf-to-gguf.py | 6 ++---- gguf-py/gguf/gguf_writer.py | 6 ++++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 0c7f945ba..3ef2f69e7 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2363,12 +2363,10 @@ class Gemma2Model(Model): self.gguf_writer.add_key_length(hparams["head_dim"]) self.gguf_writer.add_value_length(hparams["head_dim"]) self.gguf_writer.add_file_type(self.ftype) - self.gguf_writer.add_float32( - gguf.Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.model_arch), + self.gguf_writer.add_attn_logit_softcapping( self.hparams["attn_logit_softcapping"] ) - self.gguf_writer.add_float32( - gguf.Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.model_arch), + self.gguf_writer.add_final_logit_softcapping( self.hparams["final_logit_softcapping"] ) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 9869f6fe3..1aeb0d9b0 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -516,6 +516,12 @@ class GGUFWriter: def add_logit_scale(self, value: float) -> None: self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value) + def add_attn_logit_softcapping(self, value: float) -> None: + self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value) + + def add_final_logit_softcapping(self, value: float) -> None: + self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value) + def add_expert_count(self, count: int) -> None: self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count) From f4424c150f1181cb8b2f2cb6a700c821c789facc Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 28 Jun 2024 16:00:20 -0400 Subject: [PATCH 4/8] Disable flash attention for Gemma2 --- src/llama.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/llama.cpp b/src/llama.cpp index 39d80d33e..d2ee2d6ad 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17402,6 +17402,11 @@ struct llama_context * llama_new_context_with_model( params.flash_attn = false; } + if (params.flash_attn && model->arch == LLM_ARCH_GEMMA2) { + LLAMA_LOG_WARN("%s: flash_attn is not compatible with Gemma2 - forcing off\n", __func__); + params.flash_attn = false; + } + if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) { LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__); params.flash_attn = false; From 3a2471811f7a91511b6907dd5050e2ba1de475e0 Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 28 Jun 2024 16:07:47 -0400 Subject: [PATCH 5/8] Update src/llama.cpp Co-authored-by: slaren --- src/llama.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index d2ee2d6ad..cb8ce3b8c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17402,11 +17402,12 @@ struct llama_context * llama_new_context_with_model( params.flash_attn = false; } - if (params.flash_attn && model->arch == LLM_ARCH_GEMMA2) { - LLAMA_LOG_WARN("%s: flash_attn is not compatible with Gemma2 - forcing off\n", __func__); + if (params.flash_attn && model->hparams.attn_soft_cap) { + LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__); params.flash_attn = false; } + if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) { LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__); params.flash_attn = false; From bb7159927dfcd94fd550dc673d03ce02e392aed4 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sat, 29 Jun 2024 01:10:55 -0400 Subject: [PATCH 6/8] Add default value for attention and final logit softcap value --- src/llama.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index d2ee2d6ad..4f1447d79 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2103,8 +2103,8 @@ struct llama_hparams { float f_norm_eps; float f_norm_rms_eps; - float f_attn_logit_softcapping; - float f_final_logit_softcapping; + float f_attn_logit_softcapping = 50.0f; + float f_final_logit_softcapping = 30.0f; float rope_attn_factor = 1.0f; float rope_freq_base_train; @@ -4710,8 +4710,8 @@ static void llm_load_hparams( case LLM_ARCH_GEMMA2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping); - ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping); + ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); hparams.attn_soft_cap = true; switch (hparams.n_layer) { From a89427908d04fcf3b4e975724596efddce4db737 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sat, 29 Jun 2024 10:17:33 -0400 Subject: [PATCH 7/8] Add custom kq scaling from Gemma2Attention --- convert-hf-to-gguf.py | 3 +++ gguf-py/gguf/constants.py | 1 + gguf-py/gguf/gguf_writer.py | 3 +++ src/llama.cpp | 6 +++++- 4 files changed, 12 insertions(+), 1 deletion(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 3ef2f69e7..23a357343 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2369,6 +2369,9 @@ class Gemma2Model(Model): self.gguf_writer.add_final_logit_softcapping( self.hparams["final_logit_softcapping"] ) + self.gguf_writer.add_query_pre_attn_scalar( + self.hparams["query_pre_attn_scalar"] + ) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unusem diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 9bfa891d5..eab5cbf69 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -52,6 +52,7 @@ class Keys: DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping" FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping" + QUERY_PRE_ATTN_SCALAR = "{arch}.query_pre_attn_scalar" class Attention: HEAD_COUNT = "{arch}.attention.head_count" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 1aeb0d9b0..37c41a5bf 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -522,6 +522,9 @@ class GGUFWriter: def add_final_logit_softcapping(self, value: float) -> None: self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value) + def add_query_pre_attn_scalar(self, value: float) -> None: + self.add_float32(Keys.LLM.QUERY_PRE_ATTN_SCALAR.format(arch=self.arch), value) + def add_expert_count(self, count: int) -> None: self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count) diff --git a/src/llama.cpp b/src/llama.cpp index 9654ffad3..56a6898c3 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -304,6 +304,7 @@ enum llm_kv { LLM_KV_DECODER_START_TOKEN_ID, LLM_KV_ATTN_LOGIT_SOFTCAPPING, LLM_KV_FINAL_LOGIT_SOFTCAPPING, + LLM_KV_QUERY_PRE_ATTN_SCALAR, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -396,6 +397,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, + { LLM_KV_QUERY_PRE_ATTN_SCALAR, "%s.query_pre_attn_scalar" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -2105,6 +2107,7 @@ struct llama_hparams { float f_attn_logit_softcapping = 50.0f; float f_final_logit_softcapping = 30.0f; + float f_query_pre_attn_scalar = 144.0f; float rope_attn_factor = 1.0f; float rope_freq_base_train; @@ -4712,6 +4715,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + ml.get_key(LLM_KV_QUERY_PRE_ATTN_SCALAR, hparams.f_query_pre_attn_scalar, false); hparams.attn_soft_cap = true; switch (hparams.n_layer) { @@ -10948,7 +10952,7 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); - Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); + Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(hparams.f_query_pre_attn_scalar)); cb(Qcur, "Qcur_scaled", il); Kcur = ggml_rope_ext( From 51f0bd50a1c276adfd85e6c7f7f3ce670683d540 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sat, 29 Jun 2024 23:02:50 -0400 Subject: [PATCH 8/8] Remove custom pre attention scaling and use computed value instead. --- convert-hf-to-gguf.py | 3 --- gguf-py/gguf/constants.py | 1 - gguf-py/gguf/gguf_writer.py | 3 --- src/llama.cpp | 8 ++------ 4 files changed, 2 insertions(+), 13 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 23a357343..3ef2f69e7 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2369,9 +2369,6 @@ class Gemma2Model(Model): self.gguf_writer.add_final_logit_softcapping( self.hparams["final_logit_softcapping"] ) - self.gguf_writer.add_query_pre_attn_scalar( - self.hparams["query_pre_attn_scalar"] - ) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unusem diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index eab5cbf69..9bfa891d5 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -52,7 +52,6 @@ class Keys: DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping" FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping" - QUERY_PRE_ATTN_SCALAR = "{arch}.query_pre_attn_scalar" class Attention: HEAD_COUNT = "{arch}.attention.head_count" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 37c41a5bf..1aeb0d9b0 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -522,9 +522,6 @@ class GGUFWriter: def add_final_logit_softcapping(self, value: float) -> None: self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value) - def add_query_pre_attn_scalar(self, value: float) -> None: - self.add_float32(Keys.LLM.QUERY_PRE_ATTN_SCALAR.format(arch=self.arch), value) - def add_expert_count(self, count: int) -> None: self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count) diff --git a/src/llama.cpp b/src/llama.cpp index 56a6898c3..2a4d73856 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -304,7 +304,6 @@ enum llm_kv { LLM_KV_DECODER_START_TOKEN_ID, LLM_KV_ATTN_LOGIT_SOFTCAPPING, LLM_KV_FINAL_LOGIT_SOFTCAPPING, - LLM_KV_QUERY_PRE_ATTN_SCALAR, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -397,7 +396,6 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, - { LLM_KV_QUERY_PRE_ATTN_SCALAR, "%s.query_pre_attn_scalar" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -2107,7 +2105,6 @@ struct llama_hparams { float f_attn_logit_softcapping = 50.0f; float f_final_logit_softcapping = 30.0f; - float f_query_pre_attn_scalar = 144.0f; float rope_attn_factor = 1.0f; float rope_freq_base_train; @@ -4715,7 +4712,6 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); - ml.get_key(LLM_KV_QUERY_PRE_ATTN_SCALAR, hparams.f_query_pre_attn_scalar, false); hparams.attn_soft_cap = true; switch (hparams.n_layer) { @@ -10952,7 +10948,7 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); - Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(hparams.f_query_pre_attn_scalar)); + Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); cb(Qcur, "Qcur_scaled", il); Kcur = ggml_rope_ext( @@ -11060,7 +11056,7 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); - Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); + Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); cb(Qcur, "Qcur_scaled", il); Kcur = ggml_rope_ext(