Track max contiguous cells value and position as well

This commit is contained in:
KerfuffleV2 2023-11-23 05:36:41 -07:00
parent cb137d8bfc
commit 22d0485a7a
3 changed files with 39 additions and 14 deletions

View file

@ -1393,8 +1393,8 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
//
void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) {
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d\n",
view.n_cells, view.n_max_seq, view.used_cells, view.token_count);
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, max contiguous cells=%d @ %d\n",
view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous_cells, view.max_contiguous_cells_idx);
llama_kv_cache_view_cell * c_curr = view.cells;
struct llama_kv_cache_view_cell_sequence * cs_curr = view.cells_sequences;
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
@ -1405,14 +1405,14 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) {
for (int j = 0; j < view.n_max_seq; j++) {
if (cs_curr[j].seq_id >= 0) { seq_count++; }
}
putchar(int('0' + (std::min(9, seq_count))));
putchar(seq_count == 0 ? '.' : ('0' + (std::min(9, seq_count))));
}
printf("\n=== Done dumping\n");
}
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d\n",
view.n_cells, view.n_max_seq, view.used_cells, view.token_count);
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, max contiguous cells=%d @ %d\n",
view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous_cells, view.max_contiguous_cells_idx);
std::unordered_map<llama_seq_id, size_t> seqs;
llama_kv_cache_view_cell * c_curr = view.cells;

View file

@ -8807,12 +8807,14 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq) {
struct llama_kv_cache_view result = {
/*.n_cells*/ 0,
/*.n_max_seq*/ n_max_seq,
/*.token_count*/ 0,
/*.used_cells*/ llama_get_kv_cache_used_cells(ctx),
/*.cells*/ nullptr,
/*.cells_sequences*/ nullptr,
/*.n_cells*/ 0,
/*.n_max_seq*/ n_max_seq,
/*.token_count*/ 0,
/*.used_cells*/ llama_get_kv_cache_used_cells(ctx),
/*max_contiguous*/ 0,
/*max_contiguous_idx*/ -1,
/*.cells*/ nullptr,
/*.cells_sequences*/ nullptr,
};
return result;
}
@ -8844,11 +8846,25 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k
struct llama_kv_cache_view_cell_sequence * cs_curr = view->cells_sequences;
int32_t used_cells = 0;
int32_t token_count = 0;
int32_t curr_contig_idx = -1;
uint32_t max_contig = 0;
int32_t max_contig_idx = -1;
for (uint32_t i = 0; i < ctx->kv_self.size; i++, c_curr++, cs_curr += view->n_max_seq) {
token_count += ctx->kv_self.cells[i].seq_id.size();
for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_max_seq) {
const size_t curr_size = kv_cells[i].seq_id.size();
token_count += curr_size;
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
if (curr_size > 0) {
if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
max_contig_idx = i;
max_contig = i - curr_contig_idx;
}
curr_contig_idx = -1;
} else if (curr_contig_idx < 0) {
curr_contig_idx = i;
}
int seq_idx = 0;
for (const llama_seq_id it : kv_cells[i].seq_id) {
if (seq_idx >= view->n_max_seq) {
@ -8864,6 +8880,12 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k
cs_curr[seq_idx].seq_id = -1;
}
}
if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
max_contig_idx = curr_contig_idx;
max_contig = kv_cells.size() - curr_contig_idx;
}
view->max_contiguous_cells = max_contig;
view->max_contiguous_cells_idx = max_contig_idx;
view->token_count = token_count;
view->used_cells = used_cells;
if (uint32_t(used_cells) != ctx->kv_self.used) {

View file

@ -366,6 +366,7 @@ extern "C" {
};
struct llama_kv_cache_view_cell_sequence {
// Would like to have token_id here as well.
llama_seq_id seq_id;
};
@ -374,7 +375,9 @@ extern "C" {
int32_t n_max_seq;
int32_t token_count;
int32_t used_cells;
struct llama_kv_cache_view_cell *cells;
int32_t max_contiguous_cells;
int32_t max_contiguous_cells_idx;
struct llama_kv_cache_view_cell * cells;
struct llama_kv_cache_view_cell_sequence * cells_sequences;
};