From 71fcb7e27e28296bcd614b1802c8eaa1ac6f947e Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Thu, 23 Nov 2023 03:30:08 -0700 Subject: [PATCH] Allow exporting a view of the KV cache --- common/common.cpp | 22 +++++++++++ common/common.h | 6 +++ examples/parallel/parallel.cpp | 5 +++ llama.cpp | 67 ++++++++++++++++++++++++++++++++++ llama.h | 23 ++++++++++++ 5 files changed, 123 insertions(+) diff --git a/common/common.cpp b/common/common.cpp index eec704b99..b40a74cf4 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1386,3 +1386,25 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); } + +// +// KV cache utils +// + +void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) { + printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d\n", + view.n_cells, view.n_max_seq, view.used_cells, view.token_count); + llama_kv_cache_view_cell * c_curr = view.cells; + struct llama_kv_cache_view_cell_sequence * cs_curr = view.cells_sequences; + for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) { + if (i % row_size == 0) { + printf("\n%5d: ", i); + } + int seq_count = 0; + for (int j = 0; j < view.n_max_seq; j++) { + if (cs_curr[j].seq_id >= 0) { seq_count++; } + } + putchar(int('0' + (std::min(9, seq_count)))); + } + printf("\n=== Done dumping\n"); +} diff --git a/common/common.h b/common/common.h index 88fa13fc0..58a153203 100644 --- a/common/common.h +++ b/common/common.h @@ -218,3 +218,9 @@ std::string get_sortable_timestamp(); void dump_non_result_info_yaml( FILE * stream, const gpt_params & params, const llama_context * lctx, const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc); + +// +// KV cache utils +// + +void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index a78df305f..439e6d0a7 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -172,6 +172,8 @@ int main(int argc, char ** argv) { int32_t n_total_gen = 0; int32_t n_cache_miss = 0; + struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, n_seq); + const auto t_main_start = ggml_time_us(); LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__); @@ -201,6 +203,9 @@ int main(int argc, char ** argv) { LOG_TEE("Processing requests ...\n\n"); while (true) { + llama_kv_cache_view_update(ctx, &kvc_view); + dump_kv_cache_view(kvc_view); + llama_batch_clear(batch); // decode any currently ongoing sequences diff --git a/llama.cpp b/llama.cpp index 5679c7050..e23d820ea 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8805,6 +8805,73 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha } } +struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq) { + struct llama_kv_cache_view result = { + /*.n_cells*/ 0, + /*.n_max_seq*/ n_max_seq, + /*.token_count*/ 0, + /*.used_cells*/ llama_get_kv_cache_used_cells(ctx), + /*.cells*/ nullptr, + /*.cells_sequences*/ nullptr, + }; + return result; +} + +void llama_kv_cache_view_free(struct llama_kv_cache_view * view) { + if (view->cells != nullptr) { + free(view->cells); + view->cells = nullptr; + } + if (view->cells_sequences != nullptr) { + free(view->cells_sequences); + view->cells_sequences = nullptr; + } +} + +void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) { + if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) { + view->n_cells = int32_t(ctx->kv_self.size); + void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells); + GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells"); + view->cells = (struct llama_kv_cache_view_cell *)p; + p = realloc(view->cells_sequences, sizeof(struct llama_kv_cache_view_cell_sequence) * view->n_max_seq * view->n_cells); + GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences"); + view->cells_sequences = (struct llama_kv_cache_view_cell_sequence *)p; + } + + const std::vector & kv_cells = ctx->kv_self.cells; + llama_kv_cache_view_cell * c_curr = view->cells; + struct llama_kv_cache_view_cell_sequence * cs_curr = view->cells_sequences; + int32_t used_cells = 0; + int32_t token_count = 0; + + for (uint32_t i = 0; i < ctx->kv_self.size; i++, c_curr++, cs_curr += view->n_max_seq) { + token_count += ctx->kv_self.cells[i].seq_id.size(); + c_curr->pos = kv_cells[i].pos + kv_cells[i].delta; + + int seq_idx = 0; + for (const llama_seq_id it : kv_cells[i].seq_id) { + if (seq_idx >= view->n_max_seq) { + break; + } + cs_curr[seq_idx].seq_id = it; + seq_idx++; + } + if (seq_idx != 0) { + used_cells++; + } + for (; seq_idx < view->n_max_seq; seq_idx++) { + cs_curr[seq_idx].seq_id = -1; + } + } + view->token_count = token_count; + view->used_cells = used_cells; + if (uint32_t(used_cells) != ctx->kv_self.used) { + LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n", + __func__, ctx->kv_self.used, used_cells); + } +} + int llama_get_kv_cache_token_count(const struct llama_context * ctx) { int result = 0; diff --git a/llama.h b/llama.h index 06b982ee0..72a156ac5 100644 --- a/llama.h +++ b/llama.h @@ -361,6 +361,29 @@ extern "C" { // KV cache // + struct llama_kv_cache_view_cell { + llama_pos pos; + }; + + struct llama_kv_cache_view_cell_sequence { + llama_seq_id seq_id; + }; + + struct llama_kv_cache_view { + int32_t n_cells; + int32_t n_max_seq; + int32_t token_count; + int32_t used_cells; + struct llama_kv_cache_view_cell *cells; + struct llama_kv_cache_view_cell_sequence * cells_sequences; + }; + + LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq); + + LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view); + + LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view); + // Returns the number of tokens in the KV cache (slow, use only for debug) // If a KV cell has multiple sequences assigned to it, it will be counted multiple times LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);