llama : use a vector for ctx->output_ids

* llama : rework reallocation logic for llama_output_reserve

Now comparing the actual size with the new total size of the output buffer
to allow more efficient enabling and disabling of the embeddings
and/or logits output in the future.
This commit is contained in:
Francis Couture-Harpin 2024-03-18 20:51:32 -04:00
parent 09bb15a66a
commit 4551e7eba8

View file

@ -2055,8 +2055,6 @@ struct llama_context {
ggml_backend_free(backend);
}
free(output_ids);
#ifdef GGML_USE_VULKAN
ggml_vk_free_cpu_assist();
#endif
@ -2101,7 +2099,7 @@ struct llama_context {
size_t logits_size = 0; // capacity (of floats) for logits
float * logits = nullptr;
int32_t * output_ids = nullptr; // map token positions to ids of the logits and embd buffers
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch
@ -9179,32 +9177,29 @@ static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) {
const auto n_batch = cparams.n_batch;
const auto n_vocab = hparams.n_vocab;
const auto n_embd = hparams.n_embd;
const int64_t capacity = lctx.output_size;
// TODO: use a per-batch flag for logits presence instead
const bool has_logits = cparams.causal_attn;
const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
if (!lctx.output_ids) {
// never resized afterwards
lctx.output_ids = (int32_t *) malloc(n_batch*sizeof(int32_t));
if (lctx.output_ids == nullptr) {
throw std::runtime_error("failed to allocate output_ids buffer");
}
}
// alloc only when more than the current logits capacity is required
if (capacity < n_outputs_max) {
lctx.output_size = n_outputs_max;
lctx.logits_size = has_logits ? n_vocab*n_outputs_max : 0;
lctx.embd_size = has_embd ? n_embd*n_outputs_max : 0;
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
const size_t buf_output_size = (lctx.logits_size + lctx.embd_size)*sizeof(float);
if (lctx.output_ids.empty()) {
// init, never resized afterwards
lctx.output_ids.resize(n_batch);
}
const size_t prev_size = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output) : 0;
const size_t new_size = (logits_size + embd_size) * sizeof(float);
// alloc only when more than the current capacity is required
// TODO: also consider shrinking the buffer
if (prev_size < new_size) {
if (lctx.buf_output) {
#ifndef NDEBUG
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
const size_t prev_size = ggml_backend_buffer_get_size(lctx.buf_output);
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, buf_output_size / 1024.0 / 1024.0);
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
#endif
ggml_backend_buffer_free(lctx.buf_output);
lctx.buf_output = nullptr;
@ -9212,18 +9207,21 @@ static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) {
lctx.embd = nullptr;
}
lctx.buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buf_output_size);
lctx.buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), new_size);
if (lctx.buf_output == nullptr) {
throw std::runtime_error(format("failed to allocate output buffer of size %.2f MiB", buf_output_size / (1024.0 * 1024.0)));
throw std::runtime_error(format("failed to allocate output buffer of size %.2f MiB", new_size / (1024.0 * 1024.0)));
}
}
float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output);
lctx.output_size = n_outputs_max;
lctx.logits = has_logits ? output_base : nullptr;
lctx.embd = has_embd ? output_base + lctx.logits_size : nullptr;
}
// set all ids as invalid (assume two's complement negative numbers)
memset(lctx.output_ids, -1, n_batch*sizeof(int32_t));
lctx.embd = has_embd ? output_base + logits_size : nullptr;
lctx.logits_size = logits_size;
lctx.embd_size = embd_size;
// set all ids as invalid (negative)
std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1);
ggml_backend_buffer_clear(lctx.buf_output, 0);
@ -14152,7 +14150,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
{
std::vector<int32_t> output_pos;
const size_t n_batch = ctx->cparams.n_batch;
const int32_t * output_ids = ctx->output_ids;
const auto & output_ids = ctx->output_ids;
output_pos.resize(ctx->output_size);