diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 2afdb3abd..81b84eb44 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1210,8 +1210,7 @@ struct sql_printer : public printer { } }; -static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) { - llama_set_n_threads(ctx, n_threads, n_threads); +static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch) { const llama_model * model = llama_get_model(ctx); const int32_t n_vocab = llama_n_vocab(model); @@ -1233,9 +1232,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat llama_synchronize(ctx); } -static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { - llama_set_n_threads(ctx, n_threads, n_threads); - +static void test_gen(llama_context * ctx, int n_gen, int n_past) { const llama_model * model = llama_get_model(ctx); const int32_t n_vocab = llama_n_vocab(model); @@ -1332,13 +1329,31 @@ int main(int argc, char ** argv) { llama_kv_cache_clear(ctx); + struct ggml_threadpool_params tpp; + tpp.n_threads = t.n_threads; + + // TODO: expose these via cli opts + tpp.mask_specified = false; + tpp.strict_cpu = false; + tpp.prio = 1; + tpp.poll = false; + + struct ggml_compute_threadpool * threadpool = ggml_create_threadpool(&tpp); + if (!threadpool) { + LOG_TEE("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads); + exit(1); + } + + llama_set_n_threads(ctx, t.n_threads, t.n_threads); + llama_attach_threadpool(ctx, threadpool); + // warmup run if (t.n_prompt > 0) { - //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads); - test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); + //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch); + test_prompt(ctx, t.n_prompt, 0, t.n_batch); } if (t.n_gen > 0) { - test_gen(ctx, 1, 0, t.n_threads); + test_gen(ctx, 1, 0); } for (int i = 0; i < params.reps; i++) { @@ -1347,10 +1362,10 @@ int main(int argc, char ** argv) { uint64_t t_start = get_time_ns(); if (t.n_prompt > 0) { - test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); + test_prompt(ctx, t.n_prompt, 0, t.n_batch); } if (t.n_gen > 0) { - test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads); + test_gen(ctx, t.n_gen, t.n_prompt); } uint64_t t_ns = get_time_ns() - t_start; @@ -1362,6 +1377,8 @@ int main(int argc, char ** argv) { llama_print_timings(ctx); llama_free(ctx); + + ggml_release_threadpool(threadpool); } llama_free_model(lmodel);