From c8b424fae55e9a2cbfd406d344789ab15e5101e5 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Fri, 26 Jul 2024 19:06:37 -0400 Subject: [PATCH] llama : remove _context suffix for llama_data_context --- src/llama.cpp | 65 +++++++++++++++++++++++++-------------------------- 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index ed58da6ca..62ebfd590 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17296,12 +17296,11 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi return llama_state_save_file(ctx, path_session, tokens, n_token_count); } -// llama_context_data // TODO: replace all non-fatal assertions with returned errors or exceptions -struct llama_data_context { +struct llama_data_write { virtual void write(const void * src, size_t size) = 0; virtual size_t get_size_written() = 0; - virtual ~llama_data_context() = default; + virtual ~llama_data_write() = default; void write_string(const std::string & str, uint32_t max_size) { uint32_t str_size = str.size(); @@ -17522,11 +17521,11 @@ struct llama_data_context { } }; -struct llama_data_read_context { +struct llama_data_read { virtual const uint8_t * read(size_t size) = 0; virtual void read_to(void * dst, size_t size) = 0; virtual size_t get_size_read() = 0; - virtual ~llama_data_read_context() = default; + virtual ~llama_data_read() = default; void read_string(std::string & str, uint32_t max_size) { uint32_t str_size; @@ -17809,10 +17808,10 @@ struct llama_data_read_context { } }; -struct llama_data_dummy_context : llama_data_context { +struct llama_data_write_dummy : llama_data_write { size_t size_written = 0; - llama_data_dummy_context() {} + llama_data_write_dummy() {} // TODO: avoid unnecessary calls to ggml_backend_tensor_get in a dummy context @@ -17825,12 +17824,12 @@ struct llama_data_dummy_context : llama_data_context { } }; -struct llama_data_buffer_context : llama_data_context { +struct llama_data_write_buffer : llama_data_write { uint8_t * ptr; size_t buf_size = 0; size_t size_written = 0; - llama_data_buffer_context(uint8_t * p, size_t len) : ptr(p), buf_size(len) {} + llama_data_write_buffer(uint8_t * p, size_t len) : ptr(p), buf_size(len) {} void write(const void * src, size_t size) override { if (size > buf_size) { @@ -17847,12 +17846,12 @@ struct llama_data_buffer_context : llama_data_context { } }; -struct llama_data_read_buffer_context : llama_data_read_context { +struct llama_data_read_buffer : llama_data_read { const uint8_t * ptr; size_t buf_size = 0; size_t size_read = 0; - llama_data_read_buffer_context(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} + llama_data_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} const uint8_t * read(size_t size) override { const uint8_t * base_ptr = ptr; @@ -17874,11 +17873,11 @@ struct llama_data_read_buffer_context : llama_data_read_context { } }; -struct llama_data_file_context : llama_data_context { +struct llama_data_write_file : llama_data_write { llama_file * file; size_t size_written = 0; - llama_data_file_context(llama_file * f) : file(f) {} + llama_data_write_file(llama_file * f) : file(f) {} void write(const void * src, size_t size) override { file->write_raw(src, size); @@ -17890,12 +17889,12 @@ struct llama_data_file_context : llama_data_context { } }; -struct llama_data_read_file_context : llama_data_read_context { +struct llama_data_read_file : llama_data_read { llama_file * file; size_t size_read = 0; std::vector temp_buffer; - llama_data_read_file_context(llama_file * f) : file(f) {} + llama_data_read_file(llama_file * f) : file(f) {} void read_to(void * dst, size_t size) override { file->read_raw(dst, size); @@ -17917,16 +17916,16 @@ struct llama_data_read_file_context : llama_data_read_context { * * file context: * llama_file file("/path", "wb"); - * llama_data_file_context data_ctx(&file); - * llama_state_get_data(ctx, &data_ctx); + * llama_data_write_file data_ctx(&file); + * llama_state_get_data_internal(ctx, data_ctx); * * buffer context: * std::vector buf(max_size, 0); - * llama_data_buffer_context data_ctx(&buf.data()); - * llama_state_get_data(ctx, &data_ctx); + * llama_data_write_buffer data_ctx(buf.data(), max_size); + * llama_state_get_data_internal(ctx, data_ctx); * */ -static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_context & data_ctx) { +static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx) { llama_synchronize(ctx); data_ctx.write_model_info(ctx); @@ -17944,18 +17943,18 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da } size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) { - llama_data_buffer_context data_ctx(dst, size); + llama_data_write_buffer data_ctx(dst, size); return llama_state_get_data_internal(ctx, data_ctx); } // Returns the *actual* size of the state. // Intended to be used when saving to state to a buffer. size_t llama_state_get_size(struct llama_context * ctx) { - llama_data_dummy_context data_ctx; + llama_data_write_dummy data_ctx; return llama_state_get_data_internal(ctx, data_ctx); } -static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read_context & data_ctx) { +static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx) { llama_synchronize(ctx); data_ctx.read_model_info(ctx); @@ -17975,7 +17974,7 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da // Sets the state reading from the specified source address size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) { - llama_data_read_buffer_context data_ctx(src, size); + llama_data_read_buffer data_ctx(src, size); return llama_state_set_data_internal(ctx, data_ctx); } @@ -18016,7 +18015,7 @@ static bool llama_state_load_file_internal(struct llama_context * ctx, const cha return false; } - llama_data_read_file_context data_ctx(&file); + llama_data_read_file data_ctx(&file); const size_t n_read = llama_state_set_data_internal(ctx, data_ctx); if (n_read != n_state_size_cur) { @@ -18047,7 +18046,7 @@ static bool llama_state_save_file_internal(struct llama_context * ctx, const cha file.write_raw(tokens, sizeof(llama_token) * n_token_count); // save the context state using stream saving - llama_data_file_context data_ctx(&file); + llama_data_write_file data_ctx(&file); llama_state_get_data_internal(ctx, data_ctx); return true; @@ -18062,7 +18061,7 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session } } -static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_context & data_ctx, llama_seq_id seq_id) { +static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) { llama_synchronize(ctx); data_ctx.write_kv_cache(ctx, seq_id); @@ -18071,12 +18070,12 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam } size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) { - llama_data_dummy_context data_ctx; + llama_data_write_dummy data_ctx; return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); } size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) { - llama_data_buffer_context data_ctx(dst, size); + llama_data_write_buffer data_ctx(dst, size); try { return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); } catch (const std::exception & err) { @@ -18085,7 +18084,7 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_ } } -static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read_context & data_ctx, llama_seq_id dest_seq_id) { +static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) { llama_synchronize(ctx); data_ctx.read_kv_cache(ctx, dest_seq_id); @@ -18094,7 +18093,7 @@ static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llam } size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) { - llama_data_read_buffer_context data_ctx(src, size); + llama_data_read_buffer data_ctx(src, size); try { return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id); } catch (const std::exception & err) { @@ -18114,7 +18113,7 @@ static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, con file.write_raw(tokens, sizeof(llama_token) * n_token_count); // save the context state using stream saving - llama_data_file_context data_ctx(&file); + llama_data_write_file data_ctx(&file); llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); const size_t res = file.tell(); @@ -18152,7 +18151,7 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con // restore the context state { const size_t state_size = file.size - file.tell(); - llama_data_read_file_context data_ctx(&file); + llama_data_read_file data_ctx(&file); const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id); if (!nread) { LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);