From 705d3937eaa1f1f370fda188564405e358905d8c Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 16 Mar 2024 17:24:05 -0400 Subject: [PATCH] llama : fix lctx.n_outputs not being set before building graph --- llama.cpp | 147 +++++++++++++++++++++++++++++------------------------- 1 file changed, 79 insertions(+), 68 deletions(-) diff --git a/llama.cpp b/llama.cpp index 47071700b..f397ffed4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2103,7 +2103,7 @@ struct llama_context { int32_t * output_ids = nullptr; // map token positions to ids of the logits and embd buffers size_t output_size = 0; // capacity (of tokens positions) for the output buffers - int32_t n_outputs = 0; // number of actually-used outputs in the current or previous batch + int32_t n_outputs = 0; // number of actually-used outputs in the current or previous ubatch bool logits_all = false; @@ -8985,24 +8985,25 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer)); int32_t * data = (int32_t *) lctx.inp_out_ids->data; + int32_t n_outputs = 0; if (batch.logits) { - int32_t n_outputs = 0; for (int i = 0; i < n_tokens; ++i) { if (batch.logits[i]) { data[n_outputs++] = i; } } - lctx.n_outputs = n_outputs; } else if (lctx.logits_all || (cparams.embeddings && hparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) { for (int i = 0; i < n_tokens; ++i) { data[i] = i; } - lctx.n_outputs = n_tokens; + n_outputs = n_tokens; } else { // only keep last output data[0] = n_tokens - 1; - lctx.n_outputs = 1; + n_outputs = 1; } + // the graph needs the have been passed the correct number of outputs + GGML_ASSERT(lctx.n_outputs == n_outputs); } GGML_ASSERT( @@ -9202,6 +9203,7 @@ static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) { const auto n_embd = hparams.n_embd; const int64_t capacity = lctx.output_size; + // TODO: use a per-batch flag for logits presence instead const bool has_logits = cparams.causal_attn; const bool has_embd = cparams.embeddings && (!hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); @@ -9221,10 +9223,11 @@ static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) { const size_t buf_output_size = (lctx.logits_size + lctx.embd_size)*sizeof(float); if (lctx.buf_output) { -#ifndef NDEBUG + // This doesn't happen often +// #ifndef NDEBUG const size_t prev_size = ggml_backend_buffer_get_size(lctx.buf_output); - fprintf(stderr, "%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, buf_output_size/ 1024.0 / 1024.0); -#endif + LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, buf_output_size / 1024.0 / 1024.0); +// #endif ggml_backend_buffer_free(lctx.buf_output); lctx.buf_output = nullptr; lctx.logits = nullptr; @@ -9246,7 +9249,7 @@ static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) { ggml_backend_buffer_clear(lctx.buf_output, 0); - lctx.n_outputs = n_outputs; // also set in llama_set_inputs() before a batch + lctx.n_outputs = 0; } @@ -9325,8 +9328,8 @@ static int llama_decode_internal( const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; - int32_t n_logits = 0; - int32_t n_logits_prev = 0; + int32_t n_outputs = 0; + int32_t n_outputs_prev = 0; const auto n_ubatch = cparams.n_ubatch; @@ -9338,11 +9341,9 @@ static int llama_decode_internal( // reserve output buffer if (batch_all.logits) { for (uint32_t i = 0; i < n_tokens_all; ++i) { - if (batch_all.logits[i]) { - n_logits++; - } + n_outputs += batch_all.logits[i] != 0; } - llama_output_reserve(lctx, n_logits); + llama_output_reserve(lctx, n_outputs); int32_t i_logits = 0; for (uint32_t i = 0; i < n_tokens_all; ++i) { if (batch_all.logits[i]) { @@ -9350,15 +9351,15 @@ static int llama_decode_internal( } } } else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) { - n_logits = n_tokens_all; - llama_output_reserve(lctx, n_logits); + n_outputs = n_tokens_all; + llama_output_reserve(lctx, n_outputs); for (uint32_t i = 0; i < n_tokens_all; ++i) { lctx.output_ids[i] = i; } } else { - // keep last logits only - n_logits = 1; - llama_output_reserve(lctx, n_logits); + // keep last output only + n_outputs = 1; + llama_output_reserve(lctx, n_outputs); lctx.output_ids[0] = 0; } @@ -9377,6 +9378,27 @@ static int llama_decode_internal( /* .all_seq_id = */ batch_all.all_seq_id, }; + // count the outputs in this u_batch + { + int32_t n_outputs_new = 0; + + if (u_batch.logits) { + for (uint32_t i = 0; i < n_tokens; i++) { + n_outputs_new += u_batch.logits[i] != 0; + } + } else if (lctx.logits_all) { + n_outputs_new = n_tokens; + } else { + // keep last output only + if (cur_token + n_tokens >= n_tokens_all) { + n_outputs_new = 1; + } + } + + // needs to happen before the graph is built + lctx.n_outputs = n_outputs_new; + } + int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; GGML_ASSERT(n_threads > 0); @@ -9451,18 +9473,26 @@ static int llama_decode_internal( embd = gf->nodes[gf->n_nodes - 1]; GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0); - // TODO: graph view to ignore the logits when not needed - } else { - if (strcmp(res->name, "result_output") == 0) { - // the token embeddings could be the second to last tensor, or any of the previous tensors - // NOTE: see build_result_output() for an idea of up to how many tensors to skip - for (int i = 3; strcmp(embd->name, "result_norm") != 0 && i <= 10; ++i) { - embd = gf->nodes[gf->n_nodes - i]; - } - GGML_ASSERT(strcmp(embd->name, "result_norm") == 0); - } else { - GGML_ASSERT(false && "missing result_output tensor"); + } else if (cparams.embeddings) { + // the embeddings could be in the second to last tensor, or any of the previous tensors + int i_embd = gf->n_nodes - 2; + for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) { + i_embd = gf->n_nodes - i; + if (i_embd < 0) { break; } + embd = gf->nodes[i_embd]; } + GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor"); + + // TODO: use a per-batch flag to know when to skip logits while keeping embeddings + if (!cparams.causal_attn) { + res = nullptr; // do not extract logits when not needed + // skip computing logits + // TODO: is this safe? + gf->n_nodes = i_embd + 1; + } + } else { + embd = nullptr; // do not extract embeddings when not needed + GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor"); } // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); @@ -9505,38 +9535,22 @@ static int llama_decode_internal( //} // extract logits - // TODO: do not compute and extract logits if only embeddings are needed - // update the graphs to skip "result_output" if logits are not needed if (res) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res); GGML_ASSERT(backend_res != nullptr); - int32_t new_logits = 0; - if (u_batch.logits) { - for (uint32_t i = 0; i < n_tokens; i++) { - if (u_batch.logits[i]) { - new_logits++; - } - } - } else if (lctx.logits_all) { - new_logits += n_tokens; - } else { - // keep last logits only - if (cur_token + n_tokens >= n_tokens_all) { - new_logits += 1; - } - } + float * logits_out = lctx.logits + n_outputs_prev*n_vocab; + const int32_t n_outputs_new = lctx.n_outputs; - if (new_logits) { - GGML_ASSERT(new_logits <= n_logits); - GGML_ASSERT((n_logits_prev+new_logits)*n_vocab <= (int64_t) lctx.logits_size); - ggml_backend_tensor_get_async(backend_res, res, lctx.logits, n_logits_prev*n_vocab*sizeof(float), new_logits*n_vocab*sizeof(float)); - n_logits_prev += new_logits; + if (n_outputs_new) { + GGML_ASSERT(n_outputs_prev+n_outputs_new <= n_outputs); + GGML_ASSERT((n_outputs_prev+n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size); + ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float)); } } // extract embeddings - if (cparams.embeddings && embd) { + if (embd) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd); GGML_ASSERT(backend_embd != nullptr); @@ -9544,17 +9558,13 @@ static int llama_decode_internal( case LLAMA_POOLING_TYPE_NONE: { // extract token embeddings - auto & embd_out = lctx.embd; + float * embd_out = lctx.embd + n_outputs_prev*n_embd; + const int32_t n_outputs_new = lctx.n_outputs; - if (u_batch.logits) { - //embd_out.resize(n_embd * n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - if (u_batch.logits[i] == 0) { - continue; - } - // FIXME - ggml_backend_tensor_get_async(backend_embd, embd, embd_out + n_embd*(i + cur_token), (n_embd*i)*sizeof(float), n_embd*sizeof(float)); - } + if (n_outputs_new) { + GGML_ASSERT(n_outputs_prev+n_outputs_new <= n_outputs); + GGML_ASSERT((n_outputs_prev+n_outputs_new)*n_embd <= (int64_t) lctx.embd_size); + ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float)); } } break; case LLAMA_POOLING_TYPE_CLS: @@ -9581,6 +9591,7 @@ static int llama_decode_internal( } break; } } + n_outputs_prev += lctx.n_outputs; } // wait for the computation to finish (automatically done when obtaining the model output) @@ -14639,11 +14650,11 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { const int32_t j = ctx->output_ids[i]; llama_synchronize(ctx); - - if (ctx->logits && 0 <= j && j < ctx->n_outputs) { + if (ctx->logits && 0 <= j && (size_t) j < ctx->output_size) { return ctx->logits + j*ctx->model.hparams.n_vocab; } - LLAMA_LOG_ERROR("%s: invalid logits id %i\n", __func__, i); + LLAMA_LOG_ERROR("%s: invalid logits id %i, reason: %s (j=%i, output_size=%li)\n", + __func__, i, !ctx->logits ? "no logits" : j < 0 ? "batch.logits[i] wasn't true" : "too big", j, ctx->output_size); #ifndef NDEBUG GGML_ASSERT(false); #endif @@ -14661,7 +14672,7 @@ float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { llama_synchronize(ctx); - if (ctx->embd && 0 < j && j < ctx->n_outputs) { + if (ctx->embd && 0 < j && (size_t) j < ctx->output_size) { return ctx->embd + j*ctx->model.hparams.n_embd; } LLAMA_LOG_ERROR("%s: invalid embeddings id %i\n", __func__, i);