add pp_threads support to other files
This commit is contained in:
parent
d854348992
commit
be26777a6a
3 changed files with 20 additions and 7 deletions
|
@ -50,8 +50,8 @@ int main(int argc, char ** argv) {
|
||||||
// print system information
|
// print system information
|
||||||
{
|
{
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
|
fprintf(stderr, "system_info: n_threads = %d / %d | pp_threads = %d / %d | %s\n",
|
||||||
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
params.n_threads, std::thread::hardware_concurrency(), params.pp_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
||||||
}
|
}
|
||||||
|
|
||||||
int n_past = 0;
|
int n_past = 0;
|
||||||
|
@ -74,7 +74,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
if (params.embedding){
|
if (params.embedding){
|
||||||
if (embd_inp.size() > 0) {
|
if (embd_inp.size() > 0) {
|
||||||
if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads, params.n_threads)) {
|
if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads, params.pp_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
params.seed = 42;
|
params.seed = 42;
|
||||||
params.n_threads = 4;
|
params.n_threads = 4;
|
||||||
|
params.pp_threads = 4;
|
||||||
params.repeat_last_n = 64;
|
params.repeat_last_n = 64;
|
||||||
params.prompt = "The quick brown fox";
|
params.prompt = "The quick brown fox";
|
||||||
|
|
||||||
|
@ -56,7 +57,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// evaluate prompt
|
// evaluate prompt
|
||||||
llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads, params.n_threads);
|
llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads, params.pp_threads);
|
||||||
|
|
||||||
last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
|
last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
|
||||||
n_past += n_prompt_tokens;
|
n_past += n_prompt_tokens;
|
||||||
|
@ -93,7 +94,7 @@ int main(int argc, char ** argv) {
|
||||||
last_n_tokens_data.push_back(next_token);
|
last_n_tokens_data.push_back(next_token);
|
||||||
|
|
||||||
printf("%s", next_token_str);
|
printf("%s", next_token_str);
|
||||||
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads, params.n_threads)) {
|
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads, params.pp_threads)) {
|
||||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
@ -153,7 +154,7 @@ int main(int argc, char ** argv) {
|
||||||
last_n_tokens_data.push_back(next_token);
|
last_n_tokens_data.push_back(next_token);
|
||||||
|
|
||||||
printf("%s", next_token_str);
|
printf("%s", next_token_str);
|
||||||
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads, params.n_threads)) {
|
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads, params.pp_threads)) {
|
||||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||||
llama_free(ctx2);
|
llama_free(ctx2);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
|
|
@ -382,7 +382,7 @@ struct llama_server_context
|
||||||
{
|
{
|
||||||
n_eval = params.n_batch;
|
n_eval = params.n_batch;
|
||||||
}
|
}
|
||||||
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads, params.n_threads))
|
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads, params.pp_threads))
|
||||||
{
|
{
|
||||||
LOG_ERROR("failed to eval", {
|
LOG_ERROR("failed to eval", {
|
||||||
{"n_eval", n_eval},
|
{"n_eval", n_eval},
|
||||||
|
@ -648,6 +648,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
||||||
fprintf(stdout, " -h, --help show this help message and exit\n");
|
fprintf(stdout, " -h, --help show this help message and exit\n");
|
||||||
fprintf(stdout, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
|
fprintf(stdout, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
|
||||||
fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
||||||
|
fprintf(stdout, " -ppt N, --pp-threads N\n");
|
||||||
|
fprintf(stdout, " number of threads to use during prompt processing (default: %d)\n", params.pp_threads);
|
||||||
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
|
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
|
||||||
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
|
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
|
||||||
fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
|
fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
|
||||||
|
@ -818,6 +820,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||||
}
|
}
|
||||||
params.n_threads = std::stoi(argv[i]);
|
params.n_threads = std::stoi(argv[i]);
|
||||||
}
|
}
|
||||||
|
else if (arg == "-ppt" || arg == "--pp-threads")
|
||||||
|
{
|
||||||
|
if (++i >= argc)
|
||||||
|
{
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.pp_threads = std::stoi(argv[i]);
|
||||||
|
}
|
||||||
else if (arg == "-b" || arg == "--batch-size")
|
else if (arg == "-b" || arg == "--batch-size")
|
||||||
{
|
{
|
||||||
if (++i >= argc)
|
if (++i >= argc)
|
||||||
|
@ -1178,6 +1189,7 @@ int main(int argc, char **argv)
|
||||||
{"commit", BUILD_COMMIT}});
|
{"commit", BUILD_COMMIT}});
|
||||||
LOG_INFO("system info", {
|
LOG_INFO("system info", {
|
||||||
{"n_threads", params.n_threads},
|
{"n_threads", params.n_threads},
|
||||||
|
{"pp_threads", params.pp_threads},
|
||||||
{"total_threads", std::thread::hardware_concurrency()},
|
{"total_threads", std::thread::hardware_concurrency()},
|
||||||
{"system_info", llama_print_system_info()},
|
{"system_info", llama_print_system_info()},
|
||||||
});
|
});
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue