llama : KV cache view API + better KV cache management (#4170)

* llama : keep track of used KV cells + better KV cache management

* llama : zero KV cache used upon clear

ggml-ci

* llama : allow exporting a view of the KV cache (#4180)

* Allow exporting a view of the KV cache

* Allow dumping the sequences per cell in common

* Track max contiguous cells value and position as well

* Fix max contiguous empty cells index calculation

Make dump functions deal with lengths or sequences counts > 10 better

* Fix off by one error in dump_kv_cache_view

* Add doc comments for KV cache view functions

Eliminate cell sequence struct; use llama_seq_id directly

Minor cleanups

* common : add -dkvc arg for enabling kv cache dumps

---------

Co-authored-by: Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com>
This commit is contained in:
Georgi Gerganov 2023-11-23 19:07:56 +02:00 committed by GitHub
parent d103d935c0
commit 6b0a7420d0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 277 additions and 7 deletions

128
llama.cpp
View file

@ -1280,6 +1280,7 @@ struct llama_kv_cache {
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;
uint32_t used = 0; // used cells (i.e. at least one seq_id)
// computed before each graph build
uint32_t n = 0;
@ -1504,6 +1505,7 @@ static bool llama_kv_cache_init(
cache.head = 0;
cache.size = n_ctx;
cache.used = 0;
cache.cells.clear();
cache.cells.resize(n_ctx);
@ -1605,6 +1607,8 @@ static bool llama_kv_cache_find_slot(
}
}
cache.used += n_tokens;
return true;
}
@ -1625,6 +1629,7 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
cache.cells[i].seq_id.clear();
}
cache.head = 0;
cache.used = 0;
}
static void llama_kv_cache_seq_rm(
@ -1647,6 +1652,9 @@ static void llama_kv_cache_seq_rm(
continue;
}
if (cache.cells[i].seq_id.empty()) {
// keep count of the number of used cells
if (cache.cells[i].pos >= 0) cache.used--;
cache.cells[i].pos = -1;
if (new_head == cache.size) new_head = i;
}
@ -1654,7 +1662,7 @@ static void llama_kv_cache_seq_rm(
}
// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size) cache.head = new_head;
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
}
static void llama_kv_cache_seq_cp(
@ -1680,6 +1688,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
for (uint32_t i = 0; i < cache.size; ++i) {
if (!cache.cells[i].has_seq_id(seq_id)) {
if (cache.cells[i].pos >= 0) cache.used--;
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
@ -1690,7 +1699,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
}
// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size) cache.head = new_head;
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
}
static void llama_kv_cache_seq_shift(
@ -1711,6 +1720,7 @@ static void llama_kv_cache_seq_shift(
cache.cells[i].delta += delta;
if (cache.cells[i].pos < 0) {
if (!cache.cells[i].seq_id.empty()) cache.used--;
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
@ -5469,6 +5479,12 @@ static int llama_decode_internal(
batch.seq_id = seq_id_arr.data();
}
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*n_tokens) {
kv_self.head = 0;
}
if (!llama_kv_cache_find_slot(kv_self, batch)) {
return 1;
}
@ -5479,7 +5495,7 @@ static int llama_decode_internal(
//kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA?
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self)));
//printf("kv_self.n = %d\n", kv_self.n);
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
ggml_allocr_reset(lctx.alloc);
@ -8789,8 +8805,107 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha
}
}
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq) {
struct llama_kv_cache_view result = {
/*.n_cells = */ 0,
/*.n_max_seq = */ n_max_seq,
/*.token_count = */ 0,
/*.used_cells = */ llama_get_kv_cache_used_cells(ctx),
/*.max_contiguous = */ 0,
/*.max_contiguous_idx = */ -1,
/*.cells = */ nullptr,
/*.cells_sequences = */ nullptr,
};
return result;
}
void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
if (view->cells != nullptr) {
free(view->cells);
view->cells = nullptr;
}
if (view->cells_sequences != nullptr) {
free(view->cells_sequences);
view->cells_sequences = nullptr;
}
}
void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) {
if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) {
view->n_cells = int32_t(ctx->kv_self.size);
void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
view->cells = (struct llama_kv_cache_view_cell *)p;
p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_max_seq * view->n_cells);
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
view->cells_sequences = (llama_seq_id *)p;
}
const std::vector<llama_kv_cell> & kv_cells = ctx->kv_self.cells;
llama_kv_cache_view_cell * c_curr = view->cells;
llama_seq_id * cs_curr = view->cells_sequences;
int32_t used_cells = 0;
int32_t token_count = 0;
int32_t curr_contig_idx = -1;
uint32_t max_contig = 0;
int32_t max_contig_idx = -1;
for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_max_seq) {
const size_t curr_size = kv_cells[i].seq_id.size();
token_count += curr_size;
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
if (curr_size > 0) {
if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
max_contig = i - curr_contig_idx;
max_contig_idx = curr_contig_idx;
}
curr_contig_idx = -1;
} else if (curr_contig_idx < 0) {
curr_contig_idx = i;
}
int seq_idx = 0;
for (const llama_seq_id it : kv_cells[i].seq_id) {
if (seq_idx >= view->n_max_seq) {
break;
}
cs_curr[seq_idx] = it;
seq_idx++;
}
if (seq_idx != 0) {
used_cells++;
}
for (; seq_idx < view->n_max_seq; seq_idx++) {
cs_curr[seq_idx] = -1;
}
}
if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
max_contig_idx = curr_contig_idx;
max_contig = kv_cells.size() - curr_contig_idx;
}
view->max_contiguous = max_contig;
view->max_contiguous_idx = max_contig_idx;
view->token_count = token_count;
view->used_cells = used_cells;
if (uint32_t(used_cells) != ctx->kv_self.used) {
LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
__func__, ctx->kv_self.used, used_cells);
}
}
int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
return ctx->kv_self.head;
int result = 0;
for (uint32_t i = 0; i < ctx->kv_self.size; i++) {
result += ctx->kv_self.cells[i].seq_id.size();
}
return result;
}
int llama_get_kv_cache_used_cells(const struct llama_context * ctx) {
return ctx->kv_self.used;
}
void llama_kv_cache_clear(struct llama_context * ctx) {
@ -8960,10 +9075,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
const size_t kv_buf_size = kv_self.buf.size;
const uint32_t kv_head = kv_self.head;
const uint32_t kv_size = kv_self.size;
const uint32_t kv_used = kv_self.used;
data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
data_ctx->write(&kv_head, sizeof(kv_head));
data_ctx->write(&kv_size, sizeof(kv_size));
data_ctx->write(&kv_used, sizeof(kv_used));
if (kv_buf_size) {
const size_t elt_size = ggml_element_size(kv_self.k);
@ -9086,10 +9203,12 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
size_t kv_buf_size;
uint32_t kv_head;
uint32_t kv_size;
uint32_t kv_used;
memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head);
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used);
if (kv_buf_size) {
GGML_ASSERT(kv_self.buf.size == kv_buf_size);
@ -9124,6 +9243,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
ctx->kv_self.head = kv_head;
ctx->kv_self.size = kv_size;
ctx->kv_self.used = kv_used;
ctx->kv_self.cells.resize(kv_size);