From a371a8b611eb69476208f400ae8ba4271f80d10d Mon Sep 17 00:00:00 2001 From: Galunid Date: Fri, 10 Nov 2023 05:54:46 +0100 Subject: [PATCH] Use ggml_view_3d --- llama.cpp | 102 ++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 87 insertions(+), 15 deletions(-) diff --git a/llama.cpp b/llama.cpp index 100c7e6fc..35d79c284 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4323,7 +4323,7 @@ struct llm_build_context { struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass); cb(Kcur, "Kcur", il); - struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 1, 2, 0, 3)); + struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 2, 1, 0, 3)); cb(Q, "Q", il); Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3)); @@ -4710,34 +4710,106 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); + struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(tmpq, "tmpq", il); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); + struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(tmpk, "tmpk", il); struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, - hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow + // RoPE the first n_rot of q/k, pass the other half, and concat. + struct ggml_tensor * qrot = ggml_view_3d( + ctx0, tmpq, hparams.n_rot, n_head, n_tokens, + ggml_element_size(tmpq) * n_embd_head, + ggml_element_size(tmpq) * n_embd_head * n_head, + 0 + ); + cb(qrot, "qrot", il); + + struct ggml_tensor * krot = ggml_view_3d( + ctx0, tmpk, hparams.n_rot, n_head, n_tokens, + ggml_element_size(tmpk) * n_embd_head, + ggml_element_size(tmpk) * n_embd_head * n_head_kv, + 0 + ); + cb(krot, "krot", il); + + // get the second half of tmpq, e.g tmpq[n_rot:, :, :] + struct ggml_tensor * qpass = ggml_view_3d( + ctx0, tmpq, (n_embd_head - hparams.n_rot), n_head, n_tokens, + ggml_element_size(tmpq) * n_embd_head, + ggml_element_size(tmpq) * n_embd_head * n_head, + ggml_element_size(tmpq) * hparams.n_rot + ); + cb(qpass, "qpass", il); + + struct ggml_tensor * kpass = ggml_view_3d( + ctx0, tmpk, (n_embd_head - hparams.n_rot), n_head_kv, n_tokens, + ggml_element_size(tmpk) * (n_embd_head), + ggml_element_size(tmpk) * (n_embd_head) * n_head_kv, + ggml_element_size(tmpk) * hparams.n_rot + ); + cb(kpass, "kpass", il); + + struct ggml_tensor * qrotated = ggml_rope_custom( + ctx0, qrot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); + cb(qrotated, "qrotated", il); + + struct ggml_tensor * krotated = ggml_rope_custom( + ctx0, krot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(krotated, "krotated", il); + + // ggml currently only supports concatenation on dim=2 + // so we need to permute qrot, qpass, concat, then permute back. + qrotated = ggml_cont(ctx0, ggml_permute(ctx0, qrotated, 2, 1, 0, 3)); + cb(qrotated, "qrotated", il); + + krotated = ggml_cont(ctx0, ggml_permute(ctx0, krotated, 2, 1, 0, 3)); + cb(krotated, "krotated", il); + + qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3)); + cb(qpass, "qpass", il); + + kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3)); + cb(kpass, "kpass", il); + + struct ggml_tensor * Qcur = ggml_concat(ctx0, qrotated, qpass); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, - hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass); cb(Kcur, "Kcur", il); + struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 2, 1, 0, 3)); + cb(Q, "Q", il); + + Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3)); + cb(Kcur, "Kcur", il); + + // Qcur = ggml_rope_custom( + // ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + // hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, + // ext_factor, attn_factor, beta_fast, beta_slow + // ); + // cb(Qcur, "Qcur", il); + + // Kcur = ggml_rope_custom( + // ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + // hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, + // ext_factor, attn_factor, beta_fast, beta_slow + // ); + // cb(Kcur, "Kcur", il); + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); cur = llm_build_kqv(ctx0, hparams, kv_self, model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); cb(cur, "kqv_out", il); }