From e19cb3aeb728987aa0a58a119d800c06fdd6aad7 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 17 Mar 2024 16:31:19 -0400 Subject: [PATCH] llama : fix wrong n_outputs in llama_set_inputs A mismatch happened when using a smaller n_ubatch than n_batch and then using llama_batch_get_one(). The decision of what n_outputs should be now almost fully depends on how lctx.n_outputs is set in llama_decode_internal. The conditions are simpler this way. * llama : when saving the state, recalculate n_outputs This ensures the correct number of outputs for the entire previous batch is stored in the session file, even when n_ubatch is smaller than n_batch. --- llama.cpp | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/llama.cpp b/llama.cpp index b82b56059..e606bdda4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2103,7 +2103,7 @@ struct llama_context { int32_t * output_ids = nullptr; // map token positions to ids of the logits and embd buffers size_t output_size = 0; // capacity (of tokens positions) for the output buffers - int32_t n_outputs = 0; // number of actually-used outputs in the current or previous ubatch + int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch bool logits_all = false; @@ -8985,25 +8985,25 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer)); int32_t * data = (int32_t *) lctx.inp_out_ids->data; - int32_t n_outputs = 0; - if (batch.logits) { + if (lctx.n_outputs == n_tokens) { + for (int i = 0; i < n_tokens; ++i) { + data[i] = i; + } + } else if (batch.logits) { + int32_t n_outputs = 0; for (int i = 0; i < n_tokens; ++i) { if (batch.logits[i]) { data[n_outputs++] = i; } } - } else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) { - for (int i = 0; i < n_tokens; ++i) { - data[i] = i; - } - n_outputs = n_tokens; - } else { + // the graph needs the have been passed the correct number of outputs + GGML_ASSERT(lctx.n_outputs == n_outputs); + } else if (lctx.n_outputs == 1) { // only keep last output data[0] = n_tokens - 1; - n_outputs = 1; + } else { + GGML_ASSERT(lctx.n_outputs == 0); } - // the graph needs the have been passed the correct number of outputs - GGML_ASSERT(lctx.n_outputs == n_outputs); } GGML_ASSERT( @@ -9386,7 +9386,7 @@ static int llama_decode_internal( for (uint32_t i = 0; i < n_tokens; i++) { n_outputs_new += u_batch.logits[i] != 0; } - } else if (lctx.logits_all) { + } else if ((uint32_t) n_outputs == n_tokens_all) { n_outputs_new = n_tokens; } else { // keep last output only @@ -14166,7 +14166,9 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat // copy outputs { - size_t n_outputs = ctx->n_outputs; + // Can't use ctx->n_outputs because it's not for the + // entire last batch when n_ubatch is smaller than n_batch + size_t n_outputs = 0; // copy output ids { @@ -14174,19 +14176,17 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat const size_t n_batch = ctx->cparams.n_batch; const int32_t * output_ids = ctx->output_ids; - output_pos.resize(n_outputs); + output_pos.resize(ctx->output_size); // build a more compact representation of the output ids for (size_t i = 0; i < n_batch; ++i) { // map an output id to a position in the batch int32_t pos = output_ids[i]; if (pos >= 0) { - if ((size_t) pos >= output_pos.size()) { - // TODO: maybe fail here instead - LLAMA_LOG_WARN("%s: weird output buffer layout, possibly a bug\n", __func__); + if ((size_t) pos >= n_outputs) { n_outputs = pos + 1; - output_pos.resize(n_outputs); } + GGML_ASSERT((size_t) pos < ctx->output_size); output_pos[pos] = i; } } @@ -14201,7 +14201,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat // copy logits { const size_t logits_size = std::min(ctx->logits_size, n_outputs * ctx->model.hparams.n_vocab); - + data_ctx->write(&logits_size, sizeof(logits_size)); if (logits_size) {