use int32_t for n_thread type in public llama.cpp API

This commit is contained in:
Max Krasnyansky 2024-08-28 21:17:11 -07:00
parent b97bd67e2b
commit cae35b9fb9
2 changed files with 4 additions and 4 deletions

View file

@ -304,8 +304,8 @@ extern "C" {
uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
uint32_t n_ubatch; // physical maximum batch size
uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
int n_threads; // number of threads to use for generation
int n_threads_batch; // number of threads to use for batch processing
int32_t n_threads; // number of threads to use for generation
int32_t n_threads_batch; // number of threads to use for batch processing
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
@ -844,7 +844,7 @@ extern "C" {
// Set the number of threads used for decoding
// n_threads is the number of threads used for generation (single token)
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, int n_threads, int n_threads_batch);
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch);
// Get the number of threads used for generation of a single token.
LLAMA_API int llama_n_threads(struct llama_context * ctx);

View file

@ -19389,7 +19389,7 @@ size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepa
}
}
void llama_set_n_threads(struct llama_context * ctx, int n_threads, int n_threads_batch) {
void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) {
ctx->cparams.n_threads = n_threads;
ctx->cparams.n_threads_batch = n_threads_batch;
}