From ff0899c9b377fb71c7cd3281cdfb89f8dd7a033e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Jan 2024 10:46:39 +0200 Subject: [PATCH] llama : fix llm_build_k_shift to use correct n_rot ggml-ci --- common/common.cpp | 3 +++ llama.cpp | 20 +++++++++----------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index b2cb0e257..3aefed01d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1055,6 +1055,9 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & } static ggml_type kv_cache_type_from_str(const std::string & s) { + if (s == "f32") { + return GGML_TYPE_F32; + } if (s == "f16") { return GGML_TYPE_F16; } diff --git a/llama.cpp b/llama.cpp index d39ff94c7..65830e54a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4104,7 +4104,6 @@ static void llm_build_k_shift( struct ggml_cgraph * graph, llm_rope_type type, int64_t n_ctx, - int n_rot, float freq_base, float freq_scale, const llm_build_cb & cb) { @@ -4112,14 +4111,13 @@ static void llm_build_k_shift( const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head_k = hparams.n_embd_head_k; const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const int32_t n_rot = hparams.n_rot; const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx; const float ext_factor = cparams.yarn_ext_factor; const float attn_factor = cparams.yarn_attn_factor; const float beta_fast = cparams.yarn_beta_fast; const float beta_slow = cparams.yarn_beta_slow; - GGML_ASSERT(n_embd_head_k % n_rot == 0); - struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx); cb(K_shift, "K_shift", -1); @@ -4523,7 +4521,7 @@ struct llm_build_context { // shift the entire K-cache if needed if (do_rope_shift) { - llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb); } for (int il = 0; il < n_layer; ++il) { @@ -4708,7 +4706,7 @@ struct llm_build_context { // shift the entire K-cache if needed if (do_rope_shift) { - llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb); } for (int il = 0; il < n_layer; ++il) { @@ -4829,7 +4827,7 @@ struct llm_build_context { // shift the entire K-cache if needed if (do_rope_shift) { - llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb); } for (int il = 0; il < n_layer; ++il) { @@ -5052,7 +5050,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); if (do_rope_shift) { - llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb); } for (int il = 0; il < n_layer; ++il) { @@ -5548,7 +5546,7 @@ struct llm_build_context { // shift the entire K-cache if needed if (do_rope_shift) { - llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, hparams.n_rot, freq_base, freq_scale, cb); + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb); } for (int il = 0; il < n_layer; ++il) { @@ -5661,7 +5659,7 @@ struct llm_build_context { // shift the entire K-cache if needed if (do_rope_shift) { - llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb); } for (int il = 0; il < n_layer; ++il) { @@ -5778,7 +5776,7 @@ struct llm_build_context { // shift the entire K-cache if needed if (do_rope_shift) { - llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb); } for (int il = 0; il < n_layer; ++il) { @@ -5891,7 +5889,7 @@ struct llm_build_context { // shift the entire K-cache if needed if (do_rope_shift) { - llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb); } for (int il = 0; il < n_layer; ++il) {