diff --git a/llama.cpp b/llama.cpp index 4225f9555..4f92488d1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12496,6 +12496,14 @@ int32_t llama_n_embd(const struct llama_model * model) { return model->hparams.n_embd; } +int32_t llama_n_layers(const struct llama_model * model) { + return model->hparams.n_layer; +} + +int32_t llama_n_heads(const struct llama_model * model) { + return model->hparams.n_head; +} + float llama_rope_freq_scale_train(const struct llama_model * model) { return model->hparams.rope_freq_scale_train; } @@ -13153,6 +13161,13 @@ void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_ ctx->cparams.n_threads_batch = n_threads_batch; } +void llama_get_n_threads(struct llama_context * ctx, uint32_t * n_threads, uint32_t * n_threads_batch) { + assert(n_threads); + assert(n_threads_batch); + *n_threads = ctx->cparams.n_threads; + *n_threads_batch = ctx->cparams.n_threads_batch; +} + void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) { ctx->abort_callback = abort_callback; ctx->abort_callback_data = abort_callback_data; diff --git a/llama.h b/llama.h index 3dc162b07..3eabaa9b6 100644 --- a/llama.h +++ b/llama.h @@ -383,6 +383,8 @@ extern "C" { LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); LLAMA_API int32_t llama_n_embd (const struct llama_model * model); + LLAMA_API int32_t llama_n_layers (const struct llama_model * model); + LLAMA_API int32_t llama_n_heads (const struct llama_model * model); // Get the model's RoPE frequency scaling factor LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); @@ -640,6 +642,8 @@ extern "C" { // n_threads is the number of threads used for generation (single token) // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch); + // Get the number of threads used for decoding + LLAMA_API void llama_get_n_threads(struct llama_context * ctx, uint32_t * n_threads, uint32_t * n_threads_batch); // Set abort callback LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);