From 2cf4f62e12c36c7ba81efd8db3cb68a84e3121dd Mon Sep 17 00:00:00 2001 From: ochafik Date: Fri, 18 Aug 2023 01:46:20 +0100 Subject: [PATCH] Skip computation of unused logits during batch prompt eval (drop other batch positions after writing their kv to cache) --- llama.cpp | 44 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/llama.cpp b/llama.cpp index f2dc4da1d..838f47e36 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2117,7 +2117,8 @@ static struct ggml_cgraph * llm_build_llama( GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT - const int N = n_tokens; + // Non-const to allow short-circuiting to the last token in the last layer in prompt eval mode. + int N = n_tokens; const auto & model = lctx.model; const auto & hparams = model.hparams; @@ -2245,18 +2246,10 @@ static struct ggml_cgraph * llm_build_llama( offload_func_kq(tmpk); ggml_set_name(tmpk, "tmpk"); - struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - offload_func_kq(tmpq); - ggml_set_name(tmpq, "tmpq"); - struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); offload_func_kq(Kcur); ggml_set_name(Kcur, "Kcur"); - struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); - offload_func_kq(Qcur); - ggml_set_name(Qcur, "Qcur"); - // store key and value to memory { // compute the transposed [N, n_embd] V matrix @@ -2284,6 +2277,35 @@ static struct ggml_cgraph * llm_build_llama( ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); } + if (il == n_layer - 1 && !lctx.logits_all) + { + // From here on, we only care about the last token and its logits. + // We do as if N = 1 (from the end), which means we only keep + // the last column of cur and inpSA ((n_embd, N) -> (n_embd, 1)). + // + // Note that we do this even when N==1 so that we don't change the # nodes in the graph, + // otherwise for Metal we'd have to rebuild the concurrency list. + + cur = ggml_view_2d(ctx0, cur, n_embd, 1, cur->nb[1], (N - 1)*ggml_element_size(cur)*n_embd); + offload_func_nr(cur); + ggml_set_name(cur, "cur-lastpos"); + + inpSA = ggml_view_2d(ctx0, inpSA, n_embd, 1, inpSA->nb[1], (N - 1)*ggml_element_size(inpSA)*n_embd); + offload_func_nr(inpSA); + ggml_set_name(inpSA, "inpSA-lastpos"); + + n_past += N - 1; + N = 1; + } + + struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + offload_func_kq(tmpq); + ggml_set_name(tmpq, "tmpq"); + + struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + offload_func_kq(Qcur); + ggml_set_name(Qcur, "Qcur"); + struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); offload_func_kq(Q); ggml_set_name(Q, "Q"); @@ -2902,11 +2924,13 @@ static bool llama_eval_internal( if (lctx.logits_all) { logits_out.resize(n_vocab * N); + GGML_ASSERT(ggml_nelements(res) == n_vocab * N); memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N); } else { // return result for just the last token + GGML_ASSERT(ggml_nelements(res) == n_vocab); logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab); } }