additional methods to read model and ctx parameters
This commit is contained in:
parent
89fb735fcf
commit
e700b44217
2 changed files with 19 additions and 0 deletions
15
llama.cpp
15
llama.cpp
|
@ -12496,6 +12496,14 @@ int32_t llama_n_embd(const struct llama_model * model) {
|
|||
return model->hparams.n_embd;
|
||||
}
|
||||
|
||||
int32_t llama_n_layers(const struct llama_model * model) {
|
||||
return model->hparams.n_layer;
|
||||
}
|
||||
|
||||
int32_t llama_n_heads(const struct llama_model * model) {
|
||||
return model->hparams.n_head;
|
||||
}
|
||||
|
||||
float llama_rope_freq_scale_train(const struct llama_model * model) {
|
||||
return model->hparams.rope_freq_scale_train;
|
||||
}
|
||||
|
@ -13153,6 +13161,13 @@ void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_
|
|||
ctx->cparams.n_threads_batch = n_threads_batch;
|
||||
}
|
||||
|
||||
void llama_get_n_threads(struct llama_context * ctx, uint32_t * n_threads, uint32_t * n_threads_batch) {
|
||||
assert(n_threads);
|
||||
assert(n_threads_batch);
|
||||
*n_threads = ctx->cparams.n_threads;
|
||||
*n_threads_batch = ctx->cparams.n_threads_batch;
|
||||
}
|
||||
|
||||
void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
|
||||
ctx->abort_callback = abort_callback;
|
||||
ctx->abort_callback_data = abort_callback_data;
|
||||
|
|
4
llama.h
4
llama.h
|
@ -383,6 +383,8 @@ extern "C" {
|
|||
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_n_layers (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_n_heads (const struct llama_model * model);
|
||||
|
||||
// Get the model's RoPE frequency scaling factor
|
||||
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
|
||||
|
@ -640,6 +642,8 @@ extern "C" {
|
|||
// n_threads is the number of threads used for generation (single token)
|
||||
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
|
||||
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
|
||||
// Get the number of threads used for decoding
|
||||
LLAMA_API void llama_get_n_threads(struct llama_context * ctx, uint32_t * n_threads, uint32_t * n_threads_batch);
|
||||
|
||||
// Set abort callback
|
||||
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue