Allow exporting a view of the KV cache
This commit is contained in:
parent
671f639c59
commit
71fcb7e27e
5 changed files with 123 additions and 0 deletions
|
@ -1386,3 +1386,25 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
|||
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
|
||||
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
|
||||
}
|
||||
|
||||
//
|
||||
// KV cache utils
|
||||
//
|
||||
|
||||
void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) {
|
||||
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d\n",
|
||||
view.n_cells, view.n_max_seq, view.used_cells, view.token_count);
|
||||
llama_kv_cache_view_cell * c_curr = view.cells;
|
||||
struct llama_kv_cache_view_cell_sequence * cs_curr = view.cells_sequences;
|
||||
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
|
||||
if (i % row_size == 0) {
|
||||
printf("\n%5d: ", i);
|
||||
}
|
||||
int seq_count = 0;
|
||||
for (int j = 0; j < view.n_max_seq; j++) {
|
||||
if (cs_curr[j].seq_id >= 0) { seq_count++; }
|
||||
}
|
||||
putchar(int('0' + (std::min(9, seq_count))));
|
||||
}
|
||||
printf("\n=== Done dumping\n");
|
||||
}
|
||||
|
|
|
@ -218,3 +218,9 @@ std::string get_sortable_timestamp();
|
|||
void dump_non_result_info_yaml(
|
||||
FILE * stream, const gpt_params & params, const llama_context * lctx,
|
||||
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);
|
||||
|
||||
//
|
||||
// KV cache utils
|
||||
//
|
||||
|
||||
void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
|
||||
|
|
|
@ -172,6 +172,8 @@ int main(int argc, char ** argv) {
|
|||
int32_t n_total_gen = 0;
|
||||
int32_t n_cache_miss = 0;
|
||||
|
||||
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, n_seq);
|
||||
|
||||
const auto t_main_start = ggml_time_us();
|
||||
|
||||
LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__);
|
||||
|
@ -201,6 +203,9 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("Processing requests ...\n\n");
|
||||
|
||||
while (true) {
|
||||
llama_kv_cache_view_update(ctx, &kvc_view);
|
||||
dump_kv_cache_view(kvc_view);
|
||||
|
||||
llama_batch_clear(batch);
|
||||
|
||||
// decode any currently ongoing sequences
|
||||
|
|
67
llama.cpp
67
llama.cpp
|
@ -8805,6 +8805,73 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha
|
|||
}
|
||||
}
|
||||
|
||||
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq) {
|
||||
struct llama_kv_cache_view result = {
|
||||
/*.n_cells*/ 0,
|
||||
/*.n_max_seq*/ n_max_seq,
|
||||
/*.token_count*/ 0,
|
||||
/*.used_cells*/ llama_get_kv_cache_used_cells(ctx),
|
||||
/*.cells*/ nullptr,
|
||||
/*.cells_sequences*/ nullptr,
|
||||
};
|
||||
return result;
|
||||
}
|
||||
|
||||
void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
|
||||
if (view->cells != nullptr) {
|
||||
free(view->cells);
|
||||
view->cells = nullptr;
|
||||
}
|
||||
if (view->cells_sequences != nullptr) {
|
||||
free(view->cells_sequences);
|
||||
view->cells_sequences = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) {
|
||||
if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) {
|
||||
view->n_cells = int32_t(ctx->kv_self.size);
|
||||
void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
|
||||
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
|
||||
view->cells = (struct llama_kv_cache_view_cell *)p;
|
||||
p = realloc(view->cells_sequences, sizeof(struct llama_kv_cache_view_cell_sequence) * view->n_max_seq * view->n_cells);
|
||||
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
|
||||
view->cells_sequences = (struct llama_kv_cache_view_cell_sequence *)p;
|
||||
}
|
||||
|
||||
const std::vector<llama_kv_cell> & kv_cells = ctx->kv_self.cells;
|
||||
llama_kv_cache_view_cell * c_curr = view->cells;
|
||||
struct llama_kv_cache_view_cell_sequence * cs_curr = view->cells_sequences;
|
||||
int32_t used_cells = 0;
|
||||
int32_t token_count = 0;
|
||||
|
||||
for (uint32_t i = 0; i < ctx->kv_self.size; i++, c_curr++, cs_curr += view->n_max_seq) {
|
||||
token_count += ctx->kv_self.cells[i].seq_id.size();
|
||||
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
|
||||
|
||||
int seq_idx = 0;
|
||||
for (const llama_seq_id it : kv_cells[i].seq_id) {
|
||||
if (seq_idx >= view->n_max_seq) {
|
||||
break;
|
||||
}
|
||||
cs_curr[seq_idx].seq_id = it;
|
||||
seq_idx++;
|
||||
}
|
||||
if (seq_idx != 0) {
|
||||
used_cells++;
|
||||
}
|
||||
for (; seq_idx < view->n_max_seq; seq_idx++) {
|
||||
cs_curr[seq_idx].seq_id = -1;
|
||||
}
|
||||
}
|
||||
view->token_count = token_count;
|
||||
view->used_cells = used_cells;
|
||||
if (uint32_t(used_cells) != ctx->kv_self.used) {
|
||||
LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
|
||||
__func__, ctx->kv_self.used, used_cells);
|
||||
}
|
||||
}
|
||||
|
||||
int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
|
||||
int result = 0;
|
||||
|
||||
|
|
23
llama.h
23
llama.h
|
@ -361,6 +361,29 @@ extern "C" {
|
|||
// KV cache
|
||||
//
|
||||
|
||||
struct llama_kv_cache_view_cell {
|
||||
llama_pos pos;
|
||||
};
|
||||
|
||||
struct llama_kv_cache_view_cell_sequence {
|
||||
llama_seq_id seq_id;
|
||||
};
|
||||
|
||||
struct llama_kv_cache_view {
|
||||
int32_t n_cells;
|
||||
int32_t n_max_seq;
|
||||
int32_t token_count;
|
||||
int32_t used_cells;
|
||||
struct llama_kv_cache_view_cell *cells;
|
||||
struct llama_kv_cache_view_cell_sequence * cells_sequences;
|
||||
};
|
||||
|
||||
LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq);
|
||||
|
||||
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
|
||||
|
||||
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
|
||||
|
||||
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
||||
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
||||
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue