From f14150493811392e6c99e697529d2915a13e193d Mon Sep 17 00:00:00 2001 From: crasm Date: Sat, 9 Dec 2023 01:00:13 -0500 Subject: [PATCH] llama : fix logits_all parameter being ignored `batch.logits` is always true since the new batch API mallocs it. This made the other control flows dead code. I am unsure what the last else statement was intended to do, and removed it because it has been dead code for a while anyway. --- llama.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/llama.cpp b/llama.cpp index 3f5d663cf..9dae1f4e2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5855,7 +5855,11 @@ static int llama_decode_internal( { auto & logits_out = lctx.logits; - if (batch.logits) { + if (lctx.logits_all) { + logits_out.resize(n_vocab * n_tokens); + memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens); + } else { + GGML_ASSERT(batch.logits); logits_out.resize(n_vocab * n_tokens); for (uint32_t i = 0; i < n_tokens; i++) { if (batch.logits[i] == 0) { @@ -5863,12 +5867,6 @@ static int llama_decode_internal( } memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab); } - } else if (lctx.logits_all) { - logits_out.resize(n_vocab * n_tokens); - memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens); - } else { - logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab); } }