This commit is contained in:
Eve 2023-08-18 21:50:59 +00:00 committed by GitHub
commit 95f2c5d475
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 66 additions and 32 deletions

View file

@ -119,6 +119,15 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
if (params.n_threads <= 0) { if (params.n_threads <= 0) {
params.n_threads = std::thread::hardware_concurrency(); params.n_threads = std::thread::hardware_concurrency();
} }
} else if (arg == "-ppt" || arg == "--pp-threads") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.pp_threads = std::stoi(argv[i]);
if (params.pp_threads <= 0) {
params.pp_threads = std::thread::hardware_concurrency();
}
} else if (arg == "-p" || arg == "--prompt") { } else if (arg == "-p" || arg == "--prompt") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -545,6 +554,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " --color colorise output to distinguish prompt and user input from generations\n"); fprintf(stdout, " --color colorise output to distinguish prompt and user input from generations\n");
fprintf(stdout, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); fprintf(stdout, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stdout, " -ppt N, --pp-threads N\n");
fprintf(stdout, " number of threads to use during prompt processing (default: %d)\n", params.pp_threads);
fprintf(stdout, " -p PROMPT, --prompt PROMPT\n"); fprintf(stdout, " -p PROMPT, --prompt PROMPT\n");
fprintf(stdout, " prompt to start generation with (default: empty)\n"); fprintf(stdout, " prompt to start generation with (default: empty)\n");
fprintf(stdout, " -e process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); fprintf(stdout, " -e process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");

View file

@ -19,6 +19,7 @@ int32_t get_num_physical_cores();
struct gpt_params { struct gpt_params {
uint32_t seed = -1; // RNG seed uint32_t seed = -1; // RNG seed
int32_t n_threads = get_num_physical_cores(); int32_t n_threads = get_num_physical_cores();
int32_t pp_threads = get_num_physical_cores();
int32_t n_predict = -1; // new tokens to predict int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)

View file

@ -83,7 +83,7 @@ bool eval_float(void * model, float * input, int N){
if (n_eval > n_batch) { if (n_eval > n_batch) {
n_eval = n_batch; n_eval = n_batch;
} }
if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) { if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return false; return false;
} }
@ -104,7 +104,7 @@ bool eval_tokens(void * model, std::vector<llama_token> tokens) {
if (n_eval > params.n_batch) { if (n_eval > params.n_batch) {
n_eval = params.n_batch; n_eval = params.n_batch;
} }
if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) { if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return false; return false;
} }

View file

@ -50,8 +50,8 @@ int main(int argc, char ** argv) {
// print system information // print system information
{ {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", fprintf(stderr, "system_info: n_threads = %d / %d | pp_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); params.n_threads, std::thread::hardware_concurrency(), params.pp_threads, std::thread::hardware_concurrency(), llama_print_system_info());
} }
int n_past = 0; int n_past = 0;
@ -74,7 +74,7 @@ int main(int argc, char ** argv) {
if (params.embedding){ if (params.embedding){
if (embd_inp.size() > 0) { if (embd_inp.size() > 0) {
if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads)) { if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads, params.pp_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return 1; return 1;
} }

View file

@ -853,7 +853,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
int n_processed = 0; int n_processed = 0;
while (n_processed < n_prompt) { while (n_processed < n_prompt) {
int n_tokens = std::min(n_prompt - n_processed, n_batch); int n_tokens = std::min(n_prompt - n_processed, n_batch);
llama_eval(ctx, tokens.data(), n_tokens, n_past + n_processed, n_threads); llama_eval(ctx, tokens.data(), n_tokens, n_past + n_processed, n_threads, n_threads);
n_processed += n_tokens; n_processed += n_tokens;
} }
} }
@ -861,7 +861,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
llama_token token = llama_token_bos(); llama_token token = llama_token_bos();
for (int i = 0; i < n_gen; i++) { for (int i = 0; i < n_gen; i++) {
llama_eval(ctx, &token, 1, n_past + i, n_threads); llama_eval(ctx, &token, 1, n_past + i, n_threads, n_threads);
} }
} }

View file

@ -263,6 +263,7 @@ These options help improve the performance and memory usage of the LLaMA models.
### Number of Threads ### Number of Threads
- `-t N, --threads N`: Set the number of threads to use during computation. For optimal performance, it is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). Using the correct number of threads can greatly improve performance. - `-t N, --threads N`: Set the number of threads to use during computation. For optimal performance, it is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). Using the correct number of threads can greatly improve performance.
- `-ppt N, --pp-threads N`: Set the number of threads to use during prompt processing only.
### Mlock ### Mlock

View file

@ -133,8 +133,8 @@ int main(int argc, char ** argv) {
// print system information // print system information
{ {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", fprintf(stderr, "system_info: n_threads = %d / %d | pp_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); params.n_threads, std::thread::hardware_concurrency(), params.pp_threads, std::thread::hardware_concurrency(), llama_print_system_info());
} }
// determine the maximum memory usage needed to do inference for the given n_batch and n_ctx parameters // determine the maximum memory usage needed to do inference for the given n_batch and n_ctx parameters
@ -144,7 +144,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: testing memory usage for n_batch = %d, n_ctx = %d\n", __func__, params.n_batch, params.n_ctx); fprintf(stderr, "%s: testing memory usage for n_batch = %d, n_ctx = %d\n", __func__, params.n_batch, params.n_ctx);
const std::vector<llama_token> tmp(params.n_batch, llama_token_bos()); const std::vector<llama_token> tmp(params.n_batch, llama_token_bos());
llama_eval(ctx, tmp.data(), tmp.size(), params.n_ctx, params.n_threads); llama_eval(ctx, tmp.data(), tmp.size(), params.n_ctx, params.n_threads, params.pp_threads);
} }
llama_print_timings(ctx); llama_print_timings(ctx);
@ -406,7 +406,7 @@ int main(int argc, char ** argv) {
// do one empty run to warm up the model // do one empty run to warm up the model
{ {
const std::vector<llama_token> tmp = { llama_token_bos(), }; const std::vector<llama_token> tmp = { llama_token_bos(), };
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads); llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads, params.pp_threads);
llama_reset_timings(ctx); llama_reset_timings(ctx);
} }
@ -513,7 +513,7 @@ int main(int argc, char ** argv) {
for (int i = 0; i < input_size; i += params.n_batch) { for (int i = 0; i < input_size; i += params.n_batch) {
int n_eval = std::min(input_size - i, params.n_batch); int n_eval = std::min(input_size - i, params.n_batch);
if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) { if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads, params.pp_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return 1; return 1;
} }
@ -527,7 +527,7 @@ int main(int argc, char ** argv) {
if (n_eval > params.n_batch) { if (n_eval > params.n_batch) {
n_eval = params.n_batch; n_eval = params.n_batch;
} }
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) { if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads, params.pp_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return 1; return 1;
} }

View file

@ -66,7 +66,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
tokens[batch_start] = llama_token_bos(); tokens[batch_start] = llama_token_bos();
} }
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads, params.pp_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return; return;
} }
@ -233,7 +233,7 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
} }
// Evaluate the query // Evaluate the query
if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads)) { if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return; return;
} }
@ -337,8 +337,8 @@ int main(int argc, char ** argv) {
// print system information // print system information
{ {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", fprintf(stderr, "system_info: pp_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); params.pp_threads, std::thread::hardware_concurrency(), llama_print_system_info());
} }
if (params.hellaswag) { if (params.hellaswag) {

View file

@ -10,6 +10,7 @@ int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
params.seed = 42; params.seed = 42;
params.n_threads = 4; params.n_threads = 4;
params.pp_threads = 4;
params.repeat_last_n = 64; params.repeat_last_n = 64;
params.prompt = "The quick brown fox"; params.prompt = "The quick brown fox";
@ -56,7 +57,7 @@ int main(int argc, char ** argv) {
} }
// evaluate prompt // evaluate prompt
llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads); llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads, params.pp_threads);
last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens); last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
n_past += n_prompt_tokens; n_past += n_prompt_tokens;
@ -93,7 +94,7 @@ int main(int argc, char ** argv) {
last_n_tokens_data.push_back(next_token); last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str); printf("%s", next_token_str);
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) { if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads, params.pp_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__); fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx); llama_free(ctx);
llama_free_model(model); llama_free_model(model);
@ -153,7 +154,7 @@ int main(int argc, char ** argv) {
last_n_tokens_data.push_back(next_token); last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str); printf("%s", next_token_str);
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) { if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads, params.pp_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__); fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx2); llama_free(ctx2);
llama_free_model(model); llama_free_model(model);

View file

@ -5,6 +5,7 @@ This example demonstrates a simple HTTP API server and a simple web front end to
Command line options: Command line options:
- `--threads N`, `-t N`: Set the number of threads to use during computation. - `--threads N`, `-t N`: Set the number of threads to use during computation.
- `-ppt N`, `--pp-threads N`: Set the number of threads to use during prompt processing only.
- `-m FNAME`, `--model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`). - `-m FNAME`, `--model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`).
- `-m ALIAS`, `--alias ALIAS`: Set an alias for the model. The alias will be returned in API responses. - `-m ALIAS`, `--alias ALIAS`: Set an alias for the model. The alias will be returned in API responses.
- `-c N`, `--ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference. The size may differ in other models, for example, baichuan models were build with a context of 4096. - `-c N`, `--ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference. The size may differ in other models, for example, baichuan models were build with a context of 4096.

View file

@ -385,7 +385,7 @@ struct llama_server_context
{ {
n_eval = params.n_batch; n_eval = params.n_batch;
} }
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads)) if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads, params.pp_threads))
{ {
LOG_ERROR("failed to eval", { LOG_ERROR("failed to eval", {
{"n_eval", n_eval}, {"n_eval", n_eval},
@ -651,6 +651,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
fprintf(stdout, " -h, --help show this help message and exit\n"); fprintf(stdout, " -h, --help show this help message and exit\n");
fprintf(stdout, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled"); fprintf(stdout, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stdout, " -ppt N, --pp-threads N\n");
fprintf(stdout, " number of threads to use during prompt processing (default: %d)\n", params.pp_threads);
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa); fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps); fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
@ -822,6 +824,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
} }
params.n_threads = std::stoi(argv[i]); params.n_threads = std::stoi(argv[i]);
} }
else if (arg == "-ppt" || arg == "--pp-threads")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
params.pp_threads = std::stoi(argv[i]);
}
else if (arg == "-b" || arg == "--batch-size") else if (arg == "-b" || arg == "--batch-size")
{ {
if (++i >= argc) if (++i >= argc)
@ -1185,6 +1196,7 @@ int main(int argc, char **argv)
{"commit", BUILD_COMMIT}}); {"commit", BUILD_COMMIT}});
LOG_INFO("system info", { LOG_INFO("system info", {
{"n_threads", params.n_threads}, {"n_threads", params.n_threads},
{"pp_threads", params.pp_threads},
{"total_threads", std::thread::hardware_concurrency()}, {"total_threads", std::thread::hardware_concurrency()},
{"system_info", llama_print_system_info()}, {"system_info", llama_print_system_info()},
}); });

View file

@ -123,7 +123,7 @@ int main(int argc, char ** argv)
// Evaluate the tokens : // Evaluate the tokens :
//--------------------------------- //---------------------------------
if ( llama_eval( ctx , tokens_list.data() , int(tokens_list.size()) , llama_get_kv_cache_token_count( ctx ) , params.n_threads ) ) if ( llama_eval( ctx , tokens_list.data() , int(tokens_list.size()) , llama_get_kv_cache_token_count( ctx ) , params.n_threads , params.n_threads ) )
{ {
fprintf( stderr, "%s : failed to eval\n" , __func__ ); fprintf( stderr, "%s : failed to eval\n" , __func__ );
return 1; return 1;

View file

@ -1786,7 +1786,8 @@ static struct ggml_cgraph * llama_build_graph(
// - embd embeddings input // - embd embeddings input
// - n_tokens number of tokens // - n_tokens number of tokens
// - n_past: the context size so far // - n_past: the context size so far
// - n_threads: number of threads to use // - n_threads: number of threads to use for inference
// - pp_threads: number of threads to use for prompt processing
// //
static bool llama_eval_internal( static bool llama_eval_internal(
llama_context & lctx, llama_context & lctx,
@ -1795,6 +1796,7 @@ static bool llama_eval_internal(
int n_tokens, int n_tokens,
int n_past, int n_past,
int n_threads, int n_threads,
int pp_threads,
const char * cgraph_fname) { const char * cgraph_fname) {
LLAMA_ASSERT((!tokens && embd) || (tokens && !embd)); LLAMA_ASSERT((!tokens && embd) || (tokens && !embd));
@ -1838,7 +1840,8 @@ static bool llama_eval_internal(
// for big prompts, if BLAS is enabled, it is better to use only one thread // for big prompts, if BLAS is enabled, it is better to use only one thread
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads; pp_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : pp_threads;
n_threads = N > 1 ? pp_threads : n_threads;
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
@ -3484,7 +3487,7 @@ struct llama_context * llama_new_context_with_model(
if (ggml_mpi_rank(ctx->ctx_mpi) > 0) { if (ggml_mpi_rank(ctx->ctx_mpi) > 0) {
// Enter a blocking eval loop with dummy input, letting rank=0 drive the process // Enter a blocking eval loop with dummy input, letting rank=0 drive the process
const std::vector<llama_token> tmp(ctx->model.hparams.n_ctx, llama_token_bos()); const std::vector<llama_token> tmp(ctx->model.hparams.n_ctx, llama_token_bos());
while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {}; while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0, 0)) {};
llama_backend_free(); llama_backend_free();
exit(1); exit(1);
} }
@ -4176,8 +4179,9 @@ int llama_eval(
const llama_token * tokens, const llama_token * tokens,
int n_tokens, int n_tokens,
int n_past, int n_past,
int n_threads) { int n_threads,
if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) { int pp_threads) {
if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, pp_threads, nullptr)) {
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
return 1; return 1;
} }
@ -4198,8 +4202,9 @@ int llama_eval_embd(
const float * embd, const float * embd,
int n_tokens, int n_tokens,
int n_past, int n_past,
int n_threads) { int n_threads,
if (!llama_eval_internal(*ctx, nullptr, embd, n_tokens, n_past, n_threads, nullptr)) { int pp_threads) {
if (!llama_eval_internal(*ctx, nullptr, embd, n_tokens, n_past, n_threads, pp_threads, nullptr)) {
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
return 1; return 1;
} }
@ -4220,7 +4225,7 @@ int llama_eval_export(struct llama_context * ctx, const char * fname) {
const std::vector<llama_token> tmp(n_batch, llama_token_bos()); const std::vector<llama_token> tmp(n_batch, llama_token_bos());
if (!llama_eval_internal(*ctx, tmp.data(), nullptr, tmp.size(), n_ctx, 1, fname)) { if (!llama_eval_internal(*ctx, tmp.data(), nullptr, tmp.size(), n_ctx, 1, 1, fname)) {
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
return 1; return 1;
} }

View file

@ -308,7 +308,8 @@ extern "C" {
const llama_token * tokens, const llama_token * tokens,
int n_tokens, int n_tokens,
int n_past, int n_past,
int n_threads); int n_threads,
int pp_threads);
// Same as llama_eval, but use float matrix input directly. // Same as llama_eval, but use float matrix input directly.
LLAMA_API int llama_eval_embd( LLAMA_API int llama_eval_embd(
@ -316,7 +317,8 @@ extern "C" {
const float * embd, const float * embd,
int n_tokens, int n_tokens,
int n_past, int n_past,
int n_threads); int n_threads,
int pp_threads);
// Export a static computation graph for context of 511 and batch size of 1 // Export a static computation graph for context of 511 and batch size of 1
// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these // NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these