diff --git a/common/common.cpp b/common/common.cpp index f81f4d354..c187128d6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -889,7 +889,7 @@ std::tuple llama_init_from_gpt_par std::vector tmp = { llama_token_bos(model), llama_token_eos(model), }; llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); - llama_kv_cache_tokens_rm(lctx, -1, -1); + llama_kv_cache_clear(lctx); llama_reset_timings(lctx); } diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 43f9c971d..533c55c17 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -185,7 +185,7 @@ int main(int argc, char ** argv) { const auto t_pp_start = ggml_time_us(); - llama_kv_cache_tokens_rm(ctx, -1, -1); + llama_kv_cache_clear(ctx); if (!decode_helper(ctx, batch, ctx_params.n_batch)) { LOG_TEE("%s: llama_decode() failed\n", __func__); diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 20767d555..780398184 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1037,7 +1037,7 @@ int main(int argc, char ** argv) { test t(inst, lmodel, ctx); - llama_kv_cache_tokens_rm(ctx, -1, -1); + llama_kv_cache_clear(ctx); // warmup run if (t.n_prompt > 0) { @@ -1048,7 +1048,7 @@ int main(int argc, char ** argv) { } for (int i = 0; i < params.reps; i++) { - llama_kv_cache_tokens_rm(ctx, -1, -1); + llama_kv_cache_clear(ctx); uint64_t t_start = get_time_ns(); if (t.n_prompt > 0) { diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3d9f670b9..8a43b6ab8 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -298,7 +298,7 @@ int main(int argc, char ** argv) { } // remove any "future" tokens that we might have inherited from the previous session - llama_kv_cache_tokens_rm(ctx, n_matching_session_tokens, -1); + llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1); } LOGLN( diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 3c2542e8c..bd2c73d87 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -210,7 +210,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_cache_tokens_rm(ctx, -1, -1); + llama_kv_cache_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -339,7 +339,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_cache_tokens_rm(ctx, -1, -1); + llama_kv_cache_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -573,7 +573,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { } // clear the KV cache - llama_kv_cache_tokens_rm(ctx, -1, -1); + llama_kv_cache_clear(ctx); auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab); if (logits.empty()) { diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 5b7e4139d..c163c7f8e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -857,7 +857,7 @@ struct llama_server_context void kv_cache_clear() { // clear the entire KV cache - llama_kv_cache_tokens_rm(ctx, -1, -1); + llama_kv_cache_clear(ctx); clean_kv_cache = false; } diff --git a/llama.cpp b/llama.cpp index f589ab860..68eb7aaca 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1468,17 +1468,12 @@ static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { return 0; } -static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, int32_t c1) { - if (c0 < 0) c0 = 0; - if (c1 < 0) c1 = cache.size; - - for (int32_t i = c0; i < c1; ++i) { +static void llama_kv_cache_clear(struct llama_kv_cache & cache) { + for (int32_t i = 0; i < cache.size; ++i) { cache.cells[i].pos = -1; cache.cells[i].seq_id.clear(); } - - // Searching for a free slot can start here since we know it will be empty. - cache.head = uint32_t(c0); + cache.head = 0; } static void llama_kv_cache_seq_rm( @@ -9210,8 +9205,8 @@ int llama_get_kv_cache_token_count(const struct llama_context * ctx) { return ctx->kv_self.head; } -void llama_kv_cache_tokens_rm(struct llama_context * ctx, int32_t c0, int32_t c1) { - llama_kv_cache_tokens_rm(ctx->kv_self, c0, c1); +void llama_kv_cache_clear(struct llama_context * ctx) { + llama_kv_cache_clear(ctx->kv_self); } void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { @@ -9657,7 +9652,7 @@ int llama_eval( llama_token * tokens, int32_t n_tokens, int n_past) { - llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); + llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1); const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0)); if (ret < 0) { @@ -9672,7 +9667,7 @@ int llama_eval_embd( float * embd, int32_t n_tokens, int n_past) { - llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); + llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1); llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, }; diff --git a/llama.h b/llama.h index 1c50291be..d59f64891 100644 --- a/llama.h +++ b/llama.h @@ -333,13 +333,9 @@ extern "C" { LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx), "avoid using this, it will be removed in the future, instead - count the tokens in user code"); - // Remove all tokens data of cells in [c0, c1) - // c0 < 0 : [0, c1] - // c1 < 0 : [c0, inf) - LLAMA_API void llama_kv_cache_tokens_rm( - struct llama_context * ctx, - int32_t c0, - int32_t c1); + // Clear the KV cache + LLAMA_API void llama_kv_cache_clear( + struct llama_context * ctx); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) // seq_id < 0 : match any sequence