llama : remove _context suffix for llama_data_context
This commit is contained in:
parent
cddc899b85
commit
c8b424fae5
1 changed files with 32 additions and 33 deletions
|
@ -17296,12 +17296,11 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
|
||||||
return llama_state_save_file(ctx, path_session, tokens, n_token_count);
|
return llama_state_save_file(ctx, path_session, tokens, n_token_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
// llama_context_data
|
|
||||||
// TODO: replace all non-fatal assertions with returned errors or exceptions
|
// TODO: replace all non-fatal assertions with returned errors or exceptions
|
||||||
struct llama_data_context {
|
struct llama_data_write {
|
||||||
virtual void write(const void * src, size_t size) = 0;
|
virtual void write(const void * src, size_t size) = 0;
|
||||||
virtual size_t get_size_written() = 0;
|
virtual size_t get_size_written() = 0;
|
||||||
virtual ~llama_data_context() = default;
|
virtual ~llama_data_write() = default;
|
||||||
|
|
||||||
void write_string(const std::string & str, uint32_t max_size) {
|
void write_string(const std::string & str, uint32_t max_size) {
|
||||||
uint32_t str_size = str.size();
|
uint32_t str_size = str.size();
|
||||||
|
@ -17522,11 +17521,11 @@ struct llama_data_context {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_data_read_context {
|
struct llama_data_read {
|
||||||
virtual const uint8_t * read(size_t size) = 0;
|
virtual const uint8_t * read(size_t size) = 0;
|
||||||
virtual void read_to(void * dst, size_t size) = 0;
|
virtual void read_to(void * dst, size_t size) = 0;
|
||||||
virtual size_t get_size_read() = 0;
|
virtual size_t get_size_read() = 0;
|
||||||
virtual ~llama_data_read_context() = default;
|
virtual ~llama_data_read() = default;
|
||||||
|
|
||||||
void read_string(std::string & str, uint32_t max_size) {
|
void read_string(std::string & str, uint32_t max_size) {
|
||||||
uint32_t str_size;
|
uint32_t str_size;
|
||||||
|
@ -17809,10 +17808,10 @@ struct llama_data_read_context {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_data_dummy_context : llama_data_context {
|
struct llama_data_write_dummy : llama_data_write {
|
||||||
size_t size_written = 0;
|
size_t size_written = 0;
|
||||||
|
|
||||||
llama_data_dummy_context() {}
|
llama_data_write_dummy() {}
|
||||||
|
|
||||||
// TODO: avoid unnecessary calls to ggml_backend_tensor_get in a dummy context
|
// TODO: avoid unnecessary calls to ggml_backend_tensor_get in a dummy context
|
||||||
|
|
||||||
|
@ -17825,12 +17824,12 @@ struct llama_data_dummy_context : llama_data_context {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_data_buffer_context : llama_data_context {
|
struct llama_data_write_buffer : llama_data_write {
|
||||||
uint8_t * ptr;
|
uint8_t * ptr;
|
||||||
size_t buf_size = 0;
|
size_t buf_size = 0;
|
||||||
size_t size_written = 0;
|
size_t size_written = 0;
|
||||||
|
|
||||||
llama_data_buffer_context(uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
llama_data_write_buffer(uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
||||||
|
|
||||||
void write(const void * src, size_t size) override {
|
void write(const void * src, size_t size) override {
|
||||||
if (size > buf_size) {
|
if (size > buf_size) {
|
||||||
|
@ -17847,12 +17846,12 @@ struct llama_data_buffer_context : llama_data_context {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_data_read_buffer_context : llama_data_read_context {
|
struct llama_data_read_buffer : llama_data_read {
|
||||||
const uint8_t * ptr;
|
const uint8_t * ptr;
|
||||||
size_t buf_size = 0;
|
size_t buf_size = 0;
|
||||||
size_t size_read = 0;
|
size_t size_read = 0;
|
||||||
|
|
||||||
llama_data_read_buffer_context(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
llama_data_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
||||||
|
|
||||||
const uint8_t * read(size_t size) override {
|
const uint8_t * read(size_t size) override {
|
||||||
const uint8_t * base_ptr = ptr;
|
const uint8_t * base_ptr = ptr;
|
||||||
|
@ -17874,11 +17873,11 @@ struct llama_data_read_buffer_context : llama_data_read_context {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_data_file_context : llama_data_context {
|
struct llama_data_write_file : llama_data_write {
|
||||||
llama_file * file;
|
llama_file * file;
|
||||||
size_t size_written = 0;
|
size_t size_written = 0;
|
||||||
|
|
||||||
llama_data_file_context(llama_file * f) : file(f) {}
|
llama_data_write_file(llama_file * f) : file(f) {}
|
||||||
|
|
||||||
void write(const void * src, size_t size) override {
|
void write(const void * src, size_t size) override {
|
||||||
file->write_raw(src, size);
|
file->write_raw(src, size);
|
||||||
|
@ -17890,12 +17889,12 @@ struct llama_data_file_context : llama_data_context {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_data_read_file_context : llama_data_read_context {
|
struct llama_data_read_file : llama_data_read {
|
||||||
llama_file * file;
|
llama_file * file;
|
||||||
size_t size_read = 0;
|
size_t size_read = 0;
|
||||||
std::vector<uint8_t> temp_buffer;
|
std::vector<uint8_t> temp_buffer;
|
||||||
|
|
||||||
llama_data_read_file_context(llama_file * f) : file(f) {}
|
llama_data_read_file(llama_file * f) : file(f) {}
|
||||||
|
|
||||||
void read_to(void * dst, size_t size) override {
|
void read_to(void * dst, size_t size) override {
|
||||||
file->read_raw(dst, size);
|
file->read_raw(dst, size);
|
||||||
|
@ -17917,16 +17916,16 @@ struct llama_data_read_file_context : llama_data_read_context {
|
||||||
*
|
*
|
||||||
* file context:
|
* file context:
|
||||||
* llama_file file("/path", "wb");
|
* llama_file file("/path", "wb");
|
||||||
* llama_data_file_context data_ctx(&file);
|
* llama_data_write_file data_ctx(&file);
|
||||||
* llama_state_get_data(ctx, &data_ctx);
|
* llama_state_get_data_internal(ctx, data_ctx);
|
||||||
*
|
*
|
||||||
* buffer context:
|
* buffer context:
|
||||||
* std::vector<uint8_t> buf(max_size, 0);
|
* std::vector<uint8_t> buf(max_size, 0);
|
||||||
* llama_data_buffer_context data_ctx(&buf.data());
|
* llama_data_write_buffer data_ctx(buf.data(), max_size);
|
||||||
* llama_state_get_data(ctx, &data_ctx);
|
* llama_state_get_data_internal(ctx, data_ctx);
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_context & data_ctx) {
|
static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx) {
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
data_ctx.write_model_info(ctx);
|
data_ctx.write_model_info(ctx);
|
||||||
|
@ -17944,18 +17943,18 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) {
|
size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) {
|
||||||
llama_data_buffer_context data_ctx(dst, size);
|
llama_data_write_buffer data_ctx(dst, size);
|
||||||
return llama_state_get_data_internal(ctx, data_ctx);
|
return llama_state_get_data_internal(ctx, data_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the *actual* size of the state.
|
// Returns the *actual* size of the state.
|
||||||
// Intended to be used when saving to state to a buffer.
|
// Intended to be used when saving to state to a buffer.
|
||||||
size_t llama_state_get_size(struct llama_context * ctx) {
|
size_t llama_state_get_size(struct llama_context * ctx) {
|
||||||
llama_data_dummy_context data_ctx;
|
llama_data_write_dummy data_ctx;
|
||||||
return llama_state_get_data_internal(ctx, data_ctx);
|
return llama_state_get_data_internal(ctx, data_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read_context & data_ctx) {
|
static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx) {
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
data_ctx.read_model_info(ctx);
|
data_ctx.read_model_info(ctx);
|
||||||
|
@ -17975,7 +17974,7 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da
|
||||||
|
|
||||||
// Sets the state reading from the specified source address
|
// Sets the state reading from the specified source address
|
||||||
size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) {
|
size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) {
|
||||||
llama_data_read_buffer_context data_ctx(src, size);
|
llama_data_read_buffer data_ctx(src, size);
|
||||||
return llama_state_set_data_internal(ctx, data_ctx);
|
return llama_state_set_data_internal(ctx, data_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18016,7 +18015,7 @@ static bool llama_state_load_file_internal(struct llama_context * ctx, const cha
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_data_read_file_context data_ctx(&file);
|
llama_data_read_file data_ctx(&file);
|
||||||
const size_t n_read = llama_state_set_data_internal(ctx, data_ctx);
|
const size_t n_read = llama_state_set_data_internal(ctx, data_ctx);
|
||||||
|
|
||||||
if (n_read != n_state_size_cur) {
|
if (n_read != n_state_size_cur) {
|
||||||
|
@ -18047,7 +18046,7 @@ static bool llama_state_save_file_internal(struct llama_context * ctx, const cha
|
||||||
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
||||||
|
|
||||||
// save the context state using stream saving
|
// save the context state using stream saving
|
||||||
llama_data_file_context data_ctx(&file);
|
llama_data_write_file data_ctx(&file);
|
||||||
llama_state_get_data_internal(ctx, data_ctx);
|
llama_state_get_data_internal(ctx, data_ctx);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
@ -18062,7 +18061,7 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_context & data_ctx, llama_seq_id seq_id) {
|
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
data_ctx.write_kv_cache(ctx, seq_id);
|
data_ctx.write_kv_cache(ctx, seq_id);
|
||||||
|
@ -18071,12 +18070,12 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) {
|
size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||||
llama_data_dummy_context data_ctx;
|
llama_data_write_dummy data_ctx;
|
||||||
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
|
size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
|
||||||
llama_data_buffer_context data_ctx(dst, size);
|
llama_data_write_buffer data_ctx(dst, size);
|
||||||
try {
|
try {
|
||||||
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
||||||
} catch (const std::exception & err) {
|
} catch (const std::exception & err) {
|
||||||
|
@ -18085,7 +18084,7 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read_context & data_ctx, llama_seq_id dest_seq_id) {
|
static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
data_ctx.read_kv_cache(ctx, dest_seq_id);
|
data_ctx.read_kv_cache(ctx, dest_seq_id);
|
||||||
|
@ -18094,7 +18093,7 @@ static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llam
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) {
|
size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) {
|
||||||
llama_data_read_buffer_context data_ctx(src, size);
|
llama_data_read_buffer data_ctx(src, size);
|
||||||
try {
|
try {
|
||||||
return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
|
return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
|
||||||
} catch (const std::exception & err) {
|
} catch (const std::exception & err) {
|
||||||
|
@ -18114,7 +18113,7 @@ static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, con
|
||||||
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
||||||
|
|
||||||
// save the context state using stream saving
|
// save the context state using stream saving
|
||||||
llama_data_file_context data_ctx(&file);
|
llama_data_write_file data_ctx(&file);
|
||||||
llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
||||||
|
|
||||||
const size_t res = file.tell();
|
const size_t res = file.tell();
|
||||||
|
@ -18152,7 +18151,7 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con
|
||||||
// restore the context state
|
// restore the context state
|
||||||
{
|
{
|
||||||
const size_t state_size = file.size - file.tell();
|
const size_t state_size = file.size - file.tell();
|
||||||
llama_data_read_file_context data_ctx(&file);
|
llama_data_read_file data_ctx(&file);
|
||||||
const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
|
const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
|
||||||
if (!nread) {
|
if (!nread) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue