diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 1628a42a9..5e560289d 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -230,24 +230,31 @@ int main(int argc, char ** argv) { struct ggml_threadpool_params tpp = ggml_threadpool_params_from_cpu_params(params.cpuparams); - struct ggml_compute_threadpool * threadpool_batch = ggml_create_threadpool(&tpp_batch); - if (!threadpool_batch) { - LOG_TEE("%s: batch threadpool create failed : n_threads %d\n", __func__, tpp_batch.n_threads); - exit(1); - } struct ggml_compute_threadpool * threadpool = ggml_create_threadpool(&tpp); if (!threadpool) { LOG_TEE("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads); exit(1); } - llama_attach_batch_threadpool(ctx, threadpool_batch); llama_attach_threadpool(ctx, threadpool); if (ctx_guidance) { - llama_attach_batch_threadpool(ctx_guidance, threadpool_batch); llama_attach_threadpool(ctx_guidance, threadpool); } + struct ggml_compute_threadpool * threadpool_batch = NULL; + if (!ggml_threadpool_params_match(&tpp, &tpp_batch)) { + threadpool_batch = ggml_create_threadpool(&tpp_batch); + if (!threadpool_batch) { + LOG_TEE("%s: batch threadpool create failed : n_threads %d\n", __func__, tpp_batch.n_threads); + exit(1); + } + + llama_attach_batch_threadpool(ctx, threadpool_batch); + if (ctx_guidance) { + llama_attach_batch_threadpool(ctx_guidance, threadpool_batch); + } + } + const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); LOG("n_ctx: %d\n", n_ctx); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 9b29a3af7..173b3b22e 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2024,6 +2024,7 @@ extern "C" { GGML_API size_t ggml_graph_overhead(void); GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads); + GGML_API bool ggml_threadpool_params_match (const struct ggml_threadpool_params *p0, const struct ggml_threadpool_params *p1); GGML_API struct ggml_compute_threadpool* ggml_create_threadpool (struct ggml_threadpool_params * params); GGML_API void ggml_release_threadpool (struct ggml_compute_threadpool * threadpool); GGML_API int32_t ggml_threadpool_get_n_threads(struct ggml_compute_threadpool * threadpool); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 3dbc1244c..5a6d313ae 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -19222,6 +19222,19 @@ static thread_ret_t ggml_graph_compute_secondary_thread(void* data) { #endif // GGML_USE_OPENMP +bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) { + if (p0->n_threads != p1->n_threads ) return false; + if (p0->prio != p1->prio ) return false; + if (p0->poll != p1->poll ) return false; + if (p0->strict_cpu != p1->strict_cpu ) return false; + if (p0->mask_specified != p1->mask_specified) return false; + if (p0->mask_specified) { + return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0; + } + + return true; +} + static struct ggml_compute_threadpool * ggml_create_threadpool_impl( struct ggml_threadpool_params * tpp, bool disposable,