From be26777a6aee651131c7cca2e0aa3e2b9a14a772 Mon Sep 17 00:00:00 2001 From: netrunnereve <139727413+netrunnereve@users.noreply.github.com> Date: Tue, 8 Aug 2023 22:19:59 -0400 Subject: [PATCH] add pp_threads support to other files --- examples/embedding/embedding.cpp | 6 +++--- examples/save-load-state/save-load-state.cpp | 7 ++++--- examples/server/server.cpp | 14 +++++++++++++- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 6f732ce46..58fa2753e 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -50,8 +50,8 @@ int main(int argc, char ** argv) { // print system information { fprintf(stderr, "\n"); - fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", - params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); + fprintf(stderr, "system_info: n_threads = %d / %d | pp_threads = %d / %d | %s\n", + params.n_threads, std::thread::hardware_concurrency(), params.pp_threads, std::thread::hardware_concurrency(), llama_print_system_info()); } int n_past = 0; @@ -74,7 +74,7 @@ int main(int argc, char ** argv) { if (params.embedding){ if (embd_inp.size() > 0) { - if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads, params.n_threads)) { + if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads, params.pp_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index c8586718e..2a466c2e2 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -10,6 +10,7 @@ int main(int argc, char ** argv) { gpt_params params; params.seed = 42; params.n_threads = 4; + params.pp_threads = 4; params.repeat_last_n = 64; params.prompt = "The quick brown fox"; @@ -56,7 +57,7 @@ int main(int argc, char ** argv) { } // evaluate prompt - llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads, params.n_threads); + llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads, params.pp_threads); last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens); n_past += n_prompt_tokens; @@ -93,7 +94,7 @@ int main(int argc, char ** argv) { last_n_tokens_data.push_back(next_token); printf("%s", next_token_str); - if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads, params.n_threads)) { + if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads, params.pp_threads)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx); llama_free_model(model); @@ -153,7 +154,7 @@ int main(int argc, char ** argv) { last_n_tokens_data.push_back(next_token); printf("%s", next_token_str); - if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads, params.n_threads)) { + if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads, params.pp_threads)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx2); llama_free_model(model); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index fe91b2343..d86bab5b3 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -382,7 +382,7 @@ struct llama_server_context { n_eval = params.n_batch; } - if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads, params.n_threads)) + if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads, params.pp_threads)) { LOG_ERROR("failed to eval", { {"n_eval", n_eval}, @@ -648,6 +648,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, fprintf(stdout, " -h, --help show this help message and exit\n"); fprintf(stdout, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled"); fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); + fprintf(stdout, " -ppt N, --pp-threads N\n"); + fprintf(stdout, " number of threads to use during prompt processing (default: %d)\n", params.pp_threads); fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa); fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps); @@ -818,6 +820,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.n_threads = std::stoi(argv[i]); } + else if (arg == "-ppt" || arg == "--pp-threads") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + params.pp_threads = std::stoi(argv[i]); + } else if (arg == "-b" || arg == "--batch-size") { if (++i >= argc) @@ -1178,6 +1189,7 @@ int main(int argc, char **argv) {"commit", BUILD_COMMIT}}); LOG_INFO("system info", { {"n_threads", params.n_threads}, + {"pp_threads", params.pp_threads}, {"total_threads", std::thread::hardware_concurrency()}, {"system_info", llama_print_system_info()}, });