examples : save-load-state: save only required state
This commit is contained in:
parent
b7e7982953
commit
aee95df8f1
1 changed files with 10 additions and 11 deletions
|
@ -45,13 +45,13 @@ int main(int argc, char ** argv) {
|
||||||
// save state (rng, logits, embedding and kv_cache) to file
|
// save state (rng, logits, embedding and kv_cache) to file
|
||||||
{
|
{
|
||||||
std::vector<uint8_t> state_mem(llama_get_state_size(ctx));
|
std::vector<uint8_t> state_mem(llama_get_state_size(ctx));
|
||||||
|
const size_t written = llama_copy_state_data(ctx, state_mem.data());
|
||||||
|
|
||||||
{
|
FILE *fp_write = fopen("dump_state.bin", "wb");
|
||||||
FILE *fp_write = fopen("dump_state.bin", "wb");
|
fwrite(state_mem.data(), 1, written, fp_write);
|
||||||
llama_copy_state_data(ctx, state_mem.data()); // could also copy directly to memory mapped file
|
fclose(fp_write);
|
||||||
fwrite(state_mem.data(), 1, state_mem.size(), fp_write);
|
|
||||||
fclose(fp_write);
|
fprintf(stderr, "%s : serialized state into %zd out of a maximum of %zd bytes\n", __func__, written, state_mem.size());
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// save state (last tokens)
|
// save state (last tokens)
|
||||||
|
@ -100,18 +100,17 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<uint8_t> state_mem(llama_get_state_size(ctx2));
|
std::vector<uint8_t> state_mem(llama_get_state_size(ctx2));
|
||||||
|
|
||||||
FILE * fp_read = fopen("dump_state.bin", "rb");
|
FILE * fp_read = fopen("dump_state.bin", "rb");
|
||||||
|
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
||||||
|
fclose(fp_read);
|
||||||
|
|
||||||
const size_t ret = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
if (read != llama_set_state_data(ctx2, state_mem.data())) {
|
||||||
if (ret != state_mem.size()) {
|
|
||||||
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
||||||
llama_free(ctx2);
|
llama_free(ctx2);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_set_state_data(ctx2, state_mem.data());
|
fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
|
||||||
|
|
||||||
fclose(fp_read);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// restore state (last tokens)
|
// restore state (last tokens)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue