kv cache slot search improvements

This commit is contained in:
KerfuffleV2 2023-10-05 16:06:55 -06:00
parent 019ba1dcd0
commit abafd01ec8

View file

@ -1316,8 +1316,8 @@ static bool llama_kv_cache_find_slot(
while (true) { while (true) {
if (cache.head + n_tokens > n_ctx) { if (cache.head + n_tokens > n_ctx) {
n_tested += cache.size - cache.head;
cache.head = 0; cache.head = 0;
n_tested += n_ctx - cache.head;
continue; continue;
} }
@ -1368,6 +1368,9 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0,
cache.cells[i].pos = -1; cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear(); cache.cells[i].seq_id.clear();
} }
// Searching for a free slot can start here since we know it will be empty.
cache.head = uint32_t(c0);
} }
static void llama_kv_cache_seq_rm( static void llama_kv_cache_seq_rm(
@ -1375,6 +1378,8 @@ static void llama_kv_cache_seq_rm(
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1) { llama_pos p1) {
uint32_t new_head = cache.size;
if (p0 < 0) p0 = 0; if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
@ -1383,9 +1388,13 @@ static void llama_kv_cache_seq_rm(
cache.cells[i].seq_id.erase(seq_id); cache.cells[i].seq_id.erase(seq_id);
if (cache.cells[i].seq_id.empty()) { if (cache.cells[i].seq_id.empty()) {
cache.cells[i].pos = -1; cache.cells[i].pos = -1;
if (new_head == cache.size) new_head = i;
} }
} }
} }
// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size) cache.head = new_head;
} }
static void llama_kv_cache_seq_cp( static void llama_kv_cache_seq_cp(
@ -1397,6 +1406,8 @@ static void llama_kv_cache_seq_cp(
if (p0 < 0) p0 = 0; if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
cache.head = 0;
for (uint32_t i = 0; i < cache.size; ++i) { for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.cells[i].seq_id.insert(seq_id_dst); cache.cells[i].seq_id.insert(seq_id_dst);
@ -1405,12 +1416,18 @@ static void llama_kv_cache_seq_cp(
} }
static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
uint32_t new_head = cache.size;
for (uint32_t i = 0; i < cache.size; ++i) { for (uint32_t i = 0; i < cache.size; ++i) {
if (!cache.cells[i].has_seq_id(seq_id)) { if (!cache.cells[i].has_seq_id(seq_id)) {
cache.cells[i].pos = -1; cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear(); cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
} }
} }
// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size) cache.head = new_head;
} }
static void llama_kv_cache_seq_shift( static void llama_kv_cache_seq_shift(
@ -1419,6 +1436,8 @@ static void llama_kv_cache_seq_shift(
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
llama_pos delta) { llama_pos delta) {
uint32_t new_head = cache.size;
if (p0 < 0) p0 = 0; if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
@ -1428,12 +1447,17 @@ static void llama_kv_cache_seq_shift(
if (cache.cells[i].pos < 0) { if (cache.cells[i].pos < 0) {
cache.cells[i].pos = -1; cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear(); cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
} else { } else {
cache.has_shift = true; cache.has_shift = true;
cache.cells[i].delta = delta; cache.cells[i].delta = delta;
} }
} }
} }
// If we freed up a slot, set head to it so searching can start there.
// Otherwise we just start the next search from the beginning.
cache.head = new_head != cache.size ? new_head : 0;
} }
// //
@ -4454,10 +4478,6 @@ static int llama_decode_internal(
batch.seq_id = seq_id.data(); batch.seq_id = seq_id.data();
} }
// we always start to search for a free slot from the start of the cache
// TODO: better strategies can be implemented
kv_self.head = 0;
if (!llama_kv_cache_find_slot(kv_self, batch)) { if (!llama_kv_cache_find_slot(kv_self, batch)) {
return 1; return 1;
} }