From 5ed3e1a8f27212bc09c174d11ce5c9c96cc48ec7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 24 Nov 2023 18:58:03 +0200 Subject: [PATCH] llama : fix llm_build_k_shift args --- llama.cpp | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/llama.cpp b/llama.cpp index 370ca6b7a..85c0ee0c1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3459,8 +3459,7 @@ static void llm_build_k_shift( struct ggml_cgraph * graph, llm_rope_type type, int64_t n_ctx, - int64_t n_rot, - int n_dims, + int n_rot, float freq_base, float freq_scale, const llm_build_cb & cb) { @@ -3492,32 +3491,17 @@ static void llm_build_k_shift( // we rotate only the first n_rot dimensions ggml_rope_custom_inplace(ctx, ggml_view_3d(ctx, kv.k, - n_rot, n_head_kv, n_ctx, + n_embd_head, n_head_kv, n_ctx, ggml_element_size(kv.k)*n_embd_head, ggml_element_size(kv.k)*n_embd_gqa, ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il), - K_shift, n_dims, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(tmp, "K_shifted", il); ggml_build_forward_expand(graph, tmp); } } -static void llm_build_k_shift( - struct ggml_context * ctx, - const llama_hparams & hparams, - const llama_cparams & cparams, - const llama_kv_cache & kv, - struct ggml_cgraph * graph, - llm_rope_type type, - int64_t n_ctx, - int64_t n_rot, - float freq_base, - float freq_scale, - const llm_build_cb & cb) { - llm_build_k_shift(ctx, hparams, cparams, kv, graph, type, n_ctx, n_rot, n_rot, freq_base, freq_scale, cb); -} - static void llm_build_kv_store( struct ggml_context * ctx, const llama_hparams & hparams, @@ -4814,7 +4798,7 @@ struct llm_build_context { // shift the entire K-cache if needed if (do_rope_shift) { - llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, hparams.n_rot, freq_base, freq_scale, cb); + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, hparams.n_rot, freq_base, freq_scale, cb); } for (int il = 0; il < n_layer; ++il) {