llama : fix wrong n_outputs in llama_set_inputs

A mismatch happened when using a smaller n_ubatch than n_batch and then using
llama_batch_get_one(). The decision of what n_outputs should be now almost
fully depends on how lctx.n_outputs is set in llama_decode_internal.
The conditions are simpler this way.

* llama : when saving the state, recalculate n_outputs

This ensures the correct number of outputs for the entire previous batch
is stored in the session file, even when n_ubatch is smaller than n_batch.
This commit is contained in:
Francis Couture-Harpin 2024-03-17 16:31:19 -04:00
parent 408fcb0f91
commit e19cb3aeb7

View file

@ -2103,7 +2103,7 @@ struct llama_context {
int32_t * output_ids = nullptr; // map token positions to ids of the logits and embd buffers int32_t * output_ids = nullptr; // map token positions to ids of the logits and embd buffers
size_t output_size = 0; // capacity (of tokens positions) for the output buffers size_t output_size = 0; // capacity (of tokens positions) for the output buffers
int32_t n_outputs = 0; // number of actually-used outputs in the current or previous ubatch int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch
bool logits_all = false; bool logits_all = false;
@ -8985,25 +8985,25 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
int32_t * data = (int32_t *) lctx.inp_out_ids->data; int32_t * data = (int32_t *) lctx.inp_out_ids->data;
int32_t n_outputs = 0; if (lctx.n_outputs == n_tokens) {
if (batch.logits) { for (int i = 0; i < n_tokens; ++i) {
data[i] = i;
}
} else if (batch.logits) {
int32_t n_outputs = 0;
for (int i = 0; i < n_tokens; ++i) { for (int i = 0; i < n_tokens; ++i) {
if (batch.logits[i]) { if (batch.logits[i]) {
data[n_outputs++] = i; data[n_outputs++] = i;
} }
} }
} else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) { // the graph needs the have been passed the correct number of outputs
for (int i = 0; i < n_tokens; ++i) { GGML_ASSERT(lctx.n_outputs == n_outputs);
data[i] = i; } else if (lctx.n_outputs == 1) {
}
n_outputs = n_tokens;
} else {
// only keep last output // only keep last output
data[0] = n_tokens - 1; data[0] = n_tokens - 1;
n_outputs = 1; } else {
GGML_ASSERT(lctx.n_outputs == 0);
} }
// the graph needs the have been passed the correct number of outputs
GGML_ASSERT(lctx.n_outputs == n_outputs);
} }
GGML_ASSERT( GGML_ASSERT(
@ -9386,7 +9386,7 @@ static int llama_decode_internal(
for (uint32_t i = 0; i < n_tokens; i++) { for (uint32_t i = 0; i < n_tokens; i++) {
n_outputs_new += u_batch.logits[i] != 0; n_outputs_new += u_batch.logits[i] != 0;
} }
} else if (lctx.logits_all) { } else if ((uint32_t) n_outputs == n_tokens_all) {
n_outputs_new = n_tokens; n_outputs_new = n_tokens;
} else { } else {
// keep last output only // keep last output only
@ -14166,7 +14166,9 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
// copy outputs // copy outputs
{ {
size_t n_outputs = ctx->n_outputs; // Can't use ctx->n_outputs because it's not for the
// entire last batch when n_ubatch is smaller than n_batch
size_t n_outputs = 0;
// copy output ids // copy output ids
{ {
@ -14174,19 +14176,17 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
const size_t n_batch = ctx->cparams.n_batch; const size_t n_batch = ctx->cparams.n_batch;
const int32_t * output_ids = ctx->output_ids; const int32_t * output_ids = ctx->output_ids;
output_pos.resize(n_outputs); output_pos.resize(ctx->output_size);
// build a more compact representation of the output ids // build a more compact representation of the output ids
for (size_t i = 0; i < n_batch; ++i) { for (size_t i = 0; i < n_batch; ++i) {
// map an output id to a position in the batch // map an output id to a position in the batch
int32_t pos = output_ids[i]; int32_t pos = output_ids[i];
if (pos >= 0) { if (pos >= 0) {
if ((size_t) pos >= output_pos.size()) { if ((size_t) pos >= n_outputs) {
// TODO: maybe fail here instead
LLAMA_LOG_WARN("%s: weird output buffer layout, possibly a bug\n", __func__);
n_outputs = pos + 1; n_outputs = pos + 1;
output_pos.resize(n_outputs);
} }
GGML_ASSERT((size_t) pos < ctx->output_size);
output_pos[pos] = i; output_pos[pos] = i;
} }
} }