Skip computation of unused logits during batch prompt eval (drop other batch positions after writing their kv to cache)
This commit is contained in:
parent
cf658adc83
commit
2cf4f62e12
1 changed files with 34 additions and 10 deletions
44
llama.cpp
44
llama.cpp
|
@ -2117,7 +2117,8 @@ static struct ggml_cgraph * llm_build_llama(
|
||||||
|
|
||||||
GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT
|
GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT
|
||||||
|
|
||||||
const int N = n_tokens;
|
// Non-const to allow short-circuiting to the last token in the last layer in prompt eval mode.
|
||||||
|
int N = n_tokens;
|
||||||
|
|
||||||
const auto & model = lctx.model;
|
const auto & model = lctx.model;
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
@ -2245,18 +2246,10 @@ static struct ggml_cgraph * llm_build_llama(
|
||||||
offload_func_kq(tmpk);
|
offload_func_kq(tmpk);
|
||||||
ggml_set_name(tmpk, "tmpk");
|
ggml_set_name(tmpk, "tmpk");
|
||||||
|
|
||||||
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
|
||||||
offload_func_kq(tmpq);
|
|
||||||
ggml_set_name(tmpq, "tmpq");
|
|
||||||
|
|
||||||
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
|
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||||
offload_func_kq(Kcur);
|
offload_func_kq(Kcur);
|
||||||
ggml_set_name(Kcur, "Kcur");
|
ggml_set_name(Kcur, "Kcur");
|
||||||
|
|
||||||
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
|
|
||||||
offload_func_kq(Qcur);
|
|
||||||
ggml_set_name(Qcur, "Qcur");
|
|
||||||
|
|
||||||
// store key and value to memory
|
// store key and value to memory
|
||||||
{
|
{
|
||||||
// compute the transposed [N, n_embd] V matrix
|
// compute the transposed [N, n_embd] V matrix
|
||||||
|
@ -2284,6 +2277,35 @@ static struct ggml_cgraph * llm_build_llama(
|
||||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (il == n_layer - 1 && !lctx.logits_all)
|
||||||
|
{
|
||||||
|
// From here on, we only care about the last token and its logits.
|
||||||
|
// We do as if N = 1 (from the end), which means we only keep
|
||||||
|
// the last column of cur and inpSA ((n_embd, N) -> (n_embd, 1)).
|
||||||
|
//
|
||||||
|
// Note that we do this even when N==1 so that we don't change the # nodes in the graph,
|
||||||
|
// otherwise for Metal we'd have to rebuild the concurrency list.
|
||||||
|
|
||||||
|
cur = ggml_view_2d(ctx0, cur, n_embd, 1, cur->nb[1], (N - 1)*ggml_element_size(cur)*n_embd);
|
||||||
|
offload_func_nr(cur);
|
||||||
|
ggml_set_name(cur, "cur-lastpos");
|
||||||
|
|
||||||
|
inpSA = ggml_view_2d(ctx0, inpSA, n_embd, 1, inpSA->nb[1], (N - 1)*ggml_element_size(inpSA)*n_embd);
|
||||||
|
offload_func_nr(inpSA);
|
||||||
|
ggml_set_name(inpSA, "inpSA-lastpos");
|
||||||
|
|
||||||
|
n_past += N - 1;
|
||||||
|
N = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||||
|
offload_func_kq(tmpq);
|
||||||
|
ggml_set_name(tmpq, "tmpq");
|
||||||
|
|
||||||
|
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||||
|
offload_func_kq(Qcur);
|
||||||
|
ggml_set_name(Qcur, "Qcur");
|
||||||
|
|
||||||
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||||
offload_func_kq(Q);
|
offload_func_kq(Q);
|
||||||
ggml_set_name(Q, "Q");
|
ggml_set_name(Q, "Q");
|
||||||
|
@ -2902,11 +2924,13 @@ static bool llama_eval_internal(
|
||||||
|
|
||||||
if (lctx.logits_all) {
|
if (lctx.logits_all) {
|
||||||
logits_out.resize(n_vocab * N);
|
logits_out.resize(n_vocab * N);
|
||||||
|
GGML_ASSERT(ggml_nelements(res) == n_vocab * N);
|
||||||
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N);
|
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N);
|
||||||
} else {
|
} else {
|
||||||
// return result for just the last token
|
// return result for just the last token
|
||||||
|
GGML_ASSERT(ggml_nelements(res) == n_vocab);
|
||||||
logits_out.resize(n_vocab);
|
logits_out.resize(n_vocab);
|
||||||
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue