diff --git a/common/common.cpp b/common/common.cpp index b40a74cf4..e9c338028 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -1408,3 +1409,44 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) { } printf("\n=== Done dumping\n"); } + +void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) { + printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d\n", + view.n_cells, view.n_max_seq, view.used_cells, view.token_count); + + std::unordered_map seqs; + llama_kv_cache_view_cell * c_curr = view.cells; + struct llama_kv_cache_view_cell_sequence * cs_curr = view.cells_sequences; + for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) { + for (int j = 0; j < view.n_max_seq; j++) { + if (cs_curr[j].seq_id < 0) { continue; } + if (seqs.find(cs_curr[j].seq_id) == seqs.end()) { + seqs[cs_curr[j].seq_id] = seqs.size(); + if (seqs.size() >= 10) { break; } + } + } + if (seqs.size() >= 10) { break; } + } + printf("=== Sequence legend: "); + for (const auto & it : seqs) { + printf("%zu=%d, ", it.second, it.first); + } + + c_curr = view.cells; + cs_curr = view.cells_sequences; + for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) { + if (i % row_size == 0) { + printf("\n%5d: ", i); + } + for (int j = 0; j < view.n_max_seq; j++) { + if (cs_curr[j].seq_id >= 0) { + const auto & it = seqs.find(cs_curr[j].seq_id); + putchar(it != seqs.end() ? int('0' + it->second) : '+'); + } else { + putchar('.'); + } + } + putchar(' '); + } + printf("\n=== Done dumping\n"); +} diff --git a/common/common.h b/common/common.h index 58a153203..45bd0e43d 100644 --- a/common/common.h +++ b/common/common.h @@ -224,3 +224,4 @@ void dump_non_result_info_yaml( // void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80); +void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 80); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 439e6d0a7..8cc20b422 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -172,7 +172,7 @@ int main(int argc, char ** argv) { int32_t n_total_gen = 0; int32_t n_cache_miss = 0; - struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, n_seq); + struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, n_clients); const auto t_main_start = ggml_time_us(); @@ -204,7 +204,7 @@ int main(int argc, char ** argv) { while (true) { llama_kv_cache_view_update(ctx, &kvc_view); - dump_kv_cache_view(kvc_view); + dump_kv_cache_view_seqs(kvc_view, 40); llama_batch_clear(batch);