kv cache slot search improvements
This commit is contained in:
parent
019ba1dcd0
commit
abafd01ec8
1 changed files with 25 additions and 5 deletions
30
llama.cpp
30
llama.cpp
|
@ -1316,8 +1316,8 @@ static bool llama_kv_cache_find_slot(
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
if (cache.head + n_tokens > n_ctx) {
|
if (cache.head + n_tokens > n_ctx) {
|
||||||
|
n_tested += cache.size - cache.head;
|
||||||
cache.head = 0;
|
cache.head = 0;
|
||||||
n_tested += n_ctx - cache.head;
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1368,6 +1368,9 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0,
|
||||||
cache.cells[i].pos = -1;
|
cache.cells[i].pos = -1;
|
||||||
cache.cells[i].seq_id.clear();
|
cache.cells[i].seq_id.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Searching for a free slot can start here since we know it will be empty.
|
||||||
|
cache.head = uint32_t(c0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_kv_cache_seq_rm(
|
static void llama_kv_cache_seq_rm(
|
||||||
|
@ -1375,6 +1378,8 @@ static void llama_kv_cache_seq_rm(
|
||||||
llama_seq_id seq_id,
|
llama_seq_id seq_id,
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
llama_pos p1) {
|
llama_pos p1) {
|
||||||
|
uint32_t new_head = cache.size;
|
||||||
|
|
||||||
if (p0 < 0) p0 = 0;
|
if (p0 < 0) p0 = 0;
|
||||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||||
|
|
||||||
|
@ -1383,9 +1388,13 @@ static void llama_kv_cache_seq_rm(
|
||||||
cache.cells[i].seq_id.erase(seq_id);
|
cache.cells[i].seq_id.erase(seq_id);
|
||||||
if (cache.cells[i].seq_id.empty()) {
|
if (cache.cells[i].seq_id.empty()) {
|
||||||
cache.cells[i].pos = -1;
|
cache.cells[i].pos = -1;
|
||||||
|
if (new_head == cache.size) new_head = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we freed up a slot, set head to it so searching can start there.
|
||||||
|
if (new_head != cache.size) cache.head = new_head;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_kv_cache_seq_cp(
|
static void llama_kv_cache_seq_cp(
|
||||||
|
@ -1397,6 +1406,8 @@ static void llama_kv_cache_seq_cp(
|
||||||
if (p0 < 0) p0 = 0;
|
if (p0 < 0) p0 = 0;
|
||||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||||
|
|
||||||
|
cache.head = 0;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||||
if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
||||||
cache.cells[i].seq_id.insert(seq_id_dst);
|
cache.cells[i].seq_id.insert(seq_id_dst);
|
||||||
|
@ -1405,12 +1416,18 @@ static void llama_kv_cache_seq_cp(
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
|
static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
|
||||||
|
uint32_t new_head = cache.size;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||||
if (!cache.cells[i].has_seq_id(seq_id)) {
|
if (!cache.cells[i].has_seq_id(seq_id)) {
|
||||||
cache.cells[i].pos = -1;
|
cache.cells[i].pos = -1;
|
||||||
cache.cells[i].seq_id.clear();
|
cache.cells[i].seq_id.clear();
|
||||||
|
if (new_head == cache.size) new_head = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we freed up a slot, set head to it so searching can start there.
|
||||||
|
if (new_head != cache.size) cache.head = new_head;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_kv_cache_seq_shift(
|
static void llama_kv_cache_seq_shift(
|
||||||
|
@ -1419,6 +1436,8 @@ static void llama_kv_cache_seq_shift(
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
llama_pos p1,
|
llama_pos p1,
|
||||||
llama_pos delta) {
|
llama_pos delta) {
|
||||||
|
uint32_t new_head = cache.size;
|
||||||
|
|
||||||
if (p0 < 0) p0 = 0;
|
if (p0 < 0) p0 = 0;
|
||||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||||
|
|
||||||
|
@ -1428,12 +1447,17 @@ static void llama_kv_cache_seq_shift(
|
||||||
if (cache.cells[i].pos < 0) {
|
if (cache.cells[i].pos < 0) {
|
||||||
cache.cells[i].pos = -1;
|
cache.cells[i].pos = -1;
|
||||||
cache.cells[i].seq_id.clear();
|
cache.cells[i].seq_id.clear();
|
||||||
|
if (new_head == cache.size) new_head = i;
|
||||||
} else {
|
} else {
|
||||||
cache.has_shift = true;
|
cache.has_shift = true;
|
||||||
cache.cells[i].delta = delta;
|
cache.cells[i].delta = delta;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we freed up a slot, set head to it so searching can start there.
|
||||||
|
// Otherwise we just start the next search from the beginning.
|
||||||
|
cache.head = new_head != cache.size ? new_head : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -4454,10 +4478,6 @@ static int llama_decode_internal(
|
||||||
batch.seq_id = seq_id.data();
|
batch.seq_id = seq_id.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
// we always start to search for a free slot from the start of the cache
|
|
||||||
// TODO: better strategies can be implemented
|
|
||||||
kv_self.head = 0;
|
|
||||||
|
|
||||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue