From 8a887decd35083c1542534cfadc3a5ee592da964 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Tue, 28 Jan 2025 19:26:54 +0100 Subject: [PATCH] llama : prompt processing optimizations in DeepSeek V2 --- src/llama.cpp | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index a4c78240b..5768f9215 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -6403,6 +6403,10 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + // whether to use n_tokens as the matrix dimension during multiplication or n_head + // n_tokens is higher during prompt processing, this allows to optimize for this case + bool pp_opt = n_tokens > n_head; + for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -6535,14 +6539,18 @@ struct llm_build_context { struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope_perm); cb(q_nope2, "q_nope2", il); - struct ggml_tensor * q_nope2_perm = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3); - cb(q_nope2_perm, "q_nope2_perm", il); + if (!pp_opt) { + q_nope2 = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3); + cb(q_nope2, "q_nope2_perm", il); + } - struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2_perm); + struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2); cb(kq_nope, "kq_nope", il); - struct ggml_tensor * q_pe_perm = ggml_permute(ctx0, q_pe, 0, 3, 2, 1); - cb(q_pe_perm, "q_pe_perm", il); + if (pp_opt) { + q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3); + cb(q_pe, "q_pe_perm", il); + } struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe); cb(kq_pe, "kq_pe", il); @@ -6550,20 +6558,26 @@ struct llm_build_context { struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe); cb(kq, "kq", il); - kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); - cb(kq, "kq_perm", il); + if (!pp_opt) { + kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); + cb(kq, "kq_perm", il); + } kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); - struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 1, 3); - cb(kq_perm, "kq_soft_max_ext_perm", il); + if (!pp_opt) { + kq = ggml_permute(ctx0, kq, 0, 2, 1, 3); + cb(kq, "kq_soft_max_ext_perm", il); + } - struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq_perm); + struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq); cb(kqv_compressed, "kqv_compressed", il); - kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); - cb(kqv_compressed, "kqv_compressed_perm", il); + if (!pp_opt) { + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + cb(kqv_compressed, "kqv_compressed_perm", il); + } struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0); cb(wv_b, "wv_b", il);