threadpool: add persistent threadpool for llama-bench
This commit is contained in:
parent
e77167446a
commit
c1e7f488c3
1 changed files with 27 additions and 10 deletions
|
@ -1210,8 +1210,7 @@ struct sql_printer : public printer {
|
|||
}
|
||||
};
|
||||
|
||||
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
|
||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch) {
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const int32_t n_vocab = llama_n_vocab(model);
|
||||
|
@ -1233,9 +1232,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
|
|||
llama_synchronize(ctx);
|
||||
}
|
||||
|
||||
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
|
||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||
|
||||
static void test_gen(llama_context * ctx, int n_gen, int n_past) {
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const int32_t n_vocab = llama_n_vocab(model);
|
||||
|
||||
|
@ -1332,13 +1329,31 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
struct ggml_threadpool_params tpp;
|
||||
tpp.n_threads = t.n_threads;
|
||||
|
||||
// TODO: expose these via cli opts
|
||||
tpp.mask_specified = false;
|
||||
tpp.strict_cpu = false;
|
||||
tpp.prio = 1;
|
||||
tpp.poll = false;
|
||||
|
||||
struct ggml_compute_threadpool * threadpool = ggml_create_threadpool(&tpp);
|
||||
if (!threadpool) {
|
||||
LOG_TEE("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
llama_set_n_threads(ctx, t.n_threads, t.n_threads);
|
||||
llama_attach_threadpool(ctx, threadpool);
|
||||
|
||||
// warmup run
|
||||
if (t.n_prompt > 0) {
|
||||
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
|
||||
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
|
||||
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch);
|
||||
test_prompt(ctx, t.n_prompt, 0, t.n_batch);
|
||||
}
|
||||
if (t.n_gen > 0) {
|
||||
test_gen(ctx, 1, 0, t.n_threads);
|
||||
test_gen(ctx, 1, 0);
|
||||
}
|
||||
|
||||
for (int i = 0; i < params.reps; i++) {
|
||||
|
@ -1347,10 +1362,10 @@ int main(int argc, char ** argv) {
|
|||
uint64_t t_start = get_time_ns();
|
||||
|
||||
if (t.n_prompt > 0) {
|
||||
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
|
||||
test_prompt(ctx, t.n_prompt, 0, t.n_batch);
|
||||
}
|
||||
if (t.n_gen > 0) {
|
||||
test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
|
||||
test_gen(ctx, t.n_gen, t.n_prompt);
|
||||
}
|
||||
|
||||
uint64_t t_ns = get_time_ns() - t_start;
|
||||
|
@ -1362,6 +1377,8 @@ int main(int argc, char ** argv) {
|
|||
llama_print_timings(ctx);
|
||||
|
||||
llama_free(ctx);
|
||||
|
||||
ggml_release_threadpool(threadpool);
|
||||
}
|
||||
|
||||
llama_free_model(lmodel);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue