- restored breakage of the llama_copy_state_data API
- moved new logic for copying llama state data to internal function
This commit is contained in:
parent
1ffa6be726
commit
7c5b2b57b0
3 changed files with 32 additions and 3 deletions
|
@ -66,7 +66,11 @@ int main(int argc, char ** argv) {
|
||||||
uint8_t * state_mem = new uint8_t[state_size];
|
uint8_t * state_mem = new uint8_t[state_size];
|
||||||
|
|
||||||
// Save state (rng, logits, embedding and kv_cache) to file
|
// Save state (rng, logits, embedding and kv_cache) to file
|
||||||
llama_save_session_file(ctx, "dump_state.bin", tokens.data(), tokens.size());
|
{
|
||||||
|
llama_file file("dump_state.bin", "wb");
|
||||||
|
llama_data_file_context data_ctx(&file);
|
||||||
|
llama_copy_state_data(ctx, &data_ctx); // can also copy to memory buffer
|
||||||
|
}
|
||||||
|
|
||||||
// save state (last tokens)
|
// save state (last tokens)
|
||||||
const auto last_n_tokens_data_saved = std::vector<llama_token>(last_n_tokens_data);
|
const auto last_n_tokens_data_saved = std::vector<llama_token>(last_n_tokens_data);
|
||||||
|
|
17
llama-util.h
17
llama-util.h
|
@ -152,23 +152,40 @@ struct llama_file {
|
||||||
// llama_context_data
|
// llama_context_data
|
||||||
struct llama_data_context {
|
struct llama_data_context {
|
||||||
virtual void write(const void* src, size_t size) = 0;
|
virtual void write(const void* src, size_t size) = 0;
|
||||||
|
virtual size_t get_size_written() = 0;
|
||||||
virtual ~llama_data_context() = default;
|
virtual ~llama_data_context() = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_data_buffer_context : llama_data_context {
|
struct llama_data_buffer_context : llama_data_context {
|
||||||
uint8_t* ptr;
|
uint8_t* ptr;
|
||||||
|
size_t size_written = 0;
|
||||||
|
|
||||||
llama_data_buffer_context(uint8_t* p) : ptr(p) {}
|
llama_data_buffer_context(uint8_t* p) : ptr(p) {}
|
||||||
|
|
||||||
void write(const void* src, size_t size) override {
|
void write(const void* src, size_t size) override {
|
||||||
memcpy(ptr, src, size);
|
memcpy(ptr, src, size);
|
||||||
ptr += size;
|
ptr += size;
|
||||||
|
size_written += size;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t get_size_written() override {
|
||||||
|
return size_written;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_data_file_context : llama_data_context {
|
struct llama_data_file_context : llama_data_context {
|
||||||
llama_file* file;
|
llama_file* file;
|
||||||
|
size_t size_written = 0;
|
||||||
|
|
||||||
llama_data_file_context(llama_file* f) : file(f) {}
|
llama_data_file_context(llama_file* f) : file(f) {}
|
||||||
|
|
||||||
void write(const void* src, size_t size) override {
|
void write(const void* src, size_t size) override {
|
||||||
file->write_raw(src, size);
|
file->write_raw(src, size);
|
||||||
|
size_written += size;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t get_size_written() override {
|
||||||
|
return size_written;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
12
llama.cpp
12
llama.cpp
|
@ -3743,6 +3743,14 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
|
||||||
return s_total;
|
return s_total;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst)
|
||||||
|
{
|
||||||
|
llama_data_buffer_context data_ctx(dst);
|
||||||
|
llama_copy_state_data_internal(ctx, &data_ctx);
|
||||||
|
|
||||||
|
return data_ctx.get_size_written();
|
||||||
|
}
|
||||||
|
|
||||||
/** copy state data into either a buffer or file depending on the passed in context
|
/** copy state data into either a buffer or file depending on the passed in context
|
||||||
*
|
*
|
||||||
* file context:
|
* file context:
|
||||||
|
@ -3756,7 +3764,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
|
||||||
* llama_copy_state_data(ctx, &data_ctx);
|
* llama_copy_state_data(ctx, &data_ctx);
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
void llama_copy_state_data(struct llama_context * ctx, llama_data_context * data_ctx) {
|
void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
|
||||||
// copy rng
|
// copy rng
|
||||||
{
|
{
|
||||||
std::stringstream rng_ss;
|
std::stringstream rng_ss;
|
||||||
|
@ -4037,7 +4045,7 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
|
||||||
|
|
||||||
// save the context state using stream saving
|
// save the context state using stream saving
|
||||||
llama_data_file_context data_ctx(&file);
|
llama_data_file_context data_ctx(&file);
|
||||||
llama_copy_state_data(ctx, &data_ctx);
|
llama_copy_state_data_internal(ctx, &data_ctx);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue