Skip computation of unused logits during batch prompt eval (drop other batch positions after writing their kv to cache)

This commit is contained in:
ochafik 2023-08-18 01:46:20 +01:00 committed by Olivier Chafik
parent cf658adc83
commit 2cf4f62e12

View file

@ -2117,7 +2117,8 @@ static struct ggml_cgraph * llm_build_llama(
GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT
const int N = n_tokens; // Non-const to allow short-circuiting to the last token in the last layer in prompt eval mode.
int N = n_tokens;
const auto & model = lctx.model; const auto & model = lctx.model;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
@ -2245,18 +2246,10 @@ static struct ggml_cgraph * llm_build_llama(
offload_func_kq(tmpk); offload_func_kq(tmpk);
ggml_set_name(tmpk, "tmpk"); ggml_set_name(tmpk, "tmpk");
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
offload_func_kq(tmpq);
ggml_set_name(tmpq, "tmpq");
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(Kcur); offload_func_kq(Kcur);
ggml_set_name(Kcur, "Kcur"); ggml_set_name(Kcur, "Kcur");
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(Qcur);
ggml_set_name(Qcur, "Qcur");
// store key and value to memory // store key and value to memory
{ {
// compute the transposed [N, n_embd] V matrix // compute the transposed [N, n_embd] V matrix
@ -2284,6 +2277,35 @@ static struct ggml_cgraph * llm_build_llama(
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
} }
if (il == n_layer - 1 && !lctx.logits_all)
{
// From here on, we only care about the last token and its logits.
// We do as if N = 1 (from the end), which means we only keep
// the last column of cur and inpSA ((n_embd, N) -> (n_embd, 1)).
//
// Note that we do this even when N==1 so that we don't change the # nodes in the graph,
// otherwise for Metal we'd have to rebuild the concurrency list.
cur = ggml_view_2d(ctx0, cur, n_embd, 1, cur->nb[1], (N - 1)*ggml_element_size(cur)*n_embd);
offload_func_nr(cur);
ggml_set_name(cur, "cur-lastpos");
inpSA = ggml_view_2d(ctx0, inpSA, n_embd, 1, inpSA->nb[1], (N - 1)*ggml_element_size(inpSA)*n_embd);
offload_func_nr(inpSA);
ggml_set_name(inpSA, "inpSA-lastpos");
n_past += N - 1;
N = 1;
}
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
offload_func_kq(tmpq);
ggml_set_name(tmpq, "tmpq");
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(Qcur);
ggml_set_name(Qcur, "Qcur");
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
offload_func_kq(Q); offload_func_kq(Q);
ggml_set_name(Q, "Q"); ggml_set_name(Q, "Q");
@ -2902,11 +2924,13 @@ static bool llama_eval_internal(
if (lctx.logits_all) { if (lctx.logits_all) {
logits_out.resize(n_vocab * N); logits_out.resize(n_vocab * N);
GGML_ASSERT(ggml_nelements(res) == n_vocab * N);
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N); memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N);
} else { } else {
// return result for just the last token // return result for just the last token
GGML_ASSERT(ggml_nelements(res) == n_vocab);
logits_out.resize(n_vocab); logits_out.resize(n_vocab);
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(N-1)), sizeof(float)*n_vocab); memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab);
} }
} }