- restored breakage of the llama_copy_state_data API

- moved new logic for copying llama state data to internal function
This commit is contained in:
l3utterfly 2023-08-03 20:47:35 +08:00
parent 1ffa6be726
commit 7c5b2b57b0
3 changed files with 32 additions and 3 deletions

View file

@ -66,7 +66,11 @@ int main(int argc, char ** argv) {
uint8_t * state_mem = new uint8_t[state_size]; uint8_t * state_mem = new uint8_t[state_size];
// Save state (rng, logits, embedding and kv_cache) to file // Save state (rng, logits, embedding and kv_cache) to file
llama_save_session_file(ctx, "dump_state.bin", tokens.data(), tokens.size()); {
llama_file file("dump_state.bin", "wb");
llama_data_file_context data_ctx(&file);
llama_copy_state_data(ctx, &data_ctx); // can also copy to memory buffer
}
// save state (last tokens) // save state (last tokens)
const auto last_n_tokens_data_saved = std::vector<llama_token>(last_n_tokens_data); const auto last_n_tokens_data_saved = std::vector<llama_token>(last_n_tokens_data);

View file

@ -152,23 +152,40 @@ struct llama_file {
// llama_context_data // llama_context_data
struct llama_data_context { struct llama_data_context {
virtual void write(const void* src, size_t size) = 0; virtual void write(const void* src, size_t size) = 0;
virtual size_t get_size_written() = 0;
virtual ~llama_data_context() = default; virtual ~llama_data_context() = default;
}; };
struct llama_data_buffer_context : llama_data_context { struct llama_data_buffer_context : llama_data_context {
uint8_t* ptr; uint8_t* ptr;
size_t size_written = 0;
llama_data_buffer_context(uint8_t* p) : ptr(p) {} llama_data_buffer_context(uint8_t* p) : ptr(p) {}
void write(const void* src, size_t size) override { void write(const void* src, size_t size) override {
memcpy(ptr, src, size); memcpy(ptr, src, size);
ptr += size; ptr += size;
size_written += size;
}
size_t get_size_written() override {
return size_written;
} }
}; };
struct llama_data_file_context : llama_data_context { struct llama_data_file_context : llama_data_context {
llama_file* file; llama_file* file;
size_t size_written = 0;
llama_data_file_context(llama_file* f) : file(f) {} llama_data_file_context(llama_file* f) : file(f) {}
void write(const void* src, size_t size) override { void write(const void* src, size_t size) override {
file->write_raw(src, size); file->write_raw(src, size);
size_written += size;
}
size_t get_size_written() override {
return size_written;
} }
}; };

View file

@ -3743,6 +3743,14 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
return s_total; return s_total;
} }
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst)
{
llama_data_buffer_context data_ctx(dst);
llama_copy_state_data_internal(ctx, &data_ctx);
return data_ctx.get_size_written();
}
/** copy state data into either a buffer or file depending on the passed in context /** copy state data into either a buffer or file depending on the passed in context
* *
* file context: * file context:
@ -3756,7 +3764,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
* llama_copy_state_data(ctx, &data_ctx); * llama_copy_state_data(ctx, &data_ctx);
* *
*/ */
void llama_copy_state_data(struct llama_context * ctx, llama_data_context * data_ctx) { void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
// copy rng // copy rng
{ {
std::stringstream rng_ss; std::stringstream rng_ss;
@ -4037,7 +4045,7 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
// save the context state using stream saving // save the context state using stream saving
llama_data_file_context data_ctx(&file); llama_data_file_context data_ctx(&file);
llama_copy_state_data(ctx, &data_ctx); llama_copy_state_data_internal(ctx, &data_ctx);
return true; return true;
} }