From cc4ff992f561da777714805d5cdb732889072e7f Mon Sep 17 00:00:00 2001 From: chiranko <96988916+chiranko@users.noreply.github.com> Date: Fri, 19 Jan 2024 11:02:17 +0800 Subject: [PATCH] llama.cpp: fix codeshell with NeoX rope Co-authored-by: Georgi Gerganov --- llama.cpp | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/llama.cpp b/llama.cpp index bc88b9591..d94e21dd6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6075,31 +6075,16 @@ struct llm_build_context { cb(tmpk, "tmpk", il); cb(Vcur, "Vcur", il); - // butterfly transform for q, k(an evil trick to correct the tensor order) - struct ggml_tensor * tmpq_transformed = ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens); - tmpq_transformed = ggml_permute(ctx0, ggml_reshape_4d(ctx0, tmpq_transformed, n_embd_head / 2, 2, n_head, n_tokens), 1, 0, 2, 3); - tmpq_transformed = ggml_reshape_3d(ctx0, ggml_cont(ctx0, tmpq_transformed), n_embd_head, n_head, n_tokens); - cb(tmpq_transformed, "tmpq_transformed", il); - - struct ggml_tensor * tmpk_transformed = ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens); - tmpk_transformed = ggml_permute(ctx0, ggml_reshape_4d(ctx0, tmpk_transformed, n_embd_head / 2, 2, n_head_kv, n_tokens), 1, 0, 2, 3); - tmpk_transformed = ggml_reshape_3d(ctx0, ggml_cont(ctx0, tmpk_transformed), n_embd_head, n_head_kv, n_tokens); - cb(tmpk_transformed, "tmpk_transformed", il); - - ggml_build_forward_expand(gf, tmpq_transformed); - ggml_build_forward_expand(gf, tmpk_transformed); - ggml_build_forward_expand(gf, Vcur); - struct ggml_tensor * Qcur = ggml_rope_custom( - ctx0, tmpq_transformed, inp_pos, - hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale, + ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, + hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); struct ggml_tensor * Kcur = ggml_rope_custom( - ctx0, tmpk_transformed, inp_pos, - hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale, + ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, + hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il);