llama : refactor session file management

* llama : saving and restoring state checks for overflow

The size of the buffers should now be given to the functions working
with them, otherwise a truncated file could cause out of bound reads.

* llama : stream from session file instead of copying into a big buffer

Loading session files should no longer cause a memory usage spike.

* llama : llama_state_get_size returns the actual size instead of max

This is a breaking change, but makes that function *much* easier
to keep up to date, and it also makes it reflect the behavior
of llama_state_seq_get_size.

* llama : share code between whole and seq_id-specific state saving

Both session file types now use a more similar format.

* llama : no longer store all hparams in session files

Instead, the model arch name is stored.
The layer count and the embedding dimensions of the KV cache
are still verified when loading.
Storing all the hparams is not necessary.
This commit is contained in:
Francis Couture-Harpin 2024-07-25 18:33:54 -04:00
parent de280085e7
commit 8e39037b86
3 changed files with 679 additions and 772 deletions

View file

@ -47,7 +47,7 @@ int main(int argc, char ** argv) {
// save state (rng, logits, embedding and kv_cache) to file
{
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
const size_t written = llama_state_get_data(ctx, state_mem.data());
const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size());
FILE *fp_write = fopen("dump_state.bin", "wb");
fwrite(state_mem.data(), 1, written, fp_write);
@ -99,13 +99,16 @@ int main(int argc, char ** argv) {
// load state (rng, logits, embedding and kv_cache) from file
{
std::vector<uint8_t> state_mem(llama_state_get_size(ctx2));
std::vector<uint8_t> state_mem;
FILE * fp_read = fopen("dump_state.bin", "rb");
fseek(fp_read, 0, SEEK_END);
state_mem.resize(ftell(fp_read));
fseek(fp_read, 0, SEEK_SET);
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);
if (read != llama_state_set_data(ctx2, state_mem.data())) {
if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx2);
llama_free_model(model);
@ -159,13 +162,16 @@ int main(int argc, char ** argv) {
// load state (rng, logits, embedding and kv_cache) from file
{
std::vector<uint8_t> state_mem(llama_state_get_size(ctx3));
std::vector<uint8_t> state_mem;
FILE * fp_read = fopen("dump_state.bin", "rb");
fseek(fp_read, 0, SEEK_END);
state_mem.resize(ftell(fp_read));
fseek(fp_read, 0, SEEK_SET);
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);
if (read != llama_state_set_data(ctx3, state_mem.data())) {
if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx3);
llama_free_model(model);
@ -182,7 +188,7 @@ int main(int argc, char ** argv) {
{
// save kv of seq 0
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0);
if (ncopy != seq_store.size()) {
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
llama_free(ctx3);
@ -196,7 +202,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s : kv cache cleared\n", __func__);
// restore kv into seq 1
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1);
if (nset != seq_store.size()) {
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
llama_free(ctx3);