diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 574728f89..1e483edc0 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -150,7 +150,7 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_div(ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); llama_kv_cache_update (ctx); - n_past -= bd; + n_past = llama_kv_cache_seq_pos_max(ctx, 0); } llama_batch_clear(batch); @@ -184,7 +184,7 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); llama_kv_cache_update (ctx); - n_past -= n_discard; + n_past = llama_kv_cache_seq_pos_max(ctx, 0); llama_batch_clear(batch); @@ -214,7 +214,7 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); llama_kv_cache_update (ctx); - n_past -= n_discard; + n_past = llama_kv_cache_seq_pos_max(ctx, 0); } } diff --git a/llama.cpp b/llama.cpp index 263fdf13e..46c82b4ad 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2241,6 +2241,18 @@ static void llama_kv_cache_seq_div( } } +static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) { + llama_pos result = 0; + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id)) { + result = std::max(result, cache.cells[i].pos); + } + } + + return result; +} + // // model loading and saving // @@ -12056,6 +12068,10 @@ void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, lla llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d); } +llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { + return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id); +} + void llama_kv_cache_update(struct llama_context * ctx) { llama_kv_cache_update_internal(*ctx); } diff --git a/llama.h b/llama.h index b1621d6a3..faea891e4 100644 --- a/llama.h +++ b/llama.h @@ -547,6 +547,11 @@ extern "C" { llama_pos p1, int d); + // Returns the largest position present in the KV cache for the specified sequence + LLAMA_API llama_pos llama_kv_cache_seq_pos_max( + struct llama_context * ctx, + llama_seq_id seq_id); + // Apply the KV cache updates (such as K-shifts) to the KV data LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);