llama : fix logits_all parameter being ignored
`batch.logits` is always true since the new batch API mallocs it. This made the other control flows dead code. I am unsure what the last else statement was intended to do, and removed it because it has been dead code for a while anyway.
This commit is contained in:
parent
5a7d3125e7
commit
f141504938
1 changed files with 5 additions and 7 deletions
12
llama.cpp
12
llama.cpp
|
@ -5855,7 +5855,11 @@ static int llama_decode_internal(
|
||||||
{
|
{
|
||||||
auto & logits_out = lctx.logits;
|
auto & logits_out = lctx.logits;
|
||||||
|
|
||||||
if (batch.logits) {
|
if (lctx.logits_all) {
|
||||||
|
logits_out.resize(n_vocab * n_tokens);
|
||||||
|
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens);
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(batch.logits);
|
||||||
logits_out.resize(n_vocab * n_tokens);
|
logits_out.resize(n_vocab * n_tokens);
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
if (batch.logits[i] == 0) {
|
if (batch.logits[i] == 0) {
|
||||||
|
@ -5863,12 +5867,6 @@ static int llama_decode_internal(
|
||||||
}
|
}
|
||||||
memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab);
|
memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab);
|
||||||
}
|
}
|
||||||
} else if (lctx.logits_all) {
|
|
||||||
logits_out.resize(n_vocab * n_tokens);
|
|
||||||
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens);
|
|
||||||
} else {
|
|
||||||
logits_out.resize(n_vocab);
|
|
||||||
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue