llama : fix logits_all parameter being ignored
`batch.logits` is always true since the new batch API mallocs it. This made the other control flows dead code. I am unsure what the last else statement was intended to do, and removed it because it has been dead code for a while anyway.
This commit is contained in:
parent
5a7d3125e7
commit
f141504938
1 changed files with 5 additions and 7 deletions
12
llama.cpp
12
llama.cpp
|
@ -5855,7 +5855,11 @@ static int llama_decode_internal(
|
|||
{
|
||||
auto & logits_out = lctx.logits;
|
||||
|
||||
if (batch.logits) {
|
||||
if (lctx.logits_all) {
|
||||
logits_out.resize(n_vocab * n_tokens);
|
||||
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens);
|
||||
} else {
|
||||
GGML_ASSERT(batch.logits);
|
||||
logits_out.resize(n_vocab * n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
if (batch.logits[i] == 0) {
|
||||
|
@ -5863,12 +5867,6 @@ static int llama_decode_internal(
|
|||
}
|
||||
memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab);
|
||||
}
|
||||
} else if (lctx.logits_all) {
|
||||
logits_out.resize(n_vocab * n_tokens);
|
||||
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens);
|
||||
} else {
|
||||
logits_out.resize(n_vocab);
|
||||
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue