diff --git a/common/common.cpp b/common/common.cpp index feb5ea60b..ad81dc7b1 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -129,6 +129,15 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { if (params.n_threads <= 0) { params.n_threads = std::thread::hardware_concurrency(); } + } else if (arg == "-tb" || arg == "--threads-batch") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_threads_batch = std::stoi(argv[i]); + if (params.n_threads_batch <= 0) { + params.n_threads_batch = std::thread::hardware_concurrency(); + } } else if (arg == "-p" || arg == "--prompt") { if (++i >= argc) { invalid_param = true; @@ -606,7 +615,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" (can be specified more than once for multiple prompts).\n"); printf(" --color colorise output to distinguish prompt and user input from generations\n"); printf(" -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); - printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); + printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads); + printf(" -tb N, --threads-batch N\n"); + printf(" number of threads to use during batch evaluation and prompt processing (default: same as --threads)\n"); printf(" -p PROMPT, --prompt PROMPT\n"); printf(" prompt to start generation with (default: empty)\n"); printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); @@ -742,6 +753,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.n_ctx = params.n_ctx; cparams.n_batch = params.n_batch; + cparams.n_threads = params.n_threads; + cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; cparams.low_vram = params.low_vram; cparams.mul_mat_q = params.mul_mat_q; cparams.seed = params.seed; @@ -756,7 +769,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param std::tuple llama_init_from_gpt_params(gpt_params & params) { auto mparams = llama_model_params_from_gpt_params(params); - auto cparams = llama_context_params_from_gpt_params(params); llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams); if (model == NULL) { @@ -764,6 +776,8 @@ std::tuple llama_init_from_gpt_par return std::make_tuple(nullptr, nullptr); } + auto cparams = llama_context_params_from_gpt_params(params); + llama_context * lctx = llama_new_context_with_model(model, cparams); if (lctx == NULL) { fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); @@ -792,7 +806,7 @@ std::tuple llama_init_from_gpt_par LOG("warming up the model with an empty run\n"); const std::vector tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; - llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads); + llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0); llama_reset_timings(lctx); } diff --git a/common/common.h b/common/common.h index fa07fe51d..49a675425 100644 --- a/common/common.h +++ b/common/common.h @@ -36,6 +36,7 @@ int32_t get_num_physical_cores(); struct gpt_params { uint32_t seed = -1; // RNG seed int32_t n_threads = get_num_physical_cores(); + int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 512; // context size int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) diff --git a/examples/beam-search/beam-search.cpp b/examples/beam-search/beam-search.cpp index 888ae9665..b28499e8d 100644 --- a/examples/beam-search/beam-search.cpp +++ b/examples/beam-search/beam-search.cpp @@ -159,7 +159,7 @@ int main(int argc, char ** argv) std::cout << std::flush; int n_past = llama_get_kv_cache_token_count(ctx); - if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads)) + if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past)) { fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ ); return 1; @@ -169,7 +169,7 @@ int main(int argc, char ** argv) beam_search_callback_data callback_data{ctx, {}}; size_t const beam_width = static_cast(params.n_beams); int const n_predict = 256; - llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict, params.n_threads); + llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict); std::cout << "\n\n"; for (llama_token const token_id : callback_data.response) { diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index c995eef35..7c44008b4 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -80,7 +80,7 @@ bool eval_float(void * model, float * input, int N){ if (n_eval > n_batch) { n_eval = n_batch; } - if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) { + if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past)) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; } @@ -101,7 +101,7 @@ bool eval_tokens(void * model, std::vector tokens) { if (n_eval > params.n_batch) { n_eval = params.n_batch; } - if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) { + if (llama_eval(ctx, &tokens[i], n_eval, n_past)) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; } diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 93495aece..4c3d0cc46 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -80,7 +80,7 @@ int main(int argc, char ** argv) { while (!embd_inp.empty()) { int n_tokens = std::min(params.n_batch, (int) embd_inp.size()); - if (llama_eval(ctx, embd_inp.data(), n_tokens, n_past, params.n_threads)) { + if (llama_eval(ctx, embd_inp.data(), n_tokens, n_past)) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 2fed36ef9..125c3eea3 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -432,6 +432,9 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & mmq : params.mul_mat_q) for (const auto & nt : params.n_threads) { for (const auto & n_prompt : params.n_prompt) { + if (n_prompt == 0) { + continue; + } cmd_params_instance instance = { /* .model = */ m, /* .n_prompt = */ n_prompt, @@ -449,6 +452,9 @@ static std::vector get_cmd_params_instances(const cmd_param } for (const auto & n_gen : params.n_gen) { + if (n_gen == 0) { + continue; + } cmd_params_instance instance = { /* .model = */ m, /* .n_prompt = */ 0, @@ -953,17 +959,23 @@ struct sql_printer : public printer { static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) { std::vector tokens(n_batch, llama_token_bos(ctx)); int n_processed = 0; + + llama_set_n_threads(ctx, n_threads, n_threads); + while (n_processed < n_prompt) { int n_tokens = std::min(n_prompt - n_processed, n_batch); - llama_eval(ctx, tokens.data(), n_tokens, n_past + n_processed, n_threads); + llama_eval(ctx, tokens.data(), n_tokens, n_past + n_processed); n_processed += n_tokens; } } static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { llama_token token = llama_token_bos(ctx); + + llama_set_n_threads(ctx, n_threads, n_threads); + for (int i = 0; i < n_gen; i++) { - llama_eval(ctx, &token, 1, n_past + i, n_threads); + llama_eval(ctx, &token, 1, n_past + i); } } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 067f07286..7d9d54b00 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -582,7 +582,7 @@ int main(int argc, char ** argv) { for (int i = 0; i < input_size; i += params.n_batch) { int n_eval = std::min(input_size - i, params.n_batch); - if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) { + if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance)) { LOG_TEE("%s : failed to eval\n", __func__); return 1; } @@ -599,7 +599,7 @@ int main(int argc, char ** argv) { LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd)); - if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) { + if (llama_eval(ctx, &embd[i], n_eval, n_past)) { LOG_TEE("%s : failed to eval\n", __func__); return 1; } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 7ce5ee93f..acb0b9a1e 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -202,7 +202,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const int batch_size = std::min(end - batch_start, n_batch); //fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); - if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { + if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch)) { //fprintf(stderr, "%s : failed to eval\n", __func__); return {tokens, -1, logit_history, prob_history}; } @@ -335,7 +335,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par tokens[batch_start] = llama_token_bos(ctx); } - if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { + if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch)) { fprintf(stderr, "%s : failed to eval\n", __func__); return {tokens, -1, logit_history, prob_history}; } @@ -405,7 +405,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par } static std::vector hellaswag_evaluate_tokens( - llama_context * ctx, const std::vector& tokens, int n_past, int n_batch, int n_vocab, int n_thread + llama_context * ctx, const std::vector& tokens, int n_past, int n_batch, int n_vocab ) { std::vector result; result.reserve(tokens.size() * n_vocab); @@ -413,7 +413,7 @@ static std::vector hellaswag_evaluate_tokens( for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) { size_t n_tokens = tokens.size() - i_chunk * n_batch; n_tokens = std::min(n_tokens, size_t(n_batch)); - if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) { + if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past)) { fprintf(stderr, "%s : failed to eval\n", __func__); return {}; } @@ -554,7 +554,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { query_embd.resize(32); } - auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads); + auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab); if (logits.empty()) { fprintf(stderr, "%s : failed to eval\n", __func__); return; @@ -603,7 +603,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { //} // Evaluate the query - logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab, params.n_threads); + logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab); if (logits.empty()) { fprintf(stderr, "%s : failed to eval\n", __func__); return; diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 3d8adf013..b905eeeaf 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -48,7 +48,7 @@ int main(int argc, char ** argv) { } // evaluate prompt - llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads); + llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past); last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens); n_past += n_prompt_tokens; @@ -85,7 +85,7 @@ int main(int argc, char ** argv) { last_n_tokens_data.push_back(next_token); printf("%s", next_token_str.c_str()); - if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) { + if (llama_eval(ctx, &next_token, 1, n_past)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx); llama_free_model(model); @@ -145,7 +145,7 @@ int main(int argc, char ** argv) { last_n_tokens_data.push_back(next_token); printf("%s", next_token_str.c_str()); - if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) { + if (llama_eval(ctx2, &next_token, 1, n_past)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx2); llama_free_model(model); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 801787a23..3c9ea47fa 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -435,12 +435,11 @@ struct llama_server_context { n_eval = params.n_batch; } - if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads)) + if (llama_eval(ctx, &embd[n_past], n_eval, n_past)) { LOG_ERROR("failed to eval", { {"n_eval", n_eval}, {"n_past", n_past}, - {"n_threads", params.n_threads}, {"embd", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())}, }); has_next_token = false; @@ -1359,7 +1358,7 @@ int main(int argc, char **argv) if (llama.params.n_beams) { // Fill llama.generated_token_probs vector with final beam. llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams, - llama.n_past, llama.n_remain, llama.params.n_threads); + llama.n_past, llama.n_remain); // Translate llama.generated_token_probs to llama.generated_text. append_to_generated_text_from_generated_token_probs(llama); } else { diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index bfaee6fcb..b8f37e1b9 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -72,7 +72,7 @@ int main(int argc, char ** argv) { while (llama_get_kv_cache_token_count(ctx) < n_gen) { // evaluate the transformer - if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx), params.n_threads)) { + if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx))) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index aa904183f..c43cb9a20 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -70,9 +70,9 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt with both models - llama_eval(ctx_tgt, inp.data(), int(inp.size() - 1), 0, params.n_threads); - llama_eval(ctx_tgt, &inp.back(), 1, inp.size() - 1, params.n_threads); - llama_eval(ctx_dft, inp.data(), int(inp.size()), 0, params.n_threads); + llama_eval(ctx_tgt, inp.data(), int(inp.size() - 1), 0); + llama_eval(ctx_tgt, &inp.back(), 1, inp.size() - 1); + llama_eval(ctx_dft, inp.data(), int(inp.size()), 0); const auto t_enc_end = ggml_time_us(); @@ -172,7 +172,7 @@ int main(int argc, char ** argv) { LOG("out of drafted tokens\n"); } - llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads); + llama_eval(ctx_dft, &id, 1, n_past_dft); ++n_past_dft; // heuristic for n_draft @@ -256,7 +256,7 @@ int main(int argc, char ** argv) { } // evaluate the drafted token on the draft model - llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads); + llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur); ++n_past_cur; if (grammar_dft != NULL) { @@ -265,7 +265,7 @@ int main(int argc, char ** argv) { } // evaluate the target model on the drafted tokens - llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt, params.n_threads); + llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt); ++n_past_tgt; // the first token is always proposed by the traget model before the speculation loop diff --git a/llama.cpp b/llama.cpp index 6a025b45e..aa8a36ff3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -966,6 +966,8 @@ struct llama_hparams { struct llama_cparams { uint32_t n_ctx; // context size used during inference uint32_t n_batch; + uint32_t n_threads; // number of threads to use for generation + uint32_t n_threads_batch; // number of threads to use for batch processing float rope_freq_base; float rope_freq_scale; @@ -3673,7 +3675,6 @@ static bool llama_eval_internal( const float * embd, int n_tokens, int n_past, - int n_threads, const char * cgraph_fname) { GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT @@ -3687,14 +3688,14 @@ static bool llama_eval_internal( GGML_ASSERT(n_tokens <= n_batch); GGML_ASSERT(n_past + n_tokens <= n_ctx); + int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; + const int64_t t_start_us = ggml_time_us(); #ifdef GGML_USE_MPI ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); #endif - GGML_ASSERT(n_threads > 0); - const int N = n_tokens; const auto & model = lctx.model; @@ -5249,7 +5250,6 @@ struct llama_beam_search_data { size_t n_beams; int n_past; int n_predict; - int n_threads; std::vector beams; std::vector next_beams; @@ -5259,12 +5259,11 @@ struct llama_beam_search_data { // Used to communicate to/from callback on beams state. std::vector beam_views; - llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads) + llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict) : ctx(ctx) , n_beams(n_beams) , n_past(n_past) , n_predict(n_predict) - , n_threads(n_threads) , beam_views(n_beams) { beams.reserve(n_beams); next_beams.reserve(n_beams); @@ -5301,7 +5300,7 @@ struct llama_beam_search_data { } else { // beam is not at end-of-sentence, so branch with next top_k tokens. if (!beam.tokens.empty()) { - llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past, n_threads); + llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past); } llama_logit_info logit_info(ctx); std::vector next_tokens = logit_info.top_k(n_beams); @@ -5375,7 +5374,7 @@ struct llama_beam_search_data { callback(callback_data, get_beams_state(false)); // Sets common_prefix_length update_beams_from_beam_views(); // Update values (p,eob) that callback may have changed. if (common_prefix_length) { - llama_eval(ctx, beams[0].tokens.data(), common_prefix_length, n_past, n_threads); + llama_eval(ctx, beams[0].tokens.data(), common_prefix_length, n_past); n_past += common_prefix_length; } // Zero-out next_beam probabilities to place them last in following min-heap. @@ -5416,11 +5415,11 @@ struct llama_beam_search_data { void llama_beam_search(llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, - size_t n_beams, int n_past, int n_predict, int n_threads) { + size_t n_beams, int n_past, int n_predict) { assert(ctx); const int64_t t_start_sample_us = ggml_time_us(); - llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict, n_threads); + llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict); beam_search_data.loop(callback, callback_data); @@ -5875,7 +5874,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } } -// TODO: after the GGUF PR, this likely won't work and needs to be updated static int llama_apply_lora_from_file_internal( const struct llama_model & model, const char * path_lora, const char * path_base_model, int n_threads ) { @@ -6178,6 +6176,8 @@ struct llama_context_params llama_context_default_params() { /*.seed =*/ LLAMA_DEFAULT_SEED, /*.n_ctx =*/ 512, /*.n_batch =*/ 512, + /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default + /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.rope_freq_base =*/ 0.0f, /*.rope_freq_scale =*/ 0.0f, /*.low_vram =*/ false, @@ -6296,10 +6296,12 @@ struct llama_context * llama_new_context_with_model( auto & cparams = ctx->cparams; cparams.n_batch = params.n_batch; - cparams.mul_mat_q = params.mul_mat_q; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; cparams.rope_freq_base = params.rope_freq_base == 0 ? hparams.rope_freq_base_train : params.rope_freq_base; cparams.rope_freq_scale = params.rope_freq_scale == 0 ? hparams.rope_freq_scale_train : params.rope_freq_scale; + cparams.n_threads = params.n_threads; + cparams.n_threads_batch = params.n_threads_batch; + cparams.mul_mat_q = params.mul_mat_q; if (params.seed == LLAMA_DEFAULT_SEED) { params.seed = time(NULL); @@ -6951,9 +6953,8 @@ int llama_eval( struct llama_context * ctx, const llama_token * tokens, int n_tokens, - int n_past, - int n_threads) { - if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) { + int n_past) { + if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, nullptr)) { LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -6972,9 +6973,8 @@ int llama_eval_embd( struct llama_context * ctx, const float * embd, int n_tokens, - int n_past, - int n_threads) { - if (!llama_eval_internal(*ctx, nullptr, embd, n_tokens, n_past, n_threads, nullptr)) { + int n_past) { + if (!llama_eval_internal(*ctx, nullptr, embd, n_tokens, n_past, nullptr)) { LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -6989,13 +6989,18 @@ int llama_eval_embd( return 0; } +void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) { + ctx->cparams.n_threads = n_threads; + ctx->cparams.n_threads_batch = n_threads_batch; +} + int llama_eval_export(struct llama_context * ctx, const char * fname) { const int n_batch = 1; const int n_ctx = 512 - n_batch; const std::vector tmp(n_batch, llama_token_bos(ctx)); - if (!llama_eval_internal(*ctx, tmp.data(), nullptr, tmp.size(), n_ctx, 1, fname)) { + if (!llama_eval_internal(*ctx, tmp.data(), nullptr, tmp.size(), n_ctx, fname)) { LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } diff --git a/llama.h b/llama.h index b36dbedc6..abe6e7f0b 100644 --- a/llama.h +++ b/llama.h @@ -140,9 +140,11 @@ extern "C" { }; struct llama_context_params { - uint32_t seed; // RNG seed, -1 for random - uint32_t n_ctx; // text context - uint32_t n_batch; // prompt processing batch size + uint32_t seed; // RNG seed, -1 for random + uint32_t n_ctx; // text context + uint32_t n_batch; // prompt processing batch size + uint32_t n_threads; // number of threads to use for generation + uint32_t n_threads_batch; // number of threads to use for batch processing // ref: https://github.com/ggerganov/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency @@ -264,8 +266,10 @@ extern "C" { // Get a string describing the model type LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size); + // Returns the total size of all the tensors in the model in bytes LLAMA_API uint64_t llama_model_size(const struct llama_model * model); + // Returns the total number of parameters in the model LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); @@ -325,16 +329,17 @@ extern "C" { struct llama_context * ctx, const llama_token * tokens, int n_tokens, - int n_past, - int n_threads); + int n_past); // Same as llama_eval, but use float matrix input directly. LLAMA_API int llama_eval_embd( struct llama_context * ctx, const float * embd, int n_tokens, - int n_past, - int n_threads); + int n_past); + + // Set the number of threads + LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch); // Export a static computation graph for context of 511 and batch size of 1 // NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these @@ -518,8 +523,7 @@ extern "C" { /// @param n_beams Number of beams to use. /// @param n_past Number of tokens already evaluated. /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier. - /// @param n_threads Number of threads as passed to llama_eval(). - LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads); + LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict); // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);