From 5ee581473e8bf5e04141d5f0ab0798df81a72328 Mon Sep 17 00:00:00 2001 From: David Friehs Date: Mon, 8 Jan 2024 08:56:20 +0100 Subject: [PATCH] llama : only save and restore used logits for batch sizes of 512 this reduces save state in the best case by around 62 MB, which can be a lot if planning to save on each message to allow regenerating messages. --- llama.cpp | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/llama.cpp b/llama.cpp index 089533a60..ab75ec93e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10145,8 +10145,8 @@ size_t llama_get_state_size(const struct llama_context * ctx) { // for reference, std::mt19937(1337) serializes to 6701 bytes. const size_t s_rng_size = sizeof(size_t); const size_t s_rng = LLAMA_MAX_RNG_STATE; - const size_t s_logits_capacity = sizeof(size_t); const size_t s_logits_size = sizeof(size_t); + // assume worst case for logits although only currently set ones are serialized const size_t s_logits = ctx->logits.capacity() * sizeof(float); const size_t s_embedding_size = sizeof(size_t); const size_t s_embedding = ctx->embedding.size() * sizeof(float); @@ -10157,7 +10157,6 @@ size_t llama_get_state_size(const struct llama_context * ctx) { const size_t s_total = ( + s_rng_size + s_rng - + s_logits_capacity + s_logits_size + s_logits + s_embedding_size @@ -10241,22 +10240,13 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat // copy logits { - const size_t logits_cap = ctx->logits.capacity(); const size_t logits_size = ctx->logits.size(); - data_ctx->write(&logits_cap, sizeof(logits_cap)); data_ctx->write(&logits_size, sizeof(logits_size)); if (logits_size) { data_ctx->write(ctx->logits.data(), logits_size * sizeof(float)); } - - // If there is a gap between the size and the capacity, write padding - size_t padding_size = (logits_cap - logits_size) * sizeof(float); - if (padding_size > 0) { - std::vector padding(padding_size, 0); // Create a buffer filled with zeros - data_ctx->write(padding.data(), padding_size); - } } // copy embeddings @@ -10380,20 +10370,18 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { // set logits { - size_t logits_cap; size_t logits_size; - memcpy(&logits_cap, inp, sizeof(logits_cap)); inp += sizeof(logits_cap); memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size); - GGML_ASSERT(ctx->logits.capacity() == logits_cap); + GGML_ASSERT(ctx->logits.capacity() >= logits_size); if (logits_size) { ctx->logits.resize(logits_size); - memcpy(ctx->logits.data(), inp, logits_size * sizeof(float)); - } - inp += logits_cap * sizeof(float); + memcpy(ctx->logits.data(), inp, logits_size * sizeof(float)); + inp += logits_size * sizeof(float); + } } // set embeddings