llama : add llama_kv_cache_seq_pos_max()

This commit is contained in:
Georgi Gerganov 2024-02-24 12:54:29 +02:00
parent 18da970e1c
commit b75ec64ed2
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 24 additions and 3 deletions

View file

@ -150,7 +150,7 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_div(ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
llama_kv_cache_update (ctx);
n_past -= bd;
n_past = llama_kv_cache_seq_pos_max(ctx, 0);
}
llama_batch_clear(batch);
@ -184,7 +184,7 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
llama_kv_cache_update (ctx);
n_past -= n_discard;
n_past = llama_kv_cache_seq_pos_max(ctx, 0);
llama_batch_clear(batch);
@ -214,7 +214,7 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
llama_kv_cache_update (ctx);
n_past -= n_discard;
n_past = llama_kv_cache_seq_pos_max(ctx, 0);
}
}

View file

@ -2241,6 +2241,18 @@ static void llama_kv_cache_seq_div(
}
}
static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) {
llama_pos result = 0;
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id)) {
result = std::max(result, cache.cells[i].pos);
}
}
return result;
}
//
// model loading and saving
//
@ -12056,6 +12068,10 @@ void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, lla
llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d);
}
llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
}
void llama_kv_cache_update(struct llama_context * ctx) {
llama_kv_cache_update_internal(*ctx);
}

View file

@ -547,6 +547,11 @@ extern "C" {
llama_pos p1,
int d);
// Returns the largest position present in the KV cache for the specified sequence
LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
struct llama_context * ctx,
llama_seq_id seq_id);
// Apply the KV cache updates (such as K-shifts) to the KV data
LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);