Track max contiguous cells value and position as well

This commit is contained in:
KerfuffleV2 2023-11-23 05:36:41 -07:00
parent cb137d8bfc
commit 22d0485a7a
3 changed files with 39 additions and 14 deletions

View file

@ -1393,8 +1393,8 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
// //
void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) { void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) {
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d\n", printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, max contiguous cells=%d @ %d\n",
view.n_cells, view.n_max_seq, view.used_cells, view.token_count); view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous_cells, view.max_contiguous_cells_idx);
llama_kv_cache_view_cell * c_curr = view.cells; llama_kv_cache_view_cell * c_curr = view.cells;
struct llama_kv_cache_view_cell_sequence * cs_curr = view.cells_sequences; struct llama_kv_cache_view_cell_sequence * cs_curr = view.cells_sequences;
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) { for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
@ -1405,14 +1405,14 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) {
for (int j = 0; j < view.n_max_seq; j++) { for (int j = 0; j < view.n_max_seq; j++) {
if (cs_curr[j].seq_id >= 0) { seq_count++; } if (cs_curr[j].seq_id >= 0) { seq_count++; }
} }
putchar(int('0' + (std::min(9, seq_count)))); putchar(seq_count == 0 ? '.' : ('0' + (std::min(9, seq_count))));
} }
printf("\n=== Done dumping\n"); printf("\n=== Done dumping\n");
} }
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) { void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d\n", printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, max contiguous cells=%d @ %d\n",
view.n_cells, view.n_max_seq, view.used_cells, view.token_count); view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous_cells, view.max_contiguous_cells_idx);
std::unordered_map<llama_seq_id, size_t> seqs; std::unordered_map<llama_seq_id, size_t> seqs;
llama_kv_cache_view_cell * c_curr = view.cells; llama_kv_cache_view_cell * c_curr = view.cells;

View file

@ -8811,6 +8811,8 @@ struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context *
/*.n_max_seq*/ n_max_seq, /*.n_max_seq*/ n_max_seq,
/*.token_count*/ 0, /*.token_count*/ 0,
/*.used_cells*/ llama_get_kv_cache_used_cells(ctx), /*.used_cells*/ llama_get_kv_cache_used_cells(ctx),
/*max_contiguous*/ 0,
/*max_contiguous_idx*/ -1,
/*.cells*/ nullptr, /*.cells*/ nullptr,
/*.cells_sequences*/ nullptr, /*.cells_sequences*/ nullptr,
}; };
@ -8844,11 +8846,25 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k
struct llama_kv_cache_view_cell_sequence * cs_curr = view->cells_sequences; struct llama_kv_cache_view_cell_sequence * cs_curr = view->cells_sequences;
int32_t used_cells = 0; int32_t used_cells = 0;
int32_t token_count = 0; int32_t token_count = 0;
int32_t curr_contig_idx = -1;
uint32_t max_contig = 0;
int32_t max_contig_idx = -1;
for (uint32_t i = 0; i < ctx->kv_self.size; i++, c_curr++, cs_curr += view->n_max_seq) { for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_max_seq) {
token_count += ctx->kv_self.cells[i].seq_id.size(); const size_t curr_size = kv_cells[i].seq_id.size();
token_count += curr_size;
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta; c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
if (curr_size > 0) {
if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
max_contig_idx = i;
max_contig = i - curr_contig_idx;
}
curr_contig_idx = -1;
} else if (curr_contig_idx < 0) {
curr_contig_idx = i;
}
int seq_idx = 0; int seq_idx = 0;
for (const llama_seq_id it : kv_cells[i].seq_id) { for (const llama_seq_id it : kv_cells[i].seq_id) {
if (seq_idx >= view->n_max_seq) { if (seq_idx >= view->n_max_seq) {
@ -8864,6 +8880,12 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k
cs_curr[seq_idx].seq_id = -1; cs_curr[seq_idx].seq_id = -1;
} }
} }
if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
max_contig_idx = curr_contig_idx;
max_contig = kv_cells.size() - curr_contig_idx;
}
view->max_contiguous_cells = max_contig;
view->max_contiguous_cells_idx = max_contig_idx;
view->token_count = token_count; view->token_count = token_count;
view->used_cells = used_cells; view->used_cells = used_cells;
if (uint32_t(used_cells) != ctx->kv_self.used) { if (uint32_t(used_cells) != ctx->kv_self.used) {

View file

@ -366,6 +366,7 @@ extern "C" {
}; };
struct llama_kv_cache_view_cell_sequence { struct llama_kv_cache_view_cell_sequence {
// Would like to have token_id here as well.
llama_seq_id seq_id; llama_seq_id seq_id;
}; };
@ -374,7 +375,9 @@ extern "C" {
int32_t n_max_seq; int32_t n_max_seq;
int32_t token_count; int32_t token_count;
int32_t used_cells; int32_t used_cells;
struct llama_kv_cache_view_cell *cells; int32_t max_contiguous_cells;
int32_t max_contiguous_cells_idx;
struct llama_kv_cache_view_cell * cells;
struct llama_kv_cache_view_cell_sequence * cells_sequences; struct llama_kv_cache_view_cell_sequence * cells_sequences;
}; };