llama : add llama_kv_cache_seq_pos_max()
This commit is contained in:
parent
18da970e1c
commit
b75ec64ed2
3 changed files with 24 additions and 3 deletions
|
@ -150,7 +150,7 @@ int main(int argc, char ** argv) {
|
|||
llama_kv_cache_seq_div(ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
|
||||
llama_kv_cache_update (ctx);
|
||||
|
||||
n_past -= bd;
|
||||
n_past = llama_kv_cache_seq_pos_max(ctx, 0);
|
||||
}
|
||||
|
||||
llama_batch_clear(batch);
|
||||
|
@ -184,7 +184,7 @@ int main(int argc, char ** argv) {
|
|||
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||
llama_kv_cache_update (ctx);
|
||||
|
||||
n_past -= n_discard;
|
||||
n_past = llama_kv_cache_seq_pos_max(ctx, 0);
|
||||
|
||||
llama_batch_clear(batch);
|
||||
|
||||
|
@ -214,7 +214,7 @@ int main(int argc, char ** argv) {
|
|||
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||
llama_kv_cache_update (ctx);
|
||||
|
||||
n_past -= n_discard;
|
||||
n_past = llama_kv_cache_seq_pos_max(ctx, 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
16
llama.cpp
16
llama.cpp
|
@ -2241,6 +2241,18 @@ static void llama_kv_cache_seq_div(
|
|||
}
|
||||
}
|
||||
|
||||
static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) {
|
||||
llama_pos result = 0;
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.cells[i].has_seq_id(seq_id)) {
|
||||
result = std::max(result, cache.cells[i].pos);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// model loading and saving
|
||||
//
|
||||
|
@ -12056,6 +12068,10 @@ void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, lla
|
|||
llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
|
||||
}
|
||||
|
||||
void llama_kv_cache_update(struct llama_context * ctx) {
|
||||
llama_kv_cache_update_internal(*ctx);
|
||||
}
|
||||
|
|
5
llama.h
5
llama.h
|
@ -547,6 +547,11 @@ extern "C" {
|
|||
llama_pos p1,
|
||||
int d);
|
||||
|
||||
// Returns the largest position present in the KV cache for the specified sequence
|
||||
LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Apply the KV cache updates (such as K-shifts) to the KV data
|
||||
LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue