additional methods to read model and ctx parameters

This commit is contained in:
Michael Podvitskiy 2024-02-26 11:58:25 +01:00
parent 89fb735fcf
commit e700b44217
2 changed files with 19 additions and 0 deletions

View file

@ -12496,6 +12496,14 @@ int32_t llama_n_embd(const struct llama_model * model) {
return model->hparams.n_embd;
}
int32_t llama_n_layers(const struct llama_model * model) {
return model->hparams.n_layer;
}
int32_t llama_n_heads(const struct llama_model * model) {
return model->hparams.n_head;
}
float llama_rope_freq_scale_train(const struct llama_model * model) {
return model->hparams.rope_freq_scale_train;
}
@ -13153,6 +13161,13 @@ void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_
ctx->cparams.n_threads_batch = n_threads_batch;
}
void llama_get_n_threads(struct llama_context * ctx, uint32_t * n_threads, uint32_t * n_threads_batch) {
assert(n_threads);
assert(n_threads_batch);
*n_threads = ctx->cparams.n_threads;
*n_threads_batch = ctx->cparams.n_threads_batch;
}
void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
ctx->abort_callback = abort_callback;
ctx->abort_callback_data = abort_callback_data;

View file

@ -383,6 +383,8 @@ extern "C" {
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
LLAMA_API int32_t llama_n_layers (const struct llama_model * model);
LLAMA_API int32_t llama_n_heads (const struct llama_model * model);
// Get the model's RoPE frequency scaling factor
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
@ -640,6 +642,8 @@ extern "C" {
// n_threads is the number of threads used for generation (single token)
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
// Get the number of threads used for decoding
LLAMA_API void llama_get_n_threads(struct llama_context * ctx, uint32_t * n_threads, uint32_t * n_threads_batch);
// Set abort callback
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);