add pp_threads support to other files

This commit is contained in:
netrunnereve 2023-08-08 22:19:59 -04:00
parent d854348992
commit be26777a6a
3 changed files with 20 additions and 7 deletions

View file

@ -50,8 +50,8 @@ int main(int argc, char ** argv) {
// print system information
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
fprintf(stderr, "system_info: n_threads = %d / %d | pp_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), params.pp_threads, std::thread::hardware_concurrency(), llama_print_system_info());
}
int n_past = 0;
@ -74,7 +74,7 @@ int main(int argc, char ** argv) {
if (params.embedding){
if (embd_inp.size() > 0) {
if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads, params.n_threads)) {
if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads, params.pp_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}

View file

@ -10,6 +10,7 @@ int main(int argc, char ** argv) {
gpt_params params;
params.seed = 42;
params.n_threads = 4;
params.pp_threads = 4;
params.repeat_last_n = 64;
params.prompt = "The quick brown fox";
@ -56,7 +57,7 @@ int main(int argc, char ** argv) {
}
// evaluate prompt
llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads, params.n_threads);
llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads, params.pp_threads);
last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
n_past += n_prompt_tokens;
@ -93,7 +94,7 @@ int main(int argc, char ** argv) {
last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str);
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads, params.n_threads)) {
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads, params.pp_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx);
llama_free_model(model);
@ -153,7 +154,7 @@ int main(int argc, char ** argv) {
last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str);
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads, params.n_threads)) {
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads, params.pp_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx2);
llama_free_model(model);

View file

@ -382,7 +382,7 @@ struct llama_server_context
{
n_eval = params.n_batch;
}
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads, params.n_threads))
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads, params.pp_threads))
{
LOG_ERROR("failed to eval", {
{"n_eval", n_eval},
@ -648,6 +648,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
fprintf(stdout, " -h, --help show this help message and exit\n");
fprintf(stdout, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stdout, " -ppt N, --pp-threads N\n");
fprintf(stdout, " number of threads to use during prompt processing (default: %d)\n", params.pp_threads);
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
@ -818,6 +820,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
params.n_threads = std::stoi(argv[i]);
}
else if (arg == "-ppt" || arg == "--pp-threads")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
params.pp_threads = std::stoi(argv[i]);
}
else if (arg == "-b" || arg == "--batch-size")
{
if (++i >= argc)
@ -1178,6 +1189,7 @@ int main(int argc, char **argv)
{"commit", BUILD_COMMIT}});
LOG_INFO("system info", {
{"n_threads", params.n_threads},
{"pp_threads", params.pp_threads},
{"total_threads", std::thread::hardware_concurrency()},
{"system_info", llama_print_system_info()},
});