From 37264707c2ba7a780f43fa66b157bc9231e0dd22 Mon Sep 17 00:00:00 2001 From: ml6 Date: Mon, 3 Apr 2023 11:17:07 -0700 Subject: [PATCH] Add "-e"/"--eval-threads" command-line parameter to set a different number of threads for single-token eval than for prompt eval. --- examples/common.cpp | 14 ++++++++++++++ examples/common.h | 1 + examples/embedding/embedding.cpp | 2 +- examples/main/main.cpp | 6 +++--- examples/perplexity/perplexity.cpp | 2 +- llama.cpp | 11 +++++++---- llama.h | 3 ++- 7 files changed, 29 insertions(+), 10 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 5400f6b01..6b2d62168 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -37,6 +37,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.n_threads = std::max(1, (int32_t) std::thread::hardware_concurrency()); } + bool ethreads_set = false; bool invalid_param = false; std::string arg; gpt_params default_params; @@ -56,6 +57,13 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.n_threads = std::stoi(argv[i]); + } else if (arg == "-e" || arg == "--eval-threads") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_ethreads = std::stoi(argv[i]); + ethreads_set = true; } else if (arg == "-p" || arg == "--prompt") { if (++i >= argc) { invalid_param = true; @@ -191,6 +199,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { exit(1); } } + + // ensure that n_ethreads defaults to the system thread-max only when n_threads is not set + if (!ethreads_set) { + params.n_ethreads = params.n_threads; + } + if (invalid_param) { fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); gpt_print_usage(argc, argv, default_params); diff --git a/examples/common.h b/examples/common.h index 1505aa927..8022c6268 100644 --- a/examples/common.h +++ b/examples/common.h @@ -16,6 +16,7 @@ struct gpt_params { int32_t seed = -1; // RNG seed int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + int32_t n_ethreads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t n_predict = 128; // new tokens to predict int32_t repeat_last_n = 64; // last n tokens to penalize int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index d397f35fd..950b4478c 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -79,7 +79,7 @@ int main(int argc, char ** argv) { if (params.embedding){ if (embd_inp.size() > 0) { - if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads)) { + if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads, params.n_ethreads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 453450a41..b8cd33b75 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -119,12 +119,12 @@ int main(int argc, char ** argv) { if (params.mem_test) { { const std::vector tmp(params.n_batch, 0); - llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads); + llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads, params.n_ethreads); } { const std::vector tmp = { 0, }; - llama_eval(ctx, tmp.data(), tmp.size(), params.n_predict - 1, params.n_threads); + llama_eval(ctx, tmp.data(), tmp.size(), params.n_predict - 1, params.n_threads, params.n_ethreads); } llama_print_timings(ctx); @@ -264,7 +264,7 @@ int main(int argc, char ** argv) { //printf("\n---\n"); } - if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) { + if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads, params.n_ethreads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 07ed0a829..88a14fbcd 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -38,7 +38,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) { // it is better to always be power of 2 for better performance std::vector embd(tokens.begin() + start, tokens.begin() + end); auto start_t = std::chrono::high_resolution_clock::now(); - if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) { + if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads, params.n_ethreads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return; } diff --git a/llama.cpp b/llama.cpp index 854bb8993..79e4bc7fd 100644 --- a/llama.cpp +++ b/llama.cpp @@ -745,13 +745,15 @@ static bool llama_model_load( // - tokens: new batch of tokens to process // - n_past: the context size so far // - n_threads: number of threads to use +// - n_ethreads: number of threads to use for single-token eval // static bool llama_eval_internal( llama_context & lctx, const llama_token * tokens, const int n_tokens, const int n_past, - const int n_threads) { + const int n_threads, + const int n_ethreads) { const int64_t t_start_us = ggml_time_us(); const int N = n_tokens; @@ -784,7 +786,7 @@ static bool llama_eval_internal( // for big prompts, if BLAS is enabled, it is better to use only one thread // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance ggml_cgraph gf = {}; - gf.n_threads = N >= 32 && ggml_cpu_has_blas() ? 1 : n_threads; + gf.n_threads = N >= 32 && ggml_cpu_has_blas() ? 1 : (N == 1 ? n_ethreads : n_threads); struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); memcpy(embd->data, tokens, N*ggml_element_size(embd)); @@ -1714,8 +1716,9 @@ int llama_eval( const llama_token * tokens, int n_tokens, int n_past, - int n_threads) { - if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads)) { + int n_threads, + int n_ethreads) { + if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, n_ethreads)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } diff --git a/llama.h b/llama.h index 04e2bf71c..14443adca 100644 --- a/llama.h +++ b/llama.h @@ -109,7 +109,8 @@ extern "C" { const llama_token * tokens, int n_tokens, int n_past, - int n_threads); + int n_threads, + int n_ethreads); // Convert the provided text into tokens. // The tokens pointer must be large enough to hold the resulting tokens.