diff --git a/include/llama.h b/include/llama.h index b969689be..3c7f89fe0 100644 --- a/include/llama.h +++ b/include/llama.h @@ -304,8 +304,8 @@ extern "C" { uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode uint32_t n_ubatch; // physical maximum batch size uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) - int n_threads; // number of threads to use for generation - int n_threads_batch; // number of threads to use for batch processing + int32_t n_threads; // number of threads to use for generation + int32_t n_threads_batch; // number of threads to use for batch processing enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id @@ -844,7 +844,7 @@ extern "C" { // Set the number of threads used for decoding // n_threads is the number of threads used for generation (single token) // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) - LLAMA_API void llama_set_n_threads(struct llama_context * ctx, int n_threads, int n_threads_batch); + LLAMA_API void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch); // Get the number of threads used for generation of a single token. LLAMA_API int llama_n_threads(struct llama_context * ctx); diff --git a/src/llama.cpp b/src/llama.cpp index fe1942c5d..fb5b76ebf 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -19389,7 +19389,7 @@ size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepa } } -void llama_set_n_threads(struct llama_context * ctx, int n_threads, int n_threads_batch) { +void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) { ctx->cparams.n_threads = n_threads; ctx->cparams.n_threads_batch = n_threads_batch; }