From 662aaea8c9c624f9a6622229e0ed01b7d37248d1 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Wed, 27 Mar 2024 16:56:35 +0800 Subject: [PATCH] llama : save and restore kv cache for single seq id --- examples/server/server.cpp | 223 +++++++++++++++++++++++++++++++- llama.cpp | 253 +++++++++++++++++++++++++++++++++++++ llama.h | 14 ++ 3 files changed, 489 insertions(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 53ad9239e..4e9a0e9e3 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -61,7 +61,10 @@ enum server_task_type { SERVER_TASK_TYPE_COMPLETION, SERVER_TASK_TYPE_CANCEL, SERVER_TASK_TYPE_NEXT_RESPONSE, - SERVER_TASK_TYPE_METRICS + SERVER_TASK_TYPE_METRICS, + SERVER_TASK_TYPE_SLOT_SAVE, + SERVER_TASK_TYPE_SLOT_RESTORE, + SERVER_TASK_TYPE_SLOT_ERASE, }; struct server_task { @@ -1612,6 +1615,142 @@ struct server_context { } queue_results.send(res); } break; + case SERVER_TASK_TYPE_SLOT_SAVE: + { + int id_slot = task.data["id_slot"]; + server_slot * slot = get_slot(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + + const size_t token_count = slot->cache_tokens.size(); + const int64_t t_start = ggml_time_us(); + + std::string filename = task.data["filename"]; + size_t state_size = llama_get_seq_size(ctx, slot->id + 1); + std::vector state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token)); + size_t nwrite = llama_copy_seq_data(ctx, state_data.data(), slot->id + 1); + GGML_ASSERT(nwrite <= state_size); + + // write the cached token count of the slot->cache_tokens.size() + memcpy(state_data.data() + nwrite, &token_count, sizeof(size_t)); + nwrite += sizeof(size_t); + + // write the cached tokens (loop) + for (size_t i = 0; i < token_count; i++) { + const llama_token token = slot->cache_tokens[i]; + memcpy(state_data.data() + nwrite, &token, sizeof(llama_token)); + nwrite += sizeof(llama_token); + } + GGML_ASSERT(nwrite <= state_data.size()); + + std::ofstream outfile(filename, std::ios::binary); + outfile.write(reinterpret_cast(state_data.data()), nwrite); + outfile.close(); + + const int64_t t_end = ggml_time_us(); + const double t_save_ms = (t_end - t_start) / 1000.0; + + server_task_result result; + result.id = task.id; + result.stop = true; + result.error = false; + result.data = json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_saved", token_count }, // tokens saved + { "n_written", nwrite }, // bytes written + { "timings", { + { "save_ms", t_save_ms } + } } + }; + queue_results.send(result); + } break; + case SERVER_TASK_TYPE_SLOT_RESTORE: + { + int id_slot = task.data["id_slot"]; + server_slot * slot = get_slot(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + + const int64_t t_start = ggml_time_us(); + + std::string filename = task.data["filename"]; // TODO: restrict to files in path specified in server params? + std::ifstream infile(filename, std::ios::binary); + if (!infile.is_open()) { + send_error(task, "Failed to open file", ERROR_TYPE_INVALID_REQUEST); + break; + } + + std::vector state_data((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); + infile.close(); + + size_t nread = llama_set_seq_data(ctx, state_data.data(), slot->id + 1); + GGML_ASSERT(nread <= state_data.size()); + + // restore cached token values + size_t token_count = 0; + if (nread + sizeof(size_t) <= state_data.size()) { + token_count = *reinterpret_cast(state_data.data() + nread); + nread += sizeof(size_t); + } + slot->cache_tokens.resize(token_count); + GGML_ASSERT(nread + (token_count * sizeof(llama_token)) <= state_data.size()); + + // tokens are of type llama_token (an integer) + for (size_t i = 0; i < token_count; i++) { + if (nread + sizeof(llama_token) <= state_data.size()) { + slot->cache_tokens[i] = *reinterpret_cast(state_data.data() + nread); + nread += sizeof(llama_token); + } + } + GGML_ASSERT(nread <= state_data.size()); + + const int64_t t_end = ggml_time_us(); + const double t_restore_ms = (t_end - t_start) / 1000.0; + + server_task_result result; + result.id = task.id; + result.stop = true; + result.error = false; + result.data = json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_restored", token_count }, // tokens restored + { "n_read", nread }, // bytes read + { "timings", { + { "restore_ms", t_restore_ms } + } } + }; + queue_results.send(result); + } break; + case SERVER_TASK_TYPE_SLOT_ERASE: + { + int id_slot = task.data["id_slot"]; + server_slot * slot = get_slot(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + + // Erase token cache + const size_t n_erased = slot->cache_tokens.size(); + llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); + slot->cache_tokens.clear(); + + server_task_result result; + result.id = task.id; + result.stop = true; + result.error = false; + result.data = json { + { "id_slot", id_slot }, + { "n_erased", n_erased } + }; + queue_results.send(result); + } break; } } @@ -3157,6 +3296,85 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; + const auto handle_slot_save = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + + json request_data = json::parse(req.body); + int id_slot = request_data["id_slot"]; + std::string filename = request_data["filename"]; + + server_task task; + task.type = SERVER_TASK_TYPE_SLOT_SAVE; + task.data = { + { "id_slot", id_slot }, + { "filename", filename }, + }; + + const int id_task = ctx_server.queue_tasks.post(task); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + if (result.error) { + res_error(res, result.data); + } else { + res.set_content(result.data.dump(), "application/json"); + } + }; + + const auto handle_slot_restore = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + + json request_data = json::parse(req.body); + int id_slot = request_data["id_slot"]; + std::string filename = request_data["filename"]; + + server_task task; + task.type = SERVER_TASK_TYPE_SLOT_RESTORE; + task.data = { + { "id_slot", id_slot }, + { "filename", filename }, + }; + + const int id_task = ctx_server.queue_tasks.post(task); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + if (result.error) { + res_error(res, result.data); + } else { + res.set_content(result.data.dump(), "application/json"); + } + }; + + const auto handle_slot_erase = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + + json request_data = json::parse(req.body); + int id_slot = request_data["id_slot"]; + + server_task task; + task.type = SERVER_TASK_TYPE_SLOT_ERASE; + task.data = { + { "id_slot", id_slot }, + }; + + const int id_task = ctx_server.queue_tasks.post(task); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + if (result.error) { + res_error(res, result.data); + } else { + res.set_content(result.data, "application/json"); + } + }; + const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { @@ -3519,6 +3737,9 @@ int main(int argc, char ** argv) { svr->Post("/v1/embeddings", handle_embeddings); svr->Post("/tokenize", handle_tokenize); svr->Post("/detokenize", handle_detokenize); + svr->Post("/slot/save", handle_slot_save); + svr->Post("/slot/restore", handle_slot_restore); + svr->Post("/slot/erase", handle_slot_erase); // // Start the server diff --git a/llama.cpp b/llama.cpp index 892d46fbc..e3a9eea4c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15059,6 +15059,259 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi return true; } +size_t llama_get_seq_size(struct llama_context* ctx, llama_seq_id seq_id) { + // save the size of size_t as a uint32_t for safety check + const size_t size_t_size_size = sizeof(uint32_t); + + // other values + const size_t s_cell_count_size = sizeof(uint32_t); + const size_t s_layer_count_size = sizeof(uint32_t); + const size_t n_embd_v_gqa_size = sizeof(uint32_t); + + size_t s_cell_count = 0; + size_t s_cell_data_size = 0; + const auto& kv_self = ctx->kv_self; + const auto& hparams = ctx->model.hparams; + + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + + for (uint32_t i = 0; i < kv_self.size; ++i) { + const auto& cell = kv_self.cells[i]; + if (cell.seq_id.count(seq_id) > 0) { + ++s_cell_count; + s_cell_data_size += sizeof(llama_pos); + } + } + + for (int il = 0; il < (int)n_layer; ++il) { + // k_size_row and v_size_el values of layer + s_cell_data_size += sizeof(size_t) * 2; + + // keys + const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); + s_cell_data_size += k_size_row * s_cell_count; + + // values (transposed) + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + s_cell_data_size += v_size_el * s_cell_count * n_embd_v_gqa; + } + + const size_t s_total = ( + size_t_size_size + + s_cell_count_size + + s_layer_count_size + + n_embd_v_gqa_size + + s_cell_data_size + ); + + return s_total; +} + +size_t llama_copy_seq_data(struct llama_context * ctx, uint8_t * dst, llama_seq_id seq_id) { + llama_data_buffer_context data_ctx(dst); + + // Save the size of size_t as a uint32_t for safety check + const uint32_t size_t_size = sizeof(size_t); + data_ctx.write(&size_t_size, sizeof(size_t_size)); + + const auto& kv_self = ctx->kv_self; + std::vector> cell_ranges; // ranges, from inclusive, to exclusive + uint32_t cell_count = 0; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id + { + uint32_t cell_range_begin = kv_self.size; + for (uint32_t i = 0; i < kv_self.size; ++i) { + const auto& cell = kv_self.cells[i]; + if (cell.has_seq_id(seq_id)) { + ++cell_count; + if (cell_range_begin == kv_self.size) { + cell_range_begin = i; + } + } + else { + if (cell_range_begin != kv_self.size) { + cell_ranges.push_back({ cell_range_begin, i }); + cell_range_begin = kv_self.size; + } + } + } + if (cell_range_begin != kv_self.size) { + cell_ranges.push_back({ cell_range_begin, kv_self.size }); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto& range : cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + } + + // Write the cell count + data_ctx.write(&cell_count, sizeof(cell_count)); + + const auto & hparams = ctx->model.hparams; + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + + // Write the layer count + data_ctx.write(&n_layer, sizeof(n_layer)); + + // Write n_embd_v_gqa + data_ctx.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + + // Iterate the ranges and write all the pos (this is the token position in the prompt) + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + const auto & cell = kv_self.cells[i]; + data_ctx.write(&cell.pos, sizeof(cell.pos)); + } + } + + // Iterate and write all the keys first, each row is a cell + // Get whole range at a time + std::vector tmp_buf; + for (int il = 0; il < (int)n_layer; ++il) { + // Write row size of key + const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); + data_ctx.write(&k_size_row, sizeof(k_size_row)); + + // Read each range of cells of k_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + tmp_buf.resize(range_size * k_size_row); + ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), range.first * k_size_row, range_size * k_size_row); + data_ctx.write(&tmp_buf[0], tmp_buf.size()); + } + } + + // For the values, they are transposed, so we also need the element size and get the element ranges from each row + const uint32_t kv_size = kv_self.size; + for (int il = 0; il < (int)n_layer; ++il) { + // Write element size + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + data_ctx.write(&v_size_el, sizeof(v_size_el)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + tmp_buf.resize(range_size * v_size_el); + ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size()); + data_ctx.write(&tmp_buf[0], tmp_buf.size()); + } + } + } + + return data_ctx.get_size_written(); +} + +size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) { + auto & kv_self = ctx->kv_self; + + // Wipe the slot + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + + const uint8_t * inp = src; + + // Read size of size_t + uint32_t size_t_size; + memcpy(&size_t_size, inp, sizeof(size_t_size)); + inp += sizeof(size_t_size); + GGML_ASSERT(size_t_size == sizeof(size_t)); + + // Read the cell count + uint32_t cell_count; + memcpy(&cell_count, inp, sizeof(cell_count)); + inp += sizeof(cell_count); + + // Read the layer count + uint32_t n_layer_ref; + memcpy(&n_layer_ref, inp, sizeof(n_layer_ref)); + inp += sizeof(n_layer_ref); + + // Read n_embd_v_gqa + uint32_t n_embd_v_gqa_ref; + memcpy(&n_embd_v_gqa_ref, inp, sizeof(n_embd_v_gqa_ref)); + inp += sizeof(n_embd_v_gqa_ref); + + // Allocate the new cells for the slot + llama_batch batch = llama_batch_init(cell_count, 0, 1); + batch.n_tokens = cell_count; + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + memcpy(&pos, inp, sizeof(pos)); + inp += sizeof(pos); + + batch.pos[i] = pos; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = dest_seq_id; + } + llama_kv_cache_find_slot(kv_self, batch); + + // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(kv_self.head + cell_count <= kv_self.size); + GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]); + GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id)); + GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id)); + + const auto& hparams = ctx->model.hparams; + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + const uint32_t kv_size = kv_self.size; + const uint32_t kv_head = kv_self.head; + GGML_ASSERT(n_layer == n_layer_ref); + GGML_ASSERT(n_embd_v_gqa == n_embd_v_gqa_ref); + + // For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo + for (int il = 0; il < (int)n_layer; ++il) { + // Read row size of key + size_t k_size_row_ref; + memcpy(&k_size_row_ref, inp, sizeof(k_size_row_ref)); + inp += sizeof(k_size_row_ref); + const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); + GGML_ASSERT(k_size_row == k_size_row_ref); + + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(kv_self.k_l[il], inp, kv_head * k_size_row, cell_count * k_size_row); + inp += cell_count * k_size_row; + } + + // For each layer, read the values for each cell (transposed) + for (int il = 0; il < (int)n_layer; ++il) { + // Read element size of value + size_t v_size_el_ref; + memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref)); + inp += sizeof(v_size_el_ref); + + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + GGML_ASSERT(v_size_el == v_size_el_ref); + + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (kv_head + j * kv_size) * v_size_el; + ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el); + inp += cell_count * v_size_el; + } + } + + // Cleanup + llama_batch_free(batch); + + const size_t nread = inp - src; + return nread; +} + void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) { ctx->cparams.n_threads = n_threads; ctx->cparams.n_threads_batch = n_threads_batch; diff --git a/llama.h b/llama.h index 1fe4af495..33164a33a 100644 --- a/llama.h +++ b/llama.h @@ -623,6 +623,20 @@ extern "C" { const llama_token * tokens, size_t n_token_count); + LLAMA_API size_t llama_get_seq_size( + struct llama_context * ctx, + llama_seq_id seq_id); + + LLAMA_API size_t llama_copy_seq_data( + struct llama_context * ctx, + uint8_t * dst, + llama_seq_id seq_id); + + LLAMA_API size_t llama_set_seq_data( + struct llama_context * ctx, + const uint8_t * src, + llama_seq_id dest_seq_id); + // // Decoding //