diff --git a/llama.cpp b/llama.cpp index 705b19651..55aa955d7 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5669,6 +5669,10 @@ int llama_n_head(const struct llama_context * ctx) { return ctx->model.hparams.n_head; } +int llama_n_head_kv(const struct llama_context * ctx) { + return ctx->model.hparams.n_head_kv; +} + int llama_n_rot(const struct llama_context * ctx) { return ctx->model.hparams.n_rot; } @@ -5701,6 +5705,10 @@ int llama_model_n_head(const struct llama_model * model) { return model->hparams.n_head; } +int llama_model_n_head_kv(const struct llama_model * model) { + return model->hparams.n_head_kv; +} + int llama_model_n_rot(const struct llama_model * model) { return model->hparams.n_rot; } diff --git a/llama.h b/llama.h index f702c54d8..c930a48d0 100644 --- a/llama.h +++ b/llama.h @@ -245,13 +245,14 @@ extern "C" { LLAMA_API bool llama_mmap_supported (void); LLAMA_API bool llama_mlock_supported(void); - LLAMA_API int llama_n_vocab(const struct llama_context * ctx); - LLAMA_API int llama_n_ctx (const struct llama_context * ctx); - LLAMA_API int llama_n_embd (const struct llama_context * ctx); - LLAMA_API int llama_n_ff (const struct llama_context * ctx); - LLAMA_API int llama_n_head (const struct llama_context * ctx); - LLAMA_API int llama_n_rot (const struct llama_context * ctx); - LLAMA_API int llama_n_layer(const struct llama_context * ctx); + LLAMA_API int llama_n_vocab (const struct llama_context * ctx); + LLAMA_API int llama_n_ctx (const struct llama_context * ctx); + LLAMA_API int llama_n_embd (const struct llama_context * ctx); + LLAMA_API int llama_n_ff (const struct llama_context * ctx); + LLAMA_API int llama_n_head (const struct llama_context * ctx); + LLAMA_API int llama_n_head_kv(const struct llama_context * ctx); + LLAMA_API int llama_n_rot (const struct llama_context * ctx); + LLAMA_API int llama_n_layer (const struct llama_context * ctx); LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx); @@ -260,6 +261,7 @@ extern "C" { LLAMA_API int llama_model_n_embd (const struct llama_model * model); LLAMA_API int llama_model_n_ff (const struct llama_model * model); LLAMA_API int llama_model_n_head (const struct llama_model * model); + LLAMA_API int llama_model_n_head_kv(const struct llama_model * model); LLAMA_API int llama_model_n_rot (const struct llama_model * model); LLAMA_API int llama_model_n_layer(const struct llama_model * model);