builds fine
This commit is contained in:
parent
5f022185a1
commit
1de711d4f8
10 changed files with 40 additions and 21 deletions
|
@ -119,6 +119,15 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
if (params.n_threads <= 0) {
|
if (params.n_threads <= 0) {
|
||||||
params.n_threads = std::thread::hardware_concurrency();
|
params.n_threads = std::thread::hardware_concurrency();
|
||||||
}
|
}
|
||||||
|
} else if (arg == "-ppt" || arg == "--pp-threads") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.pp_threads = std::stoi(argv[i]);
|
||||||
|
if (params.pp_threads <= 0) {
|
||||||
|
params.pp_threads = params.n_threads;
|
||||||
|
}
|
||||||
} else if (arg == "-p" || arg == "--prompt") {
|
} else if (arg == "-p" || arg == "--prompt") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -524,6 +533,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
fprintf(stdout, " --color colorise output to distinguish prompt and user input from generations\n");
|
fprintf(stdout, " --color colorise output to distinguish prompt and user input from generations\n");
|
||||||
fprintf(stdout, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
|
fprintf(stdout, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
|
||||||
fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
||||||
|
fprintf(stdout, " -ppt N, --pp-threads N\n");
|
||||||
|
fprintf(stdout, " number of threads to use during prompt processing (default is equal to --threads)\n");
|
||||||
fprintf(stdout, " -p PROMPT, --prompt PROMPT\n");
|
fprintf(stdout, " -p PROMPT, --prompt PROMPT\n");
|
||||||
fprintf(stdout, " prompt to start generation with (default: empty)\n");
|
fprintf(stdout, " prompt to start generation with (default: empty)\n");
|
||||||
fprintf(stdout, " -e process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
|
fprintf(stdout, " -e process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
|
||||||
|
@ -657,6 +668,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
||||||
lparams.embedding = params.embedding;
|
lparams.embedding = params.embedding;
|
||||||
lparams.rope_freq_base = params.rope_freq_base;
|
lparams.rope_freq_base = params.rope_freq_base;
|
||||||
lparams.rope_freq_scale = params.rope_freq_scale;
|
lparams.rope_freq_scale = params.rope_freq_scale;
|
||||||
|
lparams.pp_threads = params.pp_threads;
|
||||||
|
|
||||||
return lparams;
|
return lparams;
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ int32_t get_num_physical_cores();
|
||||||
struct gpt_params {
|
struct gpt_params {
|
||||||
uint32_t seed = -1; // RNG seed
|
uint32_t seed = -1; // RNG seed
|
||||||
int32_t n_threads = get_num_physical_cores();
|
int32_t n_threads = get_num_physical_cores();
|
||||||
|
int32_t pp_threads = get_num_physical_cores();
|
||||||
int32_t n_predict = -1; // new tokens to predict
|
int32_t n_predict = -1; // new tokens to predict
|
||||||
int32_t n_ctx = 512; // context size
|
int32_t n_ctx = 512; // context size
|
||||||
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
||||||
|
|
|
@ -83,7 +83,7 @@ bool eval_float(void * model, float * input, int N){
|
||||||
if (n_eval > n_batch) {
|
if (n_eval > n_batch) {
|
||||||
n_eval = n_batch;
|
n_eval = n_batch;
|
||||||
}
|
}
|
||||||
if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) {
|
if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads, params.pp_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -104,7 +104,7 @@ bool eval_tokens(void * model, std::vector<llama_token> tokens) {
|
||||||
if (n_eval > params.n_batch) {
|
if (n_eval > params.n_batch) {
|
||||||
n_eval = params.n_batch;
|
n_eval = params.n_batch;
|
||||||
}
|
}
|
||||||
if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) {
|
if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads, params.pp_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,7 +74,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
if (params.embedding){
|
if (params.embedding){
|
||||||
if (embd_inp.size() > 0) {
|
if (embd_inp.size() > 0) {
|
||||||
if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads)) {
|
if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads, params.pp_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -144,7 +144,7 @@ int main(int argc, char ** argv) {
|
||||||
fprintf(stderr, "%s: testing memory usage for n_batch = %d, n_ctx = %d\n", __func__, params.n_batch, params.n_ctx);
|
fprintf(stderr, "%s: testing memory usage for n_batch = %d, n_ctx = %d\n", __func__, params.n_batch, params.n_ctx);
|
||||||
|
|
||||||
const std::vector<llama_token> tmp(params.n_batch, llama_token_bos());
|
const std::vector<llama_token> tmp(params.n_batch, llama_token_bos());
|
||||||
llama_eval(ctx, tmp.data(), tmp.size(), params.n_ctx, params.n_threads);
|
llama_eval(ctx, tmp.data(), tmp.size(), params.n_ctx, params.n_threads, params.pp_threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx);
|
||||||
|
@ -406,7 +406,7 @@ int main(int argc, char ** argv) {
|
||||||
// do one empty run to warm up the model
|
// do one empty run to warm up the model
|
||||||
{
|
{
|
||||||
const std::vector<llama_token> tmp = { llama_token_bos(), };
|
const std::vector<llama_token> tmp = { llama_token_bos(), };
|
||||||
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
|
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads, params.pp_threads);
|
||||||
llama_reset_timings(ctx);
|
llama_reset_timings(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -509,7 +509,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
for (int i = 0; i < input_size; i += params.n_batch) {
|
for (int i = 0; i < input_size; i += params.n_batch) {
|
||||||
int n_eval = std::min(input_size - i, params.n_batch);
|
int n_eval = std::min(input_size - i, params.n_batch);
|
||||||
if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) {
|
if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads, params.pp_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -523,7 +523,7 @@ int main(int argc, char ** argv) {
|
||||||
if (n_eval > params.n_batch) {
|
if (n_eval > params.n_batch) {
|
||||||
n_eval = params.n_batch;
|
n_eval = params.n_batch;
|
||||||
}
|
}
|
||||||
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) {
|
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads, params.pp_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,7 +66,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
||||||
tokens[batch_start] = llama_token_bos();
|
tokens[batch_start] = llama_token_bos();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
|
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads, params.pp_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -233,7 +233,7 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Evaluate the query
|
// Evaluate the query
|
||||||
if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads)) {
|
if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads, params.pp_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -350,7 +350,7 @@ struct llama_server_context
|
||||||
{
|
{
|
||||||
n_eval = params.n_batch;
|
n_eval = params.n_batch;
|
||||||
}
|
}
|
||||||
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads))
|
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads, params.pp_threads))
|
||||||
{
|
{
|
||||||
LOG_ERROR("failed to eval", {
|
LOG_ERROR("failed to eval", {
|
||||||
{"n_eval", n_eval},
|
{"n_eval", n_eval},
|
||||||
|
|
|
@ -123,7 +123,7 @@ int main(int argc, char ** argv)
|
||||||
// Evaluate the tokens :
|
// Evaluate the tokens :
|
||||||
//---------------------------------
|
//---------------------------------
|
||||||
|
|
||||||
if ( llama_eval( ctx , tokens_list.data() , int(tokens_list.size()) , llama_get_kv_cache_token_count( ctx ) , params.n_threads ) )
|
if ( llama_eval( ctx , tokens_list.data() , int(tokens_list.size()) , llama_get_kv_cache_token_count( ctx ) , params.n_threads , params.pp_threads ) )
|
||||||
{
|
{
|
||||||
fprintf( stderr, "%s : failed to eval\n" , __func__ );
|
fprintf( stderr, "%s : failed to eval\n" , __func__ );
|
||||||
return 1;
|
return 1;
|
||||||
|
|
19
llama.cpp
19
llama.cpp
|
@ -895,6 +895,7 @@ struct llama_context_params llama_context_default_params() {
|
||||||
/*.rms_norm_eps =*/ LLAMA_DEFAULT_RMS_EPS,
|
/*.rms_norm_eps =*/ LLAMA_DEFAULT_RMS_EPS,
|
||||||
/*.gpu_layers =*/ 0,
|
/*.gpu_layers =*/ 0,
|
||||||
/*.main_gpu =*/ 0,
|
/*.main_gpu =*/ 0,
|
||||||
|
/*.pp_threads =*/ GGML_DEFAULT_N_THREADS,
|
||||||
/*.tensor_split =*/ nullptr,
|
/*.tensor_split =*/ nullptr,
|
||||||
/*.rope_freq_base =*/ 10000.0f,
|
/*.rope_freq_base =*/ 10000.0f,
|
||||||
/*.rope_freq_scale =*/ 1.0f,
|
/*.rope_freq_scale =*/ 1.0f,
|
||||||
|
@ -1772,6 +1773,7 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
// - n_tokens number of tokens
|
// - n_tokens number of tokens
|
||||||
// - n_past: the context size so far
|
// - n_past: the context size so far
|
||||||
// - n_threads: number of threads to use
|
// - n_threads: number of threads to use
|
||||||
|
// - pp_threads: number of threads to use for prompt processing
|
||||||
//
|
//
|
||||||
static bool llama_eval_internal(
|
static bool llama_eval_internal(
|
||||||
llama_context & lctx,
|
llama_context & lctx,
|
||||||
|
@ -1780,6 +1782,7 @@ static bool llama_eval_internal(
|
||||||
int n_tokens,
|
int n_tokens,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_threads,
|
int n_threads,
|
||||||
|
int pp_threads,
|
||||||
const char * cgraph_fname) {
|
const char * cgraph_fname) {
|
||||||
|
|
||||||
LLAMA_ASSERT((!tokens && embd) || (tokens && !embd));
|
LLAMA_ASSERT((!tokens && embd) || (tokens && !embd));
|
||||||
|
@ -1814,8 +1817,6 @@ static bool llama_eval_internal(
|
||||||
|
|
||||||
// fprintf(stderr, "graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
// fprintf(stderr, "graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||||
|
|
||||||
int32_t pp_threads = 3;
|
|
||||||
|
|
||||||
// for big prompts, if BLAS is enabled, it is better to use only one thread
|
// for big prompts, if BLAS is enabled, it is better to use only one thread
|
||||||
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
|
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
|
||||||
pp_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : pp_threads;
|
pp_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : pp_threads;
|
||||||
|
@ -3365,7 +3366,7 @@ struct llama_context * llama_new_context_with_model(
|
||||||
if (ggml_mpi_rank(ctx->ctx_mpi) > 0) {
|
if (ggml_mpi_rank(ctx->ctx_mpi) > 0) {
|
||||||
// Enter a blocking eval loop with dummy input, letting rank=0 drive the process
|
// Enter a blocking eval loop with dummy input, letting rank=0 drive the process
|
||||||
const std::vector<llama_token> tmp(ctx->model.hparams.n_ctx, llama_token_bos());
|
const std::vector<llama_token> tmp(ctx->model.hparams.n_ctx, llama_token_bos());
|
||||||
while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {};
|
while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0, 0)) {};
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
@ -4057,8 +4058,9 @@ int llama_eval(
|
||||||
const llama_token * tokens,
|
const llama_token * tokens,
|
||||||
int n_tokens,
|
int n_tokens,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_threads) {
|
int n_threads,
|
||||||
if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) {
|
int pp_threads) {
|
||||||
|
if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, pp_threads, nullptr)) {
|
||||||
fprintf(stderr, "%s: failed to eval\n", __func__);
|
fprintf(stderr, "%s: failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -4079,8 +4081,9 @@ int llama_eval_embd(
|
||||||
const float * embd,
|
const float * embd,
|
||||||
int n_tokens,
|
int n_tokens,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_threads) {
|
int n_threads,
|
||||||
if (!llama_eval_internal(*ctx, nullptr, embd, n_tokens, n_past, n_threads, nullptr)) {
|
int pp_threads) {
|
||||||
|
if (!llama_eval_internal(*ctx, nullptr, embd, n_tokens, n_past, n_threads, pp_threads, nullptr)) {
|
||||||
fprintf(stderr, "%s: failed to eval\n", __func__);
|
fprintf(stderr, "%s: failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -4101,7 +4104,7 @@ int llama_eval_export(struct llama_context * ctx, const char * fname) {
|
||||||
|
|
||||||
const std::vector<llama_token> tmp(n_batch, llama_token_bos());
|
const std::vector<llama_token> tmp(n_batch, llama_token_bos());
|
||||||
|
|
||||||
if (!llama_eval_internal(*ctx, tmp.data(), nullptr, tmp.size(), n_ctx, 1, fname)) {
|
if (!llama_eval_internal(*ctx, tmp.data(), nullptr, tmp.size(), n_ctx, 1, 1, fname)) {
|
||||||
fprintf(stderr, "%s: failed to eval\n", __func__);
|
fprintf(stderr, "%s: failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
7
llama.h
7
llama.h
|
@ -94,6 +94,7 @@ extern "C" {
|
||||||
float rms_norm_eps; // rms norm epsilon (TEMP - will be moved to model hparams)
|
float rms_norm_eps; // rms norm epsilon (TEMP - will be moved to model hparams)
|
||||||
int32_t n_gpu_layers; // number of layers to store in VRAM
|
int32_t n_gpu_layers; // number of layers to store in VRAM
|
||||||
int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
||||||
|
int32_t pp_threads; // number of threads used for prompt processing only
|
||||||
|
|
||||||
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
|
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
|
||||||
|
|
||||||
|
@ -291,7 +292,8 @@ extern "C" {
|
||||||
const llama_token * tokens,
|
const llama_token * tokens,
|
||||||
int n_tokens,
|
int n_tokens,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_threads);
|
int n_threads,
|
||||||
|
int pp_threads);
|
||||||
|
|
||||||
// Same as llama_eval, but use float matrix input directly.
|
// Same as llama_eval, but use float matrix input directly.
|
||||||
LLAMA_API int llama_eval_embd(
|
LLAMA_API int llama_eval_embd(
|
||||||
|
@ -299,7 +301,8 @@ extern "C" {
|
||||||
const float * embd,
|
const float * embd,
|
||||||
int n_tokens,
|
int n_tokens,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_threads);
|
int n_threads,
|
||||||
|
int pp_threads);
|
||||||
|
|
||||||
// Export a static computation graph for context of 511 and batch size of 1
|
// Export a static computation graph for context of 511 and batch size of 1
|
||||||
// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these
|
// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue