llama.cpp: fix codeshell with NeoX rope

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
chiranko 2024-01-19 11:02:17 +08:00 committed by GitHub
parent d70e48dedf
commit cc4ff992f5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -6075,31 +6075,16 @@ struct llm_build_context {
cb(tmpk, "tmpk", il); cb(tmpk, "tmpk", il);
cb(Vcur, "Vcur", il); cb(Vcur, "Vcur", il);
// butterfly transform for q, k(an evil trick to correct the tensor order)
struct ggml_tensor * tmpq_transformed = ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens);
tmpq_transformed = ggml_permute(ctx0, ggml_reshape_4d(ctx0, tmpq_transformed, n_embd_head / 2, 2, n_head, n_tokens), 1, 0, 2, 3);
tmpq_transformed = ggml_reshape_3d(ctx0, ggml_cont(ctx0, tmpq_transformed), n_embd_head, n_head, n_tokens);
cb(tmpq_transformed, "tmpq_transformed", il);
struct ggml_tensor * tmpk_transformed = ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens);
tmpk_transformed = ggml_permute(ctx0, ggml_reshape_4d(ctx0, tmpk_transformed, n_embd_head / 2, 2, n_head_kv, n_tokens), 1, 0, 2, 3);
tmpk_transformed = ggml_reshape_3d(ctx0, ggml_cont(ctx0, tmpk_transformed), n_embd_head, n_head_kv, n_tokens);
cb(tmpk_transformed, "tmpk_transformed", il);
ggml_build_forward_expand(gf, tmpq_transformed);
ggml_build_forward_expand(gf, tmpk_transformed);
ggml_build_forward_expand(gf, Vcur);
struct ggml_tensor * Qcur = ggml_rope_custom( struct ggml_tensor * Qcur = ggml_rope_custom(
ctx0, tmpq_transformed, inp_pos, ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos,
hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale, hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow ext_factor, attn_factor, beta_fast, beta_slow
); );
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
struct ggml_tensor * Kcur = ggml_rope_custom( struct ggml_tensor * Kcur = ggml_rope_custom(
ctx0, tmpk_transformed, inp_pos, ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos,
hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale, hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow ext_factor, attn_factor, beta_fast, beta_slow
); );
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);