From ebd4b91327cd8d7151ad652ce7264fb1d9b3d302 Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Sat, 28 Oct 2023 21:06:55 -0600 Subject: [PATCH] Extend llama_kv_cache_seq_rm to allow matichng any sequence --- llama.cpp | 10 ++++++++-- llama.h | 5 +++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/llama.cpp b/llama.cpp index 3d431ee7b..f589ab860 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1492,8 +1492,14 @@ static void llama_kv_cache_seq_rm( if (p1 < 0) p1 = std::numeric_limits::max(); for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.cells[i].seq_id.erase(seq_id); + if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + if (seq_id < 0) { + cache.cells[i].seq_id.clear(); + } else if (cache.cells[i].has_seq_id(seq_id)) { + cache.cells[i].seq_id.erase(seq_id); + } else { + continue; + } if (cache.cells[i].seq_id.empty()) { cache.cells[i].pos = -1; if (new_head == cache.size) new_head = i; diff --git a/llama.h b/llama.h index d901dcd91..1c50291be 100644 --- a/llama.h +++ b/llama.h @@ -342,8 +342,9 @@ extern "C" { int32_t c1); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) - // p0 < 0 : [0, p1] - // p1 < 0 : [p0, inf) + // seq_id < 0 : match any sequence + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) LLAMA_API void llama_kv_cache_seq_rm( struct llama_context * ctx, llama_seq_id seq_id,