add llama API functions to get grouped-query-attention n_head parameter 'n_head_kv'.

This commit is contained in:
xaedes 2023-09-09 17:07:54 +02:00
parent d7aade7d8a
commit 833a56c144
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 17 additions and 7 deletions

View file

@ -5669,6 +5669,10 @@ int llama_n_head(const struct llama_context * ctx) {
return ctx->model.hparams.n_head; return ctx->model.hparams.n_head;
} }
int llama_n_head_kv(const struct llama_context * ctx) {
return ctx->model.hparams.n_head_kv;
}
int llama_n_rot(const struct llama_context * ctx) { int llama_n_rot(const struct llama_context * ctx) {
return ctx->model.hparams.n_rot; return ctx->model.hparams.n_rot;
} }
@ -5701,6 +5705,10 @@ int llama_model_n_head(const struct llama_model * model) {
return model->hparams.n_head; return model->hparams.n_head;
} }
int llama_model_n_head_kv(const struct llama_model * model) {
return model->hparams.n_head_kv;
}
int llama_model_n_rot(const struct llama_model * model) { int llama_model_n_rot(const struct llama_model * model) {
return model->hparams.n_rot; return model->hparams.n_rot;
} }

View file

@ -250,6 +250,7 @@ extern "C" {
LLAMA_API int llama_n_embd (const struct llama_context * ctx); LLAMA_API int llama_n_embd (const struct llama_context * ctx);
LLAMA_API int llama_n_ff (const struct llama_context * ctx); LLAMA_API int llama_n_ff (const struct llama_context * ctx);
LLAMA_API int llama_n_head (const struct llama_context * ctx); LLAMA_API int llama_n_head (const struct llama_context * ctx);
LLAMA_API int llama_n_head_kv(const struct llama_context * ctx);
LLAMA_API int llama_n_rot (const struct llama_context * ctx); LLAMA_API int llama_n_rot (const struct llama_context * ctx);
LLAMA_API int llama_n_layer (const struct llama_context * ctx); LLAMA_API int llama_n_layer (const struct llama_context * ctx);
@ -260,6 +261,7 @@ extern "C" {
LLAMA_API int llama_model_n_embd (const struct llama_model * model); LLAMA_API int llama_model_n_embd (const struct llama_model * model);
LLAMA_API int llama_model_n_ff (const struct llama_model * model); LLAMA_API int llama_model_n_ff (const struct llama_model * model);
LLAMA_API int llama_model_n_head (const struct llama_model * model); LLAMA_API int llama_model_n_head (const struct llama_model * model);
LLAMA_API int llama_model_n_head_kv(const struct llama_model * model);
LLAMA_API int llama_model_n_rot (const struct llama_model * model); LLAMA_API int llama_model_n_rot (const struct llama_model * model);
LLAMA_API int llama_model_n_layer(const struct llama_model * model); LLAMA_API int llama_model_n_layer(const struct llama_model * model);