llama : remove _context suffix for llama_data_context

This commit is contained in:
Francis Couture-Harpin 2024-07-26 19:06:37 -04:00
parent cddc899b85
commit c8b424fae5

View file

@ -17296,12 +17296,11 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
return llama_state_save_file(ctx, path_session, tokens, n_token_count);
}
// llama_context_data
// TODO: replace all non-fatal assertions with returned errors or exceptions
struct llama_data_context {
struct llama_data_write {
virtual void write(const void * src, size_t size) = 0;
virtual size_t get_size_written() = 0;
virtual ~llama_data_context() = default;
virtual ~llama_data_write() = default;
void write_string(const std::string & str, uint32_t max_size) {
uint32_t str_size = str.size();
@ -17522,11 +17521,11 @@ struct llama_data_context {
}
};
struct llama_data_read_context {
struct llama_data_read {
virtual const uint8_t * read(size_t size) = 0;
virtual void read_to(void * dst, size_t size) = 0;
virtual size_t get_size_read() = 0;
virtual ~llama_data_read_context() = default;
virtual ~llama_data_read() = default;
void read_string(std::string & str, uint32_t max_size) {
uint32_t str_size;
@ -17809,10 +17808,10 @@ struct llama_data_read_context {
}
};
struct llama_data_dummy_context : llama_data_context {
struct llama_data_write_dummy : llama_data_write {
size_t size_written = 0;
llama_data_dummy_context() {}
llama_data_write_dummy() {}
// TODO: avoid unnecessary calls to ggml_backend_tensor_get in a dummy context
@ -17825,12 +17824,12 @@ struct llama_data_dummy_context : llama_data_context {
}
};
struct llama_data_buffer_context : llama_data_context {
struct llama_data_write_buffer : llama_data_write {
uint8_t * ptr;
size_t buf_size = 0;
size_t size_written = 0;
llama_data_buffer_context(uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
llama_data_write_buffer(uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
void write(const void * src, size_t size) override {
if (size > buf_size) {
@ -17847,12 +17846,12 @@ struct llama_data_buffer_context : llama_data_context {
}
};
struct llama_data_read_buffer_context : llama_data_read_context {
struct llama_data_read_buffer : llama_data_read {
const uint8_t * ptr;
size_t buf_size = 0;
size_t size_read = 0;
llama_data_read_buffer_context(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
llama_data_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
const uint8_t * read(size_t size) override {
const uint8_t * base_ptr = ptr;
@ -17874,11 +17873,11 @@ struct llama_data_read_buffer_context : llama_data_read_context {
}
};
struct llama_data_file_context : llama_data_context {
struct llama_data_write_file : llama_data_write {
llama_file * file;
size_t size_written = 0;
llama_data_file_context(llama_file * f) : file(f) {}
llama_data_write_file(llama_file * f) : file(f) {}
void write(const void * src, size_t size) override {
file->write_raw(src, size);
@ -17890,12 +17889,12 @@ struct llama_data_file_context : llama_data_context {
}
};
struct llama_data_read_file_context : llama_data_read_context {
struct llama_data_read_file : llama_data_read {
llama_file * file;
size_t size_read = 0;
std::vector<uint8_t> temp_buffer;
llama_data_read_file_context(llama_file * f) : file(f) {}
llama_data_read_file(llama_file * f) : file(f) {}
void read_to(void * dst, size_t size) override {
file->read_raw(dst, size);
@ -17917,16 +17916,16 @@ struct llama_data_read_file_context : llama_data_read_context {
*
* file context:
* llama_file file("/path", "wb");
* llama_data_file_context data_ctx(&file);
* llama_state_get_data(ctx, &data_ctx);
* llama_data_write_file data_ctx(&file);
* llama_state_get_data_internal(ctx, data_ctx);
*
* buffer context:
* std::vector<uint8_t> buf(max_size, 0);
* llama_data_buffer_context data_ctx(&buf.data());
* llama_state_get_data(ctx, &data_ctx);
* llama_data_write_buffer data_ctx(buf.data(), max_size);
* llama_state_get_data_internal(ctx, data_ctx);
*
*/
static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_context & data_ctx) {
static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx) {
llama_synchronize(ctx);
data_ctx.write_model_info(ctx);
@ -17944,18 +17943,18 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da
}
size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) {
llama_data_buffer_context data_ctx(dst, size);
llama_data_write_buffer data_ctx(dst, size);
return llama_state_get_data_internal(ctx, data_ctx);
}
// Returns the *actual* size of the state.
// Intended to be used when saving to state to a buffer.
size_t llama_state_get_size(struct llama_context * ctx) {
llama_data_dummy_context data_ctx;
llama_data_write_dummy data_ctx;
return llama_state_get_data_internal(ctx, data_ctx);
}
static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read_context & data_ctx) {
static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx) {
llama_synchronize(ctx);
data_ctx.read_model_info(ctx);
@ -17975,7 +17974,7 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da
// Sets the state reading from the specified source address
size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) {
llama_data_read_buffer_context data_ctx(src, size);
llama_data_read_buffer data_ctx(src, size);
return llama_state_set_data_internal(ctx, data_ctx);
}
@ -18016,7 +18015,7 @@ static bool llama_state_load_file_internal(struct llama_context * ctx, const cha
return false;
}
llama_data_read_file_context data_ctx(&file);
llama_data_read_file data_ctx(&file);
const size_t n_read = llama_state_set_data_internal(ctx, data_ctx);
if (n_read != n_state_size_cur) {
@ -18047,7 +18046,7 @@ static bool llama_state_save_file_internal(struct llama_context * ctx, const cha
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
// save the context state using stream saving
llama_data_file_context data_ctx(&file);
llama_data_write_file data_ctx(&file);
llama_state_get_data_internal(ctx, data_ctx);
return true;
@ -18062,7 +18061,7 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session
}
}
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_context & data_ctx, llama_seq_id seq_id) {
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
llama_synchronize(ctx);
data_ctx.write_kv_cache(ctx, seq_id);
@ -18071,12 +18070,12 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
}
size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) {
llama_data_dummy_context data_ctx;
llama_data_write_dummy data_ctx;
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
}
size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
llama_data_buffer_context data_ctx(dst, size);
llama_data_write_buffer data_ctx(dst, size);
try {
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
} catch (const std::exception & err) {
@ -18085,7 +18084,7 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_
}
}
static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read_context & data_ctx, llama_seq_id dest_seq_id) {
static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
llama_synchronize(ctx);
data_ctx.read_kv_cache(ctx, dest_seq_id);
@ -18094,7 +18093,7 @@ static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llam
}
size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) {
llama_data_read_buffer_context data_ctx(src, size);
llama_data_read_buffer data_ctx(src, size);
try {
return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
} catch (const std::exception & err) {
@ -18114,7 +18113,7 @@ static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, con
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
// save the context state using stream saving
llama_data_file_context data_ctx(&file);
llama_data_write_file data_ctx(&file);
llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
const size_t res = file.tell();
@ -18152,7 +18151,7 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con
// restore the context state
{
const size_t state_size = file.size - file.tell();
llama_data_read_file_context data_ctx(&file);
llama_data_read_file data_ctx(&file);
const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
if (!nread) {
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);