From 7c5b2b57b0c6ecd9a5dbb078385d2dc430a7579f Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Thu, 3 Aug 2023 20:47:35 +0800 Subject: [PATCH] - restored breakage of the llama_copy_state_data API - moved new logic for copying llama state data to internal function --- examples/save-load-state/save-load-state.cpp | 6 +++++- llama-util.h | 17 +++++++++++++++++ llama.cpp | 12 ++++++++++-- 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 5376c13c7..b1856e592 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -66,7 +66,11 @@ int main(int argc, char ** argv) { uint8_t * state_mem = new uint8_t[state_size]; // Save state (rng, logits, embedding and kv_cache) to file - llama_save_session_file(ctx, "dump_state.bin", tokens.data(), tokens.size()); + { + llama_file file("dump_state.bin", "wb"); + llama_data_file_context data_ctx(&file); + llama_copy_state_data(ctx, &data_ctx); // can also copy to memory buffer + } // save state (last tokens) const auto last_n_tokens_data_saved = std::vector(last_n_tokens_data); diff --git a/llama-util.h b/llama-util.h index da1cdf8d5..e58d60482 100644 --- a/llama-util.h +++ b/llama-util.h @@ -152,23 +152,40 @@ struct llama_file { // llama_context_data struct llama_data_context { virtual void write(const void* src, size_t size) = 0; + virtual size_t get_size_written() = 0; virtual ~llama_data_context() = default; }; struct llama_data_buffer_context : llama_data_context { uint8_t* ptr; + size_t size_written = 0; + llama_data_buffer_context(uint8_t* p) : ptr(p) {} + void write(const void* src, size_t size) override { memcpy(ptr, src, size); ptr += size; + size_written += size; + } + + size_t get_size_written() override { + return size_written; } }; struct llama_data_file_context : llama_data_context { llama_file* file; + size_t size_written = 0; + llama_data_file_context(llama_file* f) : file(f) {} + void write(const void* src, size_t size) override { file->write_raw(src, size); + size_written += size; + } + + size_t get_size_written() override { + return size_written; } }; diff --git a/llama.cpp b/llama.cpp index e88ac9342..5f834691f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3743,6 +3743,14 @@ size_t llama_get_state_size(const struct llama_context * ctx) { return s_total; } +size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) +{ + llama_data_buffer_context data_ctx(dst); + llama_copy_state_data_internal(ctx, &data_ctx); + + return data_ctx.get_size_written(); +} + /** copy state data into either a buffer or file depending on the passed in context * * file context: @@ -3756,7 +3764,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) { * llama_copy_state_data(ctx, &data_ctx); * */ -void llama_copy_state_data(struct llama_context * ctx, llama_data_context * data_ctx) { +void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { // copy rng { std::stringstream rng_ss; @@ -4037,7 +4045,7 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi // save the context state using stream saving llama_data_file_context data_ctx(&file); - llama_copy_state_data(ctx, &data_ctx); + llama_copy_state_data_internal(ctx, &data_ctx); return true; }