threadpool: do not create two threadpools if their params are identical

This commit is contained in:
Max Krasnyansky 2024-08-08 16:26:49 -07:00 committed by fmz
parent 2e18f0d4c9
commit 48aa8eec07
3 changed files with 28 additions and 7 deletions

View file

@ -230,24 +230,31 @@ int main(int argc, char ** argv) {
struct ggml_threadpool_params tpp =
ggml_threadpool_params_from_cpu_params(params.cpuparams);
struct ggml_compute_threadpool * threadpool_batch = ggml_create_threadpool(&tpp_batch);
if (!threadpool_batch) {
LOG_TEE("%s: batch threadpool create failed : n_threads %d\n", __func__, tpp_batch.n_threads);
exit(1);
}
struct ggml_compute_threadpool * threadpool = ggml_create_threadpool(&tpp);
if (!threadpool) {
LOG_TEE("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads);
exit(1);
}
llama_attach_batch_threadpool(ctx, threadpool_batch);
llama_attach_threadpool(ctx, threadpool);
if (ctx_guidance) {
llama_attach_batch_threadpool(ctx_guidance, threadpool_batch);
llama_attach_threadpool(ctx_guidance, threadpool);
}
struct ggml_compute_threadpool * threadpool_batch = NULL;
if (!ggml_threadpool_params_match(&tpp, &tpp_batch)) {
threadpool_batch = ggml_create_threadpool(&tpp_batch);
if (!threadpool_batch) {
LOG_TEE("%s: batch threadpool create failed : n_threads %d\n", __func__, tpp_batch.n_threads);
exit(1);
}
llama_attach_batch_threadpool(ctx, threadpool_batch);
if (ctx_guidance) {
llama_attach_batch_threadpool(ctx_guidance, threadpool_batch);
}
}
const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
LOG("n_ctx: %d\n", n_ctx);

View file

@ -2024,6 +2024,7 @@ extern "C" {
GGML_API size_t ggml_graph_overhead(void);
GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads);
GGML_API bool ggml_threadpool_params_match (const struct ggml_threadpool_params *p0, const struct ggml_threadpool_params *p1);
GGML_API struct ggml_compute_threadpool* ggml_create_threadpool (struct ggml_threadpool_params * params);
GGML_API void ggml_release_threadpool (struct ggml_compute_threadpool * threadpool);
GGML_API int32_t ggml_threadpool_get_n_threads(struct ggml_compute_threadpool * threadpool);

View file

@ -19222,6 +19222,19 @@ static thread_ret_t ggml_graph_compute_secondary_thread(void* data) {
#endif // GGML_USE_OPENMP
bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) {
if (p0->n_threads != p1->n_threads ) return false;
if (p0->prio != p1->prio ) return false;
if (p0->poll != p1->poll ) return false;
if (p0->strict_cpu != p1->strict_cpu ) return false;
if (p0->mask_specified != p1->mask_specified) return false;
if (p0->mask_specified) {
return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0;
}
return true;
}
static struct ggml_compute_threadpool * ggml_create_threadpool_impl(
struct ggml_threadpool_params * tpp,
bool disposable,