From 02a184065a55552145919b466609f3b8f7ce6a05 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Wed, 27 Mar 2024 22:39:28 +0800 Subject: [PATCH] add kv seq save restore to test case --- examples/save-load-state/save-load-state.cpp | 93 +++++++++++++++++++- 1 file changed, 91 insertions(+), 2 deletions(-) diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index ef952e2bd..6a1966712 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -24,6 +24,7 @@ int main(int argc, char ** argv) { std::string result0; std::string result1; + std::string result2; // init llama_model * model; @@ -141,16 +142,104 @@ int main(int argc, char ** argv) { n_past += 1; } - printf("\n"); + printf("\n\n"); llama_free(ctx2); - llama_free_model(model); if (result0 != result1) { fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__); return 1; } + // make new context + auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); + + printf("\nsingle seq run: %s", params.prompt.c_str()); + + // load state (rng, logits, embedding and kv_cache) from file + { + std::vector state_mem(llama_get_state_size(ctx3)); + + FILE * fp_read = fopen("dump_state.bin", "rb"); + const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); + fclose(fp_read); + + if (read != llama_set_state_data(ctx3, state_mem.data())) { + fprintf(stderr, "\n%s : failed to read state\n", __func__); + llama_free(ctx3); + llama_free_model(model); + return 1; + } + + fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size()); + } + + // restore state (last tokens) + n_past = n_past_saved; + + // save seq 0 and load into seq 1 + { + // save kv of seq 0 + std::vector seq_store(llama_get_seq_size(ctx3, 0)); + const size_t ncopy = llama_copy_seq_data(ctx3, seq_store.data(), 0); + if (ncopy != seq_store.size()) { + fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size()); + llama_free(ctx3); + llama_free_model(model); + return 1; + } + fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy); + + // erase whole kv + llama_kv_cache_clear(ctx3); + fprintf(stderr, "%s : kv cache cleared\n", __func__); + + // restore kv into seq 1 + const size_t nset = llama_set_seq_data(ctx3, seq_store.data(), 1); + if (nset != seq_store.size()) { + fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size()); + llama_free(ctx3); + llama_free_model(model); + return 1; + } + fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset); + } + + // third run with seq 1 instead of 0 + for (auto i = 0; i < params.n_predict; i++) { + auto * logits = llama_get_logits(ctx3); + auto n_vocab = llama_n_vocab(model); + std::vector candidates; + candidates.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + auto next_token = llama_sample_token(ctx3, &candidates_p); + auto next_token_str = llama_token_to_piece(ctx3, next_token); + + printf("%s", next_token_str.c_str()); + result2 += next_token_str; + + if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) { + fprintf(stderr, "\n%s : failed to evaluate\n", __func__); + llama_free(ctx3); + llama_free_model(model); + return 1; + } + n_past += 1; + } + + printf("\n"); + + llama_free(ctx3); + llama_free_model(model); + + if (result0 != result2) { + fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__); + return 1; + } + fprintf(stderr, "\n%s : success\n", __func__); return 0;