llama : fix wrong n_outputs in llama_set_inputs
A mismatch happened when using a smaller n_ubatch than n_batch and then using llama_batch_get_one(). The decision of what n_outputs should be now almost fully depends on how lctx.n_outputs is set in llama_decode_internal. The conditions are simpler this way. * llama : when saving the state, recalculate n_outputs This ensures the correct number of outputs for the entire previous batch is stored in the session file, even when n_ubatch is smaller than n_batch.
This commit is contained in:
parent
408fcb0f91
commit
e19cb3aeb7
1 changed files with 20 additions and 20 deletions
38
llama.cpp
38
llama.cpp
|
@ -2103,7 +2103,7 @@ struct llama_context {
|
||||||
|
|
||||||
int32_t * output_ids = nullptr; // map token positions to ids of the logits and embd buffers
|
int32_t * output_ids = nullptr; // map token positions to ids of the logits and embd buffers
|
||||||
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
||||||
int32_t n_outputs = 0; // number of actually-used outputs in the current or previous ubatch
|
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch
|
||||||
|
|
||||||
bool logits_all = false;
|
bool logits_all = false;
|
||||||
|
|
||||||
|
@ -8985,25 +8985,25 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
|
||||||
int32_t * data = (int32_t *) lctx.inp_out_ids->data;
|
int32_t * data = (int32_t *) lctx.inp_out_ids->data;
|
||||||
|
|
||||||
int32_t n_outputs = 0;
|
if (lctx.n_outputs == n_tokens) {
|
||||||
if (batch.logits) {
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
|
data[i] = i;
|
||||||
|
}
|
||||||
|
} else if (batch.logits) {
|
||||||
|
int32_t n_outputs = 0;
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
if (batch.logits[i]) {
|
if (batch.logits[i]) {
|
||||||
data[n_outputs++] = i;
|
data[n_outputs++] = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
|
// the graph needs the have been passed the correct number of outputs
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
GGML_ASSERT(lctx.n_outputs == n_outputs);
|
||||||
data[i] = i;
|
} else if (lctx.n_outputs == 1) {
|
||||||
}
|
|
||||||
n_outputs = n_tokens;
|
|
||||||
} else {
|
|
||||||
// only keep last output
|
// only keep last output
|
||||||
data[0] = n_tokens - 1;
|
data[0] = n_tokens - 1;
|
||||||
n_outputs = 1;
|
} else {
|
||||||
|
GGML_ASSERT(lctx.n_outputs == 0);
|
||||||
}
|
}
|
||||||
// the graph needs the have been passed the correct number of outputs
|
|
||||||
GGML_ASSERT(lctx.n_outputs == n_outputs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ASSERT(
|
GGML_ASSERT(
|
||||||
|
@ -9386,7 +9386,7 @@ static int llama_decode_internal(
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
n_outputs_new += u_batch.logits[i] != 0;
|
n_outputs_new += u_batch.logits[i] != 0;
|
||||||
}
|
}
|
||||||
} else if (lctx.logits_all) {
|
} else if ((uint32_t) n_outputs == n_tokens_all) {
|
||||||
n_outputs_new = n_tokens;
|
n_outputs_new = n_tokens;
|
||||||
} else {
|
} else {
|
||||||
// keep last output only
|
// keep last output only
|
||||||
|
@ -14166,7 +14166,9 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
||||||
|
|
||||||
// copy outputs
|
// copy outputs
|
||||||
{
|
{
|
||||||
size_t n_outputs = ctx->n_outputs;
|
// Can't use ctx->n_outputs because it's not for the
|
||||||
|
// entire last batch when n_ubatch is smaller than n_batch
|
||||||
|
size_t n_outputs = 0;
|
||||||
|
|
||||||
// copy output ids
|
// copy output ids
|
||||||
{
|
{
|
||||||
|
@ -14174,19 +14176,17 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
||||||
const size_t n_batch = ctx->cparams.n_batch;
|
const size_t n_batch = ctx->cparams.n_batch;
|
||||||
const int32_t * output_ids = ctx->output_ids;
|
const int32_t * output_ids = ctx->output_ids;
|
||||||
|
|
||||||
output_pos.resize(n_outputs);
|
output_pos.resize(ctx->output_size);
|
||||||
|
|
||||||
// build a more compact representation of the output ids
|
// build a more compact representation of the output ids
|
||||||
for (size_t i = 0; i < n_batch; ++i) {
|
for (size_t i = 0; i < n_batch; ++i) {
|
||||||
// map an output id to a position in the batch
|
// map an output id to a position in the batch
|
||||||
int32_t pos = output_ids[i];
|
int32_t pos = output_ids[i];
|
||||||
if (pos >= 0) {
|
if (pos >= 0) {
|
||||||
if ((size_t) pos >= output_pos.size()) {
|
if ((size_t) pos >= n_outputs) {
|
||||||
// TODO: maybe fail here instead
|
|
||||||
LLAMA_LOG_WARN("%s: weird output buffer layout, possibly a bug\n", __func__);
|
|
||||||
n_outputs = pos + 1;
|
n_outputs = pos + 1;
|
||||||
output_pos.resize(n_outputs);
|
|
||||||
}
|
}
|
||||||
|
GGML_ASSERT((size_t) pos < ctx->output_size);
|
||||||
output_pos[pos] = i;
|
output_pos[pos] = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue