llama : update llama_kv_self API

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-14 16:47:34 +02:00
parent fd05ab87aa
commit 17b363afd3
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
30 changed files with 387 additions and 205 deletions

View file

@ -469,7 +469,7 @@ extern "C" {
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); // TODO: remove const?
LLAMA_API struct llama_kv_cache * llama_get_kv_cache( struct llama_context * ctx);
LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
@ -641,28 +641,28 @@ extern "C" {
// Returns the number of tokens in the KV cache (slow, use only for debug)
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
LLAMA_API int32_t llama_kv_cache_n_tokens(const struct llama_kv_cache * kv);
LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx);
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx),
"use llama_kv_cache_n_tokens instead");
"use llama_kv_self_n_tokens instead");
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
LLAMA_API int32_t llama_kv_cache_used_cells(const struct llama_kv_cache * kv);
LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx);
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx),
"use llama_kv_cache_used_cells instead");
"use llama_kv_self_used_cells instead");
// Clear the KV cache - both cell info is erased and KV data is zeroed
LLAMA_API void llama_kv_cache_clear(
struct llama_kv_cache * kv);
LLAMA_API void llama_kv_self_clear(
struct llama_context * ctx);
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
// seq_id < 0 : match any sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API bool llama_kv_cache_seq_rm(
struct llama_kv_cache * kv,
LLAMA_API bool llama_kv_self_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1);
@ -671,26 +671,26 @@ extern "C" {
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_cp(
struct llama_kv_cache * kv,
LLAMA_API void llama_kv_self_seq_cp(
struct llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1);
// Removes all tokens that do not belong to the specified sequence
LLAMA_API void llama_kv_cache_seq_keep(
struct llama_kv_cache * kv,
LLAMA_API void llama_kv_self_seq_keep(
struct llama_context * ctx,
llama_seq_id seq_id);
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// If the KV cache is RoPEd, the KV data is updated accordingly:
// - lazily on next llama_decode()
// - explicitly with llama_kv_cache_update()
// - explicitly with llama_kv_self_update()
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_add(
struct llama_kv_cache * kv,
LLAMA_API void llama_kv_self_seq_add(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
@ -699,32 +699,87 @@ extern "C" {
// Integer division of the positions by factor of `d > 1`
// If the KV cache is RoPEd, the KV data is updated accordingly:
// - lazily on next llama_decode()
// - explicitly with llama_kv_cache_update()
// - explicitly with llama_kv_self_update()
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_div(
struct llama_kv_cache * kv,
LLAMA_API void llama_kv_self_seq_div(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d);
// Returns the largest position present in the KV cache for the specified sequence
LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
struct llama_kv_cache * kv,
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
struct llama_context * ctx,
llama_seq_id seq_id);
// Defragment the KV cache
// This will be applied:
// - lazily on next llama_decode()
// - explicitly with llama_kv_cache_update()
LLAMA_API void llama_kv_cache_defrag(struct llama_kv_cache * kv);
// - explicitly with llama_kv_self_update()
LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);
// Check if the context supports KV cache shifting
LLAMA_API bool llama_kv_cache_can_shift(const struct llama_kv_cache * kv);
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
LLAMA_API void llama_update_kv_cache(struct llama_context * ctx, struct llama_kv_cache * kv);
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
DEPRECATED(LLAMA_API void llama_kv_cache_clear(
struct llama_context * ctx),
"use llama_kv_self_clear instead");
DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1),
"use llama_kv_self_seq_rm instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp(
struct llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1),
"use llama_kv_self_seq_cp instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep(
struct llama_context * ctx,
llama_seq_id seq_id),
"use llama_kv_self_seq_keep instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_add(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta),
"use llama_kv_self_seq_add instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_div(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d),
"use llama_kv_self_seq_div instead");
DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
struct llama_context * ctx,
llama_seq_id seq_id),
"use llama_kv_self_seq_pos_max instead");
DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx),
"use llama_kv_self_defrag instead");
DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx),
"use llama_kv_self_can_shift instead");
DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx),
"use llama_kv_self_update instead");
//
// State / sessions