- restored breakage of the llama_copy_state_data API
- moved new logic for copying llama state data to internal function
This commit is contained in:
parent
1ffa6be726
commit
7c5b2b57b0
3 changed files with 32 additions and 3 deletions
|
@ -66,7 +66,11 @@ int main(int argc, char ** argv) {
|
|||
uint8_t * state_mem = new uint8_t[state_size];
|
||||
|
||||
// Save state (rng, logits, embedding and kv_cache) to file
|
||||
llama_save_session_file(ctx, "dump_state.bin", tokens.data(), tokens.size());
|
||||
{
|
||||
llama_file file("dump_state.bin", "wb");
|
||||
llama_data_file_context data_ctx(&file);
|
||||
llama_copy_state_data(ctx, &data_ctx); // can also copy to memory buffer
|
||||
}
|
||||
|
||||
// save state (last tokens)
|
||||
const auto last_n_tokens_data_saved = std::vector<llama_token>(last_n_tokens_data);
|
||||
|
|
17
llama-util.h
17
llama-util.h
|
@ -152,23 +152,40 @@ struct llama_file {
|
|||
// llama_context_data
|
||||
struct llama_data_context {
|
||||
virtual void write(const void* src, size_t size) = 0;
|
||||
virtual size_t get_size_written() = 0;
|
||||
virtual ~llama_data_context() = default;
|
||||
};
|
||||
|
||||
struct llama_data_buffer_context : llama_data_context {
|
||||
uint8_t* ptr;
|
||||
size_t size_written = 0;
|
||||
|
||||
llama_data_buffer_context(uint8_t* p) : ptr(p) {}
|
||||
|
||||
void write(const void* src, size_t size) override {
|
||||
memcpy(ptr, src, size);
|
||||
ptr += size;
|
||||
size_written += size;
|
||||
}
|
||||
|
||||
size_t get_size_written() override {
|
||||
return size_written;
|
||||
}
|
||||
};
|
||||
|
||||
struct llama_data_file_context : llama_data_context {
|
||||
llama_file* file;
|
||||
size_t size_written = 0;
|
||||
|
||||
llama_data_file_context(llama_file* f) : file(f) {}
|
||||
|
||||
void write(const void* src, size_t size) override {
|
||||
file->write_raw(src, size);
|
||||
size_written += size;
|
||||
}
|
||||
|
||||
size_t get_size_written() override {
|
||||
return size_written;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
12
llama.cpp
12
llama.cpp
|
@ -3743,6 +3743,14 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
|
|||
return s_total;
|
||||
}
|
||||
|
||||
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst)
|
||||
{
|
||||
llama_data_buffer_context data_ctx(dst);
|
||||
llama_copy_state_data_internal(ctx, &data_ctx);
|
||||
|
||||
return data_ctx.get_size_written();
|
||||
}
|
||||
|
||||
/** copy state data into either a buffer or file depending on the passed in context
|
||||
*
|
||||
* file context:
|
||||
|
@ -3756,7 +3764,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
|
|||
* llama_copy_state_data(ctx, &data_ctx);
|
||||
*
|
||||
*/
|
||||
void llama_copy_state_data(struct llama_context * ctx, llama_data_context * data_ctx) {
|
||||
void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
|
||||
// copy rng
|
||||
{
|
||||
std::stringstream rng_ss;
|
||||
|
@ -4037,7 +4045,7 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
|
|||
|
||||
// save the context state using stream saving
|
||||
llama_data_file_context data_ctx(&file);
|
||||
llama_copy_state_data(ctx, &data_ctx);
|
||||
llama_copy_state_data_internal(ctx, &data_ctx);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue