diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index b20101353..b6b1efe02 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1537,7 +1537,7 @@ int main(int argc, char ** argv) { exit(1); } - llama_attach_threadpool(ctx, threadpool); + llama_attach_threadpool(ctx, threadpool, NULL); // warmup run if (t.n_prompt > 0) { diff --git a/examples/main/main.cpp b/examples/main/main.cpp index bdaf0dbb6..0ccd0558f 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -240,11 +240,6 @@ int main(int argc, char ** argv) { exit(1); } - llama_attach_batch_threadpool(ctx, threadpool_batch); - if (ctx_guidance) { - llama_attach_batch_threadpool(ctx_guidance, threadpool_batch); - } - // Start the non-batch threadpool in the paused state tpp.paused = true; } @@ -255,9 +250,9 @@ int main(int argc, char ** argv) { exit(1); } - llama_attach_threadpool(ctx, threadpool); + llama_attach_threadpool(ctx, threadpool, threadpool_batch); if (ctx_guidance) { - llama_attach_threadpool(ctx_guidance, threadpool); + llama_attach_threadpool(ctx_guidance, threadpool, threadpool_batch); } const int n_ctx_train = llama_n_ctx_train(model); diff --git a/ggml/src/ggml-backend.c b/ggml/src/ggml-backend.c index 826b99ac0..03e41a09c 100644 --- a/ggml/src/ggml-backend.c +++ b/ggml/src/ggml-backend.c @@ -910,6 +910,11 @@ void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_compute_th GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; + + if (ctx->threadpool && ctx->threadpool != threadpool) { + // already had a different threadpool, pause/suspend it before switching + ggml_pause_threadpool(ctx->threadpool); + } ctx->threadpool = threadpool; } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index dd08b77f8..f05f89a27 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -19198,9 +19198,6 @@ static thread_ret_t ggml_graph_compute_secondary_thread(void* data) { state->pending = false; ggml_graph_compute_thread(state); - if (state->threadpool->ec != GGML_STATUS_SUCCESS) { - break; - } } } diff --git a/include/llama.h b/include/llama.h index 7b103261d..c03c4929b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -431,16 +431,9 @@ extern "C" { // Optional: an auto threadpool gets created in ggml if not passed explicitly LLAMA_API void llama_attach_threadpool( struct llama_context * ctx, - ggml_compute_threadpool_t threadpool); - LLAMA_API void llama_attach_batch_threadpool( - struct llama_context * ctx, - ggml_compute_threadpool_t threadpool); + ggml_compute_threadpool_t threadpool, + ggml_compute_threadpool_t threadpool_batch); LLAMA_API void llama_detach_threadpool(struct llama_context * ctx); - LLAMA_API void llama_detach_batch_threadpool(struct llama_context * ctx); - LLAMA_API void llama_detach_threadpools(struct llama_context * ctx); - - // Pauses all attached threadpools - LLAMA_API void llama_pause_threadpools(struct llama_context * ctx); // Call once at the end of the program - currently only used for MPI LLAMA_API void llama_backend_free(void); diff --git a/src/llama.cpp b/src/llama.cpp index 916d0f8c1..57e765ce0 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -15523,39 +15523,6 @@ static void llama_graph_compute( // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched)); } -// Optionally swaps the batch and single-tok threadpools. -// Returns the number of threads, and if a valid threadpool exists, returns it too. -static std::pair llama_swap_threadpools( - llama_context & lctx, - int32_t n_tokens) { - - const auto & cparams = lctx.cparams; - int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; - - ggml_compute_threadpool_t threadpool = nullptr; // nullptr -> disposable threadpool - - // A batch threadpool without a non-batch threadpool isn't supported. - GGML_ASSERT(!lctx.threadpool_batch || lctx.threadpool); - - if (lctx.threadpool_batch && lctx.threadpool) { - // Switch between the 2 threadpools as needed - if (n_tokens > 1) { - ggml_pause_threadpool(lctx.threadpool); - threadpool = lctx.threadpool_batch; - n_threads = cparams.n_threads_batch; - } else { - ggml_pause_threadpool(lctx.threadpool_batch); - threadpool = lctx.threadpool; - n_threads = cparams.n_threads; - } - } else if (lctx.threadpool) { - threadpool = lctx.threadpool; - n_threads = cparams.n_threads; - } - return std::make_pair(n_threads, threadpool); -} - - // decode a batch of tokens by evaluating the transformer // // - lctx: llama context @@ -15662,11 +15629,8 @@ static int llama_decode_internal( lctx.n_outputs = n_outputs_new; } - std::pair threads = - llama_swap_threadpools(lctx, n_tokens); - - int n_threads = threads.first; - ggml_compute_threadpool_t threadpool = threads.second; + int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; + ggml_compute_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch; GGML_ASSERT(n_threads > 0); @@ -15906,11 +15870,9 @@ static int llama_encode_internal( lctx.inp_embd_enc = NULL; lctx.n_outputs = n_tokens; - std::pair threads = - llama_swap_threadpools(lctx, n_tokens); + int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; + ggml_compute_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch; - int n_threads = threads.first; - ggml_compute_threadpool_t threadpool = threads.second; GGML_ASSERT(n_threads > 0); ggml_backend_sched_reset(lctx.sched); @@ -17500,36 +17462,15 @@ void llama_numa_init(enum ggml_numa_strategy numa) { void llama_attach_threadpool( struct llama_context * ctx, - ggml_compute_threadpool_t threadpool) { - ctx->threadpool = threadpool; -} - -void llama_attach_batch_threadpool( - struct llama_context * ctx, + ggml_compute_threadpool_t threadpool, ggml_compute_threadpool_t threadpool_batch) { - ctx->threadpool_batch = threadpool_batch; + ctx->threadpool = threadpool; + ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool; } void llama_detach_threadpool(struct llama_context * ctx) { - ctx->threadpool = nullptr; -} - -void llama_detach_batch_threadpool(struct llama_context * ctx) { - ctx->threadpool = nullptr; -} - -void llama_detach_threadpools(struct llama_context * ctx) { - llama_detach_threadpool(ctx); - llama_detach_batch_threadpool(ctx); -} - -void llama_pause_threadpools(struct llama_context * ctx) { - if (ctx->threadpool) { - ggml_pause_threadpool(ctx->threadpool); - } - if (ctx->threadpool_batch) { - ggml_pause_threadpool(ctx->threadpool_batch); - } + ctx->threadpool = nullptr; + ctx->threadpool_batch = nullptr; } void llama_backend_free(void) {