From 98914c0ed02f1503762712bbe58bfacfcbf48b60 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Fri, 15 Mar 2024 12:21:24 -0400 Subject: [PATCH] llama : more compact state saving and reloading --- llama.cpp | 173 +++++++++++++++++++++++++++++++++++++++--------------- llama.h | 24 ++++---- 2 files changed, 139 insertions(+), 58 deletions(-) diff --git a/llama.cpp b/llama.cpp index 3e528d96b..47071700b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2102,8 +2102,8 @@ struct llama_context { float * logits = nullptr; int32_t * output_ids = nullptr; // map token positions to ids of the logits and embd buffers - size_t output_size = 0; // capacity (of tokens positions) for the output buffer - int32_t n_outputs = 0; // number of actually-used outputs in the previous batch + size_t output_size = 0; // capacity (of tokens positions) for the output buffers + int32_t n_outputs = 0; // number of actually-used outputs in the current or previous batch bool logits_all = false; @@ -9192,15 +9192,18 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) { GGML_ASSERT(0 <= n_outputs); - const int32_t n_outputs_max = std::max((uint32_t) n_outputs, lctx.cparams.n_seq_max); + const auto & cparams = lctx.cparams; + const auto & hparams = lctx.model.hparams; - const auto n_batch = lctx.cparams.n_batch; - const auto n_vocab = lctx.model.hparams.n_vocab; - const auto n_embd = lctx.model.hparams.n_embd; + const int32_t n_outputs_max = std::max((uint32_t) n_outputs, cparams.n_seq_max); + + const auto n_batch = cparams.n_batch; + const auto n_vocab = hparams.n_vocab; + const auto n_embd = hparams.n_embd; const int64_t capacity = lctx.output_size; - const bool has_logits = lctx.cparams.causal_attn; - const bool has_embd = lctx.cparams.embeddings; + const bool has_logits = cparams.causal_attn; + const bool has_embd = cparams.embeddings && (!hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); if (!lctx.output_ids) { // never resized afterwards @@ -9211,29 +9214,32 @@ static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) { } // alloc only when more than the current logits capacity is required if (capacity < n_outputs_max) { + lctx.output_size = n_outputs_max; + lctx.logits_size = has_logits ? n_vocab*n_outputs_max : 0; + lctx.embd_size = has_embd ? n_embd*n_outputs_max : 0; + + const size_t buf_output_size = (lctx.logits_size + lctx.embd_size)*sizeof(float); + if (lctx.buf_output) { +#ifndef NDEBUG + const size_t prev_size = ggml_backend_buffer_get_size(lctx.buf_output); + fprintf(stderr, "%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, buf_output_size/ 1024.0 / 1024.0); +#endif ggml_backend_buffer_free(lctx.buf_output); lctx.buf_output = nullptr; lctx.logits = nullptr; lctx.embd = nullptr; } - { - lctx.output_size = n_outputs_max; - lctx.logits_size = has_logits ? n_vocab*n_outputs_max : 0; - lctx.embd_size = has_embd ? n_embd*n_outputs_max : 0; - const size_t buf_output_size = (lctx.logits_size + lctx.embd_size)*sizeof(float); - - lctx.buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buf_output_size); - if (lctx.buf_output == nullptr) { - throw std::runtime_error(format("failed to allocate output buffer of size %.2f MiB", buf_output_size / (1024.0 * 1024.0))); - } - - float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output); - - lctx.logits = has_logits ? output_base : nullptr; - lctx.embd = has_embd ? output_base + lctx.logits_size : nullptr; + lctx.buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buf_output_size); + if (lctx.buf_output == nullptr) { + throw std::runtime_error(format("failed to allocate output buffer of size %.2f MiB", buf_output_size / (1024.0 * 1024.0))); } + + float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output); + + lctx.logits = has_logits ? output_base : nullptr; + lctx.embd = has_embd ? output_base + lctx.logits_size : nullptr; } // set all ids as invalid (assume two's complement negative numbers) memset(lctx.output_ids, -1, n_batch*sizeof(int32_t)); @@ -14038,27 +14044,32 @@ void llama_kv_cache_update(struct llama_context * ctx) { // Returns the *maximum* size of the state size_t llama_get_state_size(const struct llama_context * ctx) { + const auto & cparams = ctx->cparams; + const auto & hparams = ctx->model.hparams; // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. // for reference, std::mt19937(1337) serializes to 6701 bytes. const size_t s_rng_size = sizeof(size_t); const size_t s_rng = LLAMA_MAX_RNG_STATE; + const size_t s_n_outputs = sizeof(size_t); + // assume worst case for outputs although only currently set ones are serialized + const size_t s_output_pos = ctx->cparams.n_batch * sizeof(int32_t); const size_t s_logits_size = sizeof(size_t); - // assume worst case for logits although only currently set ones are serialized - const size_t s_logits = ctx->logits_size * sizeof(float); + const size_t s_logits = ctx->logits_size ? cparams.n_batch * hparams.n_vocab * sizeof(float) : 0; const size_t s_embedding_size = sizeof(size_t); - const size_t s_embedding = ctx->embd_size * sizeof(float); + const size_t s_embedding = ctx->embd_size ? cparams.n_batch * hparams.n_embd * sizeof(float) : 0; const size_t s_kv_buf_size = sizeof(size_t); const size_t s_kv_head = sizeof(uint32_t); const size_t s_kv_size = sizeof(uint32_t); const size_t s_kv_used = sizeof(uint32_t); const size_t s_kv = ctx->kv_self.total_size(); - // TODO: assume the max is more than 1 seq_id per KV cell - const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + sizeof(llama_seq_id); + const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id); const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell; const size_t s_total = ( + s_rng_size + s_rng + + s_n_outputs + + s_output_pos + s_logits_size + s_logits + s_embedding_size @@ -14142,25 +14153,60 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat data_ctx->write(rng_str.data(), rng_size); } - // copy logits + // copy outputs { - const size_t logits_size = ctx->logits_size; + size_t n_outputs = ctx->n_outputs; - data_ctx->write(&logits_size, sizeof(logits_size)); + // copy output ids + { + std::vector output_pos; + const size_t n_batch = ctx->cparams.n_batch; + const int32_t * output_ids = ctx->output_ids; - if (logits_size) { - data_ctx->write(ctx->logits, logits_size * sizeof(float)); + output_pos.resize(n_outputs); + + // build a more compact representation of the output ids + for (size_t i = 0; i < n_batch; ++i) { + // map an output id to a position in the batch + int32_t pos = output_ids[i]; + if (pos >= 0) { + if ((size_t) pos >= output_pos.size()) { + // TODO: maybe fail here instead + LLAMA_LOG_WARN("%s: weird output buffer layout, possibly a bug\n", __func__); + n_outputs = pos + 1; + output_pos.resize(n_outputs); + } + output_pos[pos] = i; + } + } + + data_ctx->write(&n_outputs, sizeof(n_outputs)); + + if (n_outputs) { + data_ctx->write(output_pos.data(), n_outputs * sizeof(int32_t)); + } } - } - // copy embeddings - { - const size_t embeddings_size = ctx->embd_size; + // copy logits + { + const size_t logits_size = std::min(ctx->logits_size, n_outputs * ctx->model.hparams.n_vocab); + + data_ctx->write(&logits_size, sizeof(logits_size)); - data_ctx->write(&embeddings_size, sizeof(embeddings_size)); + if (logits_size) { + data_ctx->write(ctx->logits, logits_size * sizeof(float)); + } + } - if (embeddings_size) { - data_ctx->write(ctx->embd, embeddings_size * sizeof(float)); + // copy embeddings + { + const size_t embeddings_size = std::min(ctx->embd_size, n_outputs * ctx->model.hparams.n_embd); + + data_ctx->write(&embeddings_size, sizeof(embeddings_size)); + + if (embeddings_size) { + data_ctx->write(ctx->embd, embeddings_size * sizeof(float)); + } } } @@ -14257,6 +14303,28 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { GGML_ASSERT(!rng_ss.fail()); } + // set output ids + { + size_t n_outputs; + std::vector output_pos; + + memcpy(&n_outputs, inp, sizeof(n_outputs)); inp += sizeof(n_outputs); + + llama_output_reserve(*ctx, n_outputs); + + if (n_outputs) { + output_pos.resize(n_outputs); + memcpy(output_pos.data(), inp, n_outputs * sizeof(int32_t)); + inp += n_outputs * sizeof(int32_t); + + for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { + int32_t id = output_pos[i]; + GGML_ASSERT((uint32_t) id < ctx->cparams.n_batch); + ctx->output_ids[id] = i; + } + } + } + // set logits { size_t logits_size; @@ -14277,7 +14345,7 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size); - GGML_ASSERT(ctx->embd_size == embeddings_size); + GGML_ASSERT(ctx->embd_size >= embeddings_size); if (embeddings_size) { memcpy(ctx->embd, inp, embeddings_size * sizeof(float)); @@ -14562,7 +14630,6 @@ void llama_synchronize(struct llama_context * ctx) { } float * llama_get_logits(struct llama_context * ctx) { - // TODO: assert that really all logits are in the output llama_synchronize(ctx); return ctx->logits; @@ -14570,12 +14637,17 @@ float * llama_get_logits(struct llama_context * ctx) { float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { const int32_t j = ctx->output_ids[i]; - GGML_ASSERT(0 <= j); llama_synchronize(ctx); - // FIXME: check for nullptr - return ctx->logits + j*ctx->model.hparams.n_vocab; + if (ctx->logits && 0 <= j && j < ctx->n_outputs) { + return ctx->logits + j*ctx->model.hparams.n_vocab; + } + LLAMA_LOG_ERROR("%s: invalid logits id %i\n", __func__, i); +#ifndef NDEBUG + GGML_ASSERT(false); +#endif + return nullptr; } float * llama_get_embeddings(struct llama_context * ctx) { @@ -14586,12 +14658,17 @@ float * llama_get_embeddings(struct llama_context * ctx) { float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { const int32_t j = ctx->output_ids[i]; - GGML_ASSERT(0 <= j); llama_synchronize(ctx); - // FIXME: check for nullptr - return ctx->embd + j*ctx->model.hparams.n_embd; + if (ctx->embd && 0 < j && j < ctx->n_outputs) { + return ctx->embd + j*ctx->model.hparams.n_embd; + } + LLAMA_LOG_ERROR("%s: invalid embeddings id %i\n", __func__, i); +#ifndef NDEBUG + GGML_ASSERT(false); +#endif + return nullptr; } float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) { diff --git a/llama.h b/llama.h index 5eaf07e8a..2e7de69b7 100644 --- a/llama.h +++ b/llama.h @@ -39,7 +39,7 @@ #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 4 +#define LLAMA_SESSION_VERSION 5 #ifdef __cplusplus extern "C" { @@ -674,25 +674,29 @@ extern "C" { LLAMA_API void llama_synchronize(struct llama_context * ctx); // Token logits obtained from the last call to llama_decode() - // WARNING: the following layout is only valid when the batch outputs logits for all tokens - // The logits for the last token are stored in the last row - // Logits for which llama_batch.logits[i] == 0 are undefined - // Rows: n_tokens provided with llama_batch + // The logits for which llama_batch.logits[i] != 0 are stored contiguously + // in the order they have in the batch. + // Rows: number of tokens for which llama_batch.logits[i] != 0 // Cols: n_vocab LLAMA_API float * llama_get_logits(struct llama_context * ctx); // Logits for the ith token. Equivalent to: // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab + // returns NULL for invalid ids. LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); - // Get all output token embeddings - // WARNING: only use when all outputs are requested - // shape: [n_tokens*n_embd] (1-dimensional) + // Get all output token embeddings. + // when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model, + // the embeddings for which llama_batch.logits[i] != 0 are stored contiguously + // in the order they have in the batch. + // shape: [n_outputs*n_embd] + // Otherwise, returns NULL. LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); - // Get the embeddings for the ith token - // llama_get_embeddings(ctx) + i*n_embd + // Get the embeddings for the ith token. Equivalent to: + // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd // shape: [n_embd] (1-dimensional) + // returns NULL for invalid ids. LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); // Get the embeddings for a sequence id