From 3cebd6e4b70c7b976e1c2e7982972d9b2ee837ff Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Thu, 3 Aug 2023 17:02:29 +0800 Subject: [PATCH] generalised copying state data to file or buffer --- llama-util.h | 23 +++++++++ llama.cpp | 129 +++++++-------------------------------------------- 2 files changed, 39 insertions(+), 113 deletions(-) diff --git a/llama-util.h b/llama-util.h index 042ebe43c..da1cdf8d5 100644 --- a/llama-util.h +++ b/llama-util.h @@ -149,6 +149,29 @@ struct llama_file { } }; +// llama_context_data +struct llama_data_context { + virtual void write(const void* src, size_t size) = 0; + virtual ~llama_data_context() = default; +}; + +struct llama_data_buffer_context : llama_data_context { + uint8_t* ptr; + llama_data_buffer_context(uint8_t* p) : ptr(p) {} + void write(const void* src, size_t size) override { + memcpy(ptr, src, size); + ptr += size; + } +}; + +struct llama_data_file_context : llama_data_context { + llama_file* file; + llama_data_file_context(llama_file* f) : file(f) {} + void write(const void* src, size_t size) override { + file->write_raw(src, size); + } +}; + #if defined(_WIN32) static std::string llama_format_win_err(DWORD err) { LPSTR buf; diff --git a/llama.cpp b/llama.cpp index 543d85e2c..83baa941b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3743,10 +3743,8 @@ size_t llama_get_state_size(const struct llama_context * ctx) { return s_total; } -// Copies the state to the specified destination address -size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { - uint8_t * out = dst; - +// copy state data into either a buffer or file depending on the passed in context +void llama_copy_state_data(struct llama_context * ctx, llama_data_context * data_ctx) { // copy rng { std::stringstream rng_ss; @@ -3758,8 +3756,8 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE); memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); - memcpy(out, &rng_size, sizeof(rng_size)); out += sizeof(rng_size); - memcpy(out, &rng_buf[0], LLAMA_MAX_RNG_STATE); out += LLAMA_MAX_RNG_STATE; + data_ctx->write(&rng_size, sizeof(rng_size)); + data_ctx->write(&rng_buf[0], LLAMA_MAX_RNG_STATE); } // copy logits @@ -3767,114 +3765,18 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { const size_t logits_cap = ctx->logits.capacity(); const size_t logits_size = ctx->logits.size(); - memcpy(out, &logits_cap, sizeof(logits_cap)); out += sizeof(logits_cap); - memcpy(out, &logits_size, sizeof(logits_size)); out += sizeof(logits_size); + data_ctx->write(&logits_cap, sizeof(logits_cap)); + data_ctx->write(&logits_size, sizeof(logits_size)); if (logits_size) { - memcpy(out, ctx->logits.data(), logits_size * sizeof(float)); - } - - out += logits_cap * sizeof(float); - } - - // copy embeddings - { - const size_t embedding_size = ctx->embedding.size(); - - memcpy(out, &embedding_size, sizeof(embedding_size)); out += sizeof(embedding_size); - - if (embedding_size) { - memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); - out += embedding_size * sizeof(float); - } - } - - // copy kv cache - { - const auto & kv_self = ctx->kv_self; - const auto & hparams = ctx->model.hparams; - const int n_layer = hparams.n_layer; - const int n_embd = hparams.n_embd_gqa(); - const int n_ctx = hparams.n_ctx; - - const size_t kv_size = kv_self.buf.size; - const int kv_ntok = llama_get_kv_cache_token_count(ctx); - - memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size); - memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok); - - if (kv_size) { - const size_t elt_size = ggml_element_size(kv_self.k); - - ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); - ggml_cgraph gf{}; - - ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); - kout3d->data = out; - out += ggml_nbytes(kout3d); - - ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer); - vout3d->data = out; - out += ggml_nbytes(vout3d); - - ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, - n_embd, kv_ntok, n_layer, - elt_size*n_embd, elt_size*n_embd*n_ctx, 0); - - ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v, - kv_ntok, n_embd, n_layer, - elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); - - ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); - ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d)); - ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); - - ggml_free(cpy_ctx); - } - } - - const size_t written = out - dst; - const size_t max_size = llama_get_state_size(ctx); - - LLAMA_ASSERT(written <= max_size); - - return written; -} - -// writes state data directly to file instead of copying it into a buffer -void llama_write_state_data_to_file(struct llama_context * ctx, llama_file * dst_file) { - // copy rng - { - std::stringstream rng_ss; - rng_ss << ctx->rng; - - const size_t rng_size = rng_ss.str().size(); - char rng_buf[LLAMA_MAX_RNG_STATE]; - - memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE); - memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); - - dst_file->write_raw(&rng_size, sizeof(rng_size)); - dst_file->write_raw(&rng_buf[0], LLAMA_MAX_RNG_STATE); - } - - // copy logits - { - const size_t logits_cap = ctx->logits.capacity(); - const size_t logits_size = ctx->logits.size(); - - dst_file->write_raw(&logits_cap, sizeof(logits_cap)); - dst_file->write_raw(&logits_size, sizeof(logits_size)); - - if (logits_size) { - dst_file->write_raw(ctx->logits.data(), logits_size * sizeof(float)); + data_ctx->write(ctx->logits.data(), logits_size * sizeof(float)); } // If there is a gap between the size and the capacity, write padding size_t padding_size = (logits_cap - logits_size) * sizeof(float); if (padding_size > 0) { std::vector padding(padding_size, 0); // Create a buffer filled with zeros - dst_file->write_raw(padding.data(), padding_size); + data_ctx->write(padding.data(), padding_size); } } @@ -3882,10 +3784,10 @@ void llama_write_state_data_to_file(struct llama_context * ctx, llama_file * dst { const size_t embedding_size = ctx->embedding.size(); - dst_file->write_raw(&embedding_size, sizeof(embedding_size)); + data_ctx->write(&embedding_size, sizeof(embedding_size)); if (embedding_size) { - dst_file->write_raw(ctx->embedding.data(), embedding_size * sizeof(float)); + data_ctx->write(ctx->embedding.data(), embedding_size * sizeof(float)); } } @@ -3900,8 +3802,8 @@ void llama_write_state_data_to_file(struct llama_context * ctx, llama_file * dst const size_t kv_size = kv_self.buf.size; const int kv_ntok = llama_get_kv_cache_token_count(ctx); - dst_file->write_raw(&kv_size, sizeof(kv_size)); - dst_file->write_raw(&kv_ntok, sizeof(kv_ntok)); + data_ctx->write(&kv_size, sizeof(kv_size)); + data_ctx->write(&kv_ntok, sizeof(kv_ntok)); if (kv_size) { const size_t elt_size = ggml_element_size(kv_self.k); @@ -3933,8 +3835,8 @@ void llama_write_state_data_to_file(struct llama_context * ctx, llama_file * dst // our data is now in the kout3d_data and vout3d_data buffers // write them to file - dst_file->write_raw(kout3d_data.data(), kout3d_data.size()); - dst_file->write_raw(vout3d_data.data(), vout3d_data.size()); + data_ctx->write(kout3d_data.data(), kout3d_data.size()); + data_ctx->write(vout3d_data.data(), vout3d_data.size()); } } } @@ -4122,7 +4024,8 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi file.write_raw(tokens, sizeof(llama_token) * n_token_count); // save the context state using stream saving - llama_write_state_data_to_file(ctx, &file); + llama_data_file_context data_ctx(&file); + llama_copy_state_data(ctx, &data_ctx); return true; }