add pp_threads support to other files
This commit is contained in:
parent
d854348992
commit
be26777a6a
3 changed files with 20 additions and 7 deletions
|
@ -50,8 +50,8 @@ int main(int argc, char ** argv) {
|
|||
// print system information
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
|
||||
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
||||
fprintf(stderr, "system_info: n_threads = %d / %d | pp_threads = %d / %d | %s\n",
|
||||
params.n_threads, std::thread::hardware_concurrency(), params.pp_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
||||
}
|
||||
|
||||
int n_past = 0;
|
||||
|
@ -74,7 +74,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
if (params.embedding){
|
||||
if (embd_inp.size() > 0) {
|
||||
if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads, params.n_threads)) {
|
||||
if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads, params.pp_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ int main(int argc, char ** argv) {
|
|||
gpt_params params;
|
||||
params.seed = 42;
|
||||
params.n_threads = 4;
|
||||
params.pp_threads = 4;
|
||||
params.repeat_last_n = 64;
|
||||
params.prompt = "The quick brown fox";
|
||||
|
||||
|
@ -56,7 +57,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// evaluate prompt
|
||||
llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads, params.n_threads);
|
||||
llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads, params.pp_threads);
|
||||
|
||||
last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
|
||||
n_past += n_prompt_tokens;
|
||||
|
@ -93,7 +94,7 @@ int main(int argc, char ** argv) {
|
|||
last_n_tokens_data.push_back(next_token);
|
||||
|
||||
printf("%s", next_token_str);
|
||||
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads, params.n_threads)) {
|
||||
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads, params.pp_threads)) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
@ -153,7 +154,7 @@ int main(int argc, char ** argv) {
|
|||
last_n_tokens_data.push_back(next_token);
|
||||
|
||||
printf("%s", next_token_str);
|
||||
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads, params.n_threads)) {
|
||||
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads, params.pp_threads)) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_free(ctx2);
|
||||
llama_free_model(model);
|
||||
|
|
|
@ -382,7 +382,7 @@ struct llama_server_context
|
|||
{
|
||||
n_eval = params.n_batch;
|
||||
}
|
||||
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads, params.n_threads))
|
||||
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads, params.pp_threads))
|
||||
{
|
||||
LOG_ERROR("failed to eval", {
|
||||
{"n_eval", n_eval},
|
||||
|
@ -648,6 +648,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
|||
fprintf(stdout, " -h, --help show this help message and exit\n");
|
||||
fprintf(stdout, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
|
||||
fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
||||
fprintf(stdout, " -ppt N, --pp-threads N\n");
|
||||
fprintf(stdout, " number of threads to use during prompt processing (default: %d)\n", params.pp_threads);
|
||||
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
|
||||
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
|
||||
fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
|
||||
|
@ -818,6 +820,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||
}
|
||||
params.n_threads = std::stoi(argv[i]);
|
||||
}
|
||||
else if (arg == "-ppt" || arg == "--pp-threads")
|
||||
{
|
||||
if (++i >= argc)
|
||||
{
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.pp_threads = std::stoi(argv[i]);
|
||||
}
|
||||
else if (arg == "-b" || arg == "--batch-size")
|
||||
{
|
||||
if (++i >= argc)
|
||||
|
@ -1178,6 +1189,7 @@ int main(int argc, char **argv)
|
|||
{"commit", BUILD_COMMIT}});
|
||||
LOG_INFO("system info", {
|
||||
{"n_threads", params.n_threads},
|
||||
{"pp_threads", params.pp_threads},
|
||||
{"total_threads", std::thread::hardware_concurrency()},
|
||||
{"system_info", llama_print_system_info()},
|
||||
});
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue