llama : fix logits_all parameter being ignored

`batch.logits` is always true since the new batch API mallocs it. This
made the other control flows dead code.

I am unsure what the last else statement was intended to do, and removed
it because it has been dead code for a while anyway.
This commit is contained in:
crasm 2023-12-09 01:00:13 -05:00
parent 5a7d3125e7
commit f141504938

View file

@ -5855,7 +5855,11 @@ static int llama_decode_internal(
{
auto & logits_out = lctx.logits;
if (batch.logits) {
if (lctx.logits_all) {
logits_out.resize(n_vocab * n_tokens);
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens);
} else {
GGML_ASSERT(batch.logits);
logits_out.resize(n_vocab * n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
if (batch.logits[i] == 0) {
@ -5863,12 +5867,6 @@ static int llama_decode_internal(
}
memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab);
}
} else if (lctx.logits_all) {
logits_out.resize(n_vocab * n_tokens);
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens);
} else {
logits_out.resize(n_vocab);
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab);
}
}