Extend llama_kv_cache_seq_rm to allow matching any sequence (#3843)
* Extend llama_kv_cache_seq_rm to allow matichng any sequence * Replace llama_kv_cache_tokens_rm with llama_kv_cache_clear Use llama_kv_cache_clear for cache clearing Change calls to llama_kv_cache_tokens_rm that want to delete by position to use llama_kv_cache_seq_rm functionality
This commit is contained in:
parent
2046eb4345
commit
6e08281e58
8 changed files with 30 additions and 32 deletions
29
llama.cpp
29
llama.cpp
|
@ -1466,17 +1466,12 @@ static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, int32_t c1) {
|
||||
if (c0 < 0) c0 = 0;
|
||||
if (c1 < 0) c1 = cache.size;
|
||||
|
||||
for (int32_t i = c0; i < c1; ++i) {
|
||||
static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
|
||||
for (int32_t i = 0; i < cache.size; ++i) {
|
||||
cache.cells[i].pos = -1;
|
||||
cache.cells[i].seq_id.clear();
|
||||
}
|
||||
|
||||
// Searching for a free slot can start here since we know it will be empty.
|
||||
cache.head = uint32_t(c0);
|
||||
cache.head = 0;
|
||||
}
|
||||
|
||||
static void llama_kv_cache_seq_rm(
|
||||
|
@ -1490,8 +1485,14 @@ static void llama_kv_cache_seq_rm(
|
|||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
||||
cache.cells[i].seq_id.erase(seq_id);
|
||||
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
||||
if (seq_id < 0) {
|
||||
cache.cells[i].seq_id.clear();
|
||||
} else if (cache.cells[i].has_seq_id(seq_id)) {
|
||||
cache.cells[i].seq_id.erase(seq_id);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
if (cache.cells[i].seq_id.empty()) {
|
||||
cache.cells[i].pos = -1;
|
||||
if (new_head == cache.size) new_head = i;
|
||||
|
@ -9207,8 +9208,8 @@ int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
|
|||
return ctx->kv_self.head;
|
||||
}
|
||||
|
||||
void llama_kv_cache_tokens_rm(struct llama_context * ctx, int32_t c0, int32_t c1) {
|
||||
llama_kv_cache_tokens_rm(ctx->kv_self, c0, c1);
|
||||
void llama_kv_cache_clear(struct llama_context * ctx) {
|
||||
llama_kv_cache_clear(ctx->kv_self);
|
||||
}
|
||||
|
||||
void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
|
@ -9654,7 +9655,7 @@ int llama_eval(
|
|||
llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
int n_past) {
|
||||
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
|
||||
llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1);
|
||||
|
||||
const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0));
|
||||
if (ret < 0) {
|
||||
|
@ -9669,7 +9670,7 @@ int llama_eval_embd(
|
|||
float * embd,
|
||||
int32_t n_tokens,
|
||||
int n_past) {
|
||||
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
|
||||
llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1);
|
||||
|
||||
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue