diff --git a/examples/server/server.cpp b/examples/server/server.cpp index adcfa79f9..de05f47b9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1630,26 +1630,8 @@ struct server_context { std::string filename = task.data["filename"]; std::string filepath = task.data["filepath"]; - size_t state_size = llama_state_seq_get_size(ctx, slot->id + 1); - std::vector state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token)); - size_t nwrite = llama_state_seq_get_data(ctx, state_data.data(), slot->id + 1); - GGML_ASSERT(nwrite <= state_size); - // write the cached token count of the slot->cache_tokens.size() - memcpy(state_data.data() + nwrite, &token_count, sizeof(size_t)); - nwrite += sizeof(size_t); - - // write the cached tokens (loop) - for (size_t i = 0; i < token_count; i++) { - const llama_token token = slot->cache_tokens[i]; - memcpy(state_data.data() + nwrite, &token, sizeof(llama_token)); - nwrite += sizeof(llama_token); - } - GGML_ASSERT(nwrite <= state_data.size()); - - std::ofstream outfile(filepath, std::ios::binary); - outfile.write(reinterpret_cast(state_data.data()), nwrite); - outfile.close(); + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count); const int64_t t_end = ggml_time_us(); const double t_save_ms = (t_end - t_start) / 1000.0; @@ -1682,39 +1664,16 @@ struct server_context { std::string filename = task.data["filename"]; std::string filepath = task.data["filepath"]; - std::ifstream infile(filepath, std::ios::binary); - if (!infile.is_open()) { - send_error(task, "Failed to open file", ERROR_TYPE_INVALID_REQUEST); - break; - } - std::vector state_data((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); - infile.close(); - - size_t nread = llama_state_seq_set_data(ctx, state_data.data(), slot->id + 1); + slot->cache_tokens.resize(slot->n_ctx); + size_t token_count = 0; + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count); if (nread == 0) { + slot->cache_tokens.resize(0); send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); break; } - GGML_ASSERT(nread <= state_data.size()); - - // restore cached token values - size_t token_count = 0; - if (nread + sizeof(size_t) <= state_data.size()) { - token_count = *reinterpret_cast(state_data.data() + nread); - nread += sizeof(size_t); - } slot->cache_tokens.resize(token_count); - GGML_ASSERT(nread + (token_count * sizeof(llama_token)) <= state_data.size()); - - // tokens are of type llama_token (an integer) - for (size_t i = 0; i < token_count; i++) { - if (nread + sizeof(llama_token) <= state_data.size()) { - slot->cache_tokens[i] = *reinterpret_cast(state_data.data() + nread); - nread += sizeof(llama_token); - } - } - GGML_ASSERT(nread <= state_data.size()); const int64_t t_end = ggml_time_us(); const double t_restore_ms = (t_end - t_start) / 1000.0; diff --git a/llama.cpp b/llama.cpp index 145942078..5e1747842 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15133,8 +15133,7 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) return s_total; } -size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, llama_seq_id seq_id) { - llama_data_buffer_context data_ctx(dst); +static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_context & data_ctx, llama_seq_id seq_id) { const auto & kv_self = ctx->kv_self; GGML_ASSERT(!kv_self.recurrent); // not implemented @@ -15238,6 +15237,11 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, llama return data_ctx.get_size_written(); } +size_t llama_state_seq_get_data(struct llama_context* ctx, uint8_t* dst, llama_seq_id seq_id) { + llama_data_buffer_context data_ctx(dst); + return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); +} + size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) { auto & kv_self = ctx->kv_self; GGML_ASSERT(!kv_self.recurrent); // not implemented @@ -15361,6 +15365,78 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, return nread; } +size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { + llama_file file(filepath, "wb"); + + file.write_u32(LLAMA_STATE_SEQ_MAGIC); + file.write_u32(LLAMA_STATE_SEQ_VERSION); + + // save the prompt + file.write_u32((uint32_t)n_token_count); + file.write_raw(tokens, sizeof(llama_token) * n_token_count); + + // save the context state using stream saving + llama_data_file_context data_ctx(&file); + llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); + + const size_t res = file.tell(); + GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written()); + return res; +} + +static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + llama_file file(filepath, "rb"); + + // version checks + { + const uint32_t magic = file.read_u32(); + const uint32_t version = file.read_u32(); + + if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) { + LLAMA_LOG_ERROR("%s : unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version); + return 0; + } + } + + // load the prompt + { + const uint32_t n_token_count = file.read_u32(); + + if (n_token_count > n_token_capacity) { + LLAMA_LOG_ERROR("%s : token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); + return 0; + } + + file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); + *n_token_count_out = n_token_count; + } + + // restore the context state + { + const size_t state_size = file.size - file.tell(); + std::vector state_data(state_size); + file.read_raw(state_data.data(), state_size); + const size_t nread = llama_state_seq_set_data(ctx, state_data.data(), dest_seq_id); + if (!nread) { + LLAMA_LOG_ERROR("%s : failed to restore sequence state\n", __func__); + return 0; + } + GGML_ASSERT(nread <= state_size); + GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell()); + } + + return file.tell(); +} + +size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + try { + return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("error loading sequence state file: %s\n", err.what()); + return false; + } +} + void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) { ctx->cparams.n_threads = n_threads; ctx->cparams.n_threads_batch = n_threads_batch; diff --git a/llama.h b/llama.h index 0473f726a..d6e8b2ca6 100644 --- a/llama.h +++ b/llama.h @@ -37,10 +37,14 @@ #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' +#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN #define LLAMA_SESSION_VERSION 5 +#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ +#define LLAMA_STATE_SEQ_VERSION 1 + #ifdef __cplusplus extern "C" { #endif @@ -667,6 +671,21 @@ extern "C" { const uint8_t * src, llama_seq_id dest_seq_id); + LLAMA_API size_t llama_state_seq_save_file( + struct llama_context * ctx, + const char * filepath, + llama_seq_id seq_id, + const llama_token * tokens, + size_t n_token_count); + + LLAMA_API size_t llama_state_seq_load_file( + struct llama_context * ctx, + const char * filepath, + llama_seq_id dest_seq_id, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out); + // // Decoding //