From 6c798db0413d8701e9986cfde1c76ec888f434a9 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Wed, 2 Aug 2023 16:41:25 +0800 Subject: [PATCH] added stream saving context data to file to avoid allocating unnecessary amounts of memory --- llama.cpp | 109 +++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 100 insertions(+), 9 deletions(-) diff --git a/llama.cpp b/llama.cpp index d427054dd..543d85e2c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3841,6 +3841,104 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { return written; } +// writes state data directly to file instead of copying it into a buffer +void llama_write_state_data_to_file(struct llama_context * ctx, llama_file * dst_file) { + // copy rng + { + std::stringstream rng_ss; + rng_ss << ctx->rng; + + const size_t rng_size = rng_ss.str().size(); + char rng_buf[LLAMA_MAX_RNG_STATE]; + + memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE); + memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); + + dst_file->write_raw(&rng_size, sizeof(rng_size)); + dst_file->write_raw(&rng_buf[0], LLAMA_MAX_RNG_STATE); + } + + // copy logits + { + const size_t logits_cap = ctx->logits.capacity(); + const size_t logits_size = ctx->logits.size(); + + dst_file->write_raw(&logits_cap, sizeof(logits_cap)); + dst_file->write_raw(&logits_size, sizeof(logits_size)); + + if (logits_size) { + dst_file->write_raw(ctx->logits.data(), logits_size * sizeof(float)); + } + + // If there is a gap between the size and the capacity, write padding + size_t padding_size = (logits_cap - logits_size) * sizeof(float); + if (padding_size > 0) { + std::vector padding(padding_size, 0); // Create a buffer filled with zeros + dst_file->write_raw(padding.data(), padding_size); + } + } + + // copy embeddings + { + const size_t embedding_size = ctx->embedding.size(); + + dst_file->write_raw(&embedding_size, sizeof(embedding_size)); + + if (embedding_size) { + dst_file->write_raw(ctx->embedding.data(), embedding_size * sizeof(float)); + } + } + + // copy kv cache + { + const auto & kv_self = ctx->kv_self; + const auto & hparams = ctx->model.hparams; + const int n_layer = hparams.n_layer; + const int n_embd = hparams.n_embd_gqa(); + const int n_ctx = hparams.n_ctx; + + const size_t kv_size = kv_self.buf.size; + const int kv_ntok = llama_get_kv_cache_token_count(ctx); + + dst_file->write_raw(&kv_size, sizeof(kv_size)); + dst_file->write_raw(&kv_ntok, sizeof(kv_ntok)); + + if (kv_size) { + const size_t elt_size = ggml_element_size(kv_self.k); + + ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); + ggml_cgraph gf{}; + + ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); + std::vector kout3d_data(ggml_nbytes(kout3d), 0); + kout3d->data = kout3d_data.data(); + + ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer); + std::vector vout3d_data(ggml_nbytes(vout3d), 0); + vout3d->data = vout3d_data.data(); + + ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, + n_embd, kv_ntok, n_layer, + elt_size*n_embd, elt_size*n_embd*n_ctx, 0); + + ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v, + kv_ntok, n_embd, n_layer, + elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); + + ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); + ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d)); + ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); + + ggml_free(cpy_ctx); + + // our data is now in the kout3d_data and vout3d_data buffers + // write them to file + dst_file->write_raw(kout3d_data.data(), kout3d_data.size()); + dst_file->write_raw(vout3d_data.data(), vout3d_data.size()); + } + } +} + // Sets the state reading from the specified source address size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { uint8_t * inp = src; @@ -4023,15 +4121,8 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi file.write_u32((uint32_t) n_token_count); file.write_raw(tokens, sizeof(llama_token) * n_token_count); - // save the context state - { - const size_t n_state_size_max = llama_get_state_size(ctx); - - std::vector state_data(n_state_size_max); - const size_t n_state_size_cur = llama_copy_state_data(ctx, state_data.data()); - - file.write_raw(state_data.data(), n_state_size_cur); - } + // save the context state using stream saving + llama_write_state_data_to_file(ctx, &file); return true; }