move n_threads to llama_context_params, add n_threads_batch
This commit is contained in:
parent
96f6dcdeae
commit
7f95379295
14 changed files with 94 additions and 59 deletions
|
@ -129,6 +129,15 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
if (params.n_threads <= 0) {
|
||||
params.n_threads = std::thread::hardware_concurrency();
|
||||
}
|
||||
} else if (arg == "-tb" || arg == "--threads-batch") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.n_threads_batch = std::stoi(argv[i]);
|
||||
if (params.n_threads_batch <= 0) {
|
||||
params.n_threads_batch = std::thread::hardware_concurrency();
|
||||
}
|
||||
} else if (arg == "-p" || arg == "--prompt") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -606,7 +615,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
printf(" (can be specified more than once for multiple prompts).\n");
|
||||
printf(" --color colorise output to distinguish prompt and user input from generations\n");
|
||||
printf(" -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
|
||||
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
||||
printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads);
|
||||
printf(" -tb N, --threads-batch N\n");
|
||||
printf(" number of threads to use during batch evaluation and prompt processing (default: same as --threads)\n");
|
||||
printf(" -p PROMPT, --prompt PROMPT\n");
|
||||
printf(" prompt to start generation with (default: empty)\n");
|
||||
printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
|
||||
|
@ -742,6 +753,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
|||
|
||||
cparams.n_ctx = params.n_ctx;
|
||||
cparams.n_batch = params.n_batch;
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
||||
cparams.low_vram = params.low_vram;
|
||||
cparams.mul_mat_q = params.mul_mat_q;
|
||||
cparams.seed = params.seed;
|
||||
|
@ -756,7 +769,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
|||
|
||||
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
|
||||
auto mparams = llama_model_params_from_gpt_params(params);
|
||||
auto cparams = llama_context_params_from_gpt_params(params);
|
||||
|
||||
llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams);
|
||||
if (model == NULL) {
|
||||
|
@ -764,6 +776,8 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
|||
return std::make_tuple(nullptr, nullptr);
|
||||
}
|
||||
|
||||
auto cparams = llama_context_params_from_gpt_params(params);
|
||||
|
||||
llama_context * lctx = llama_new_context_with_model(model, cparams);
|
||||
if (lctx == NULL) {
|
||||
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
|
||||
|
@ -792,7 +806,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
|||
LOG("warming up the model with an empty run\n");
|
||||
|
||||
const std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
|
||||
llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads);
|
||||
llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0);
|
||||
llama_reset_timings(lctx);
|
||||
}
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@ int32_t get_num_physical_cores();
|
|||
struct gpt_params {
|
||||
uint32_t seed = -1; // RNG seed
|
||||
int32_t n_threads = get_num_physical_cores();
|
||||
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
int32_t n_ctx = 512; // context size
|
||||
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
||||
|
|
|
@ -159,7 +159,7 @@ int main(int argc, char ** argv)
|
|||
std::cout << std::flush;
|
||||
|
||||
int n_past = llama_get_kv_cache_token_count(ctx);
|
||||
if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads))
|
||||
if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past))
|
||||
{
|
||||
fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );
|
||||
return 1;
|
||||
|
@ -169,7 +169,7 @@ int main(int argc, char ** argv)
|
|||
beam_search_callback_data callback_data{ctx, {}};
|
||||
size_t const beam_width = static_cast<size_t>(params.n_beams);
|
||||
int const n_predict = 256;
|
||||
llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict, params.n_threads);
|
||||
llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict);
|
||||
|
||||
std::cout << "\n\n";
|
||||
for (llama_token const token_id : callback_data.response) {
|
||||
|
|
|
@ -80,7 +80,7 @@ bool eval_float(void * model, float * input, int N){
|
|||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) {
|
||||
if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
@ -101,7 +101,7 @@ bool eval_tokens(void * model, std::vector<llama_token> tokens) {
|
|||
if (n_eval > params.n_batch) {
|
||||
n_eval = params.n_batch;
|
||||
}
|
||||
if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) {
|
||||
if (llama_eval(ctx, &tokens[i], n_eval, n_past)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -80,7 +80,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
while (!embd_inp.empty()) {
|
||||
int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
|
||||
if (llama_eval(ctx, embd_inp.data(), n_tokens, n_past, params.n_threads)) {
|
||||
if (llama_eval(ctx, embd_inp.data(), n_tokens, n_past)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
|
|
@ -432,6 +432,9 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
|||
for (const auto & mmq : params.mul_mat_q)
|
||||
for (const auto & nt : params.n_threads) {
|
||||
for (const auto & n_prompt : params.n_prompt) {
|
||||
if (n_prompt == 0) {
|
||||
continue;
|
||||
}
|
||||
cmd_params_instance instance = {
|
||||
/* .model = */ m,
|
||||
/* .n_prompt = */ n_prompt,
|
||||
|
@ -449,6 +452,9 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
|||
}
|
||||
|
||||
for (const auto & n_gen : params.n_gen) {
|
||||
if (n_gen == 0) {
|
||||
continue;
|
||||
}
|
||||
cmd_params_instance instance = {
|
||||
/* .model = */ m,
|
||||
/* .n_prompt = */ 0,
|
||||
|
@ -953,17 +959,23 @@ struct sql_printer : public printer {
|
|||
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
|
||||
std::vector<llama_token> tokens(n_batch, llama_token_bos(ctx));
|
||||
int n_processed = 0;
|
||||
|
||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||
|
||||
while (n_processed < n_prompt) {
|
||||
int n_tokens = std::min(n_prompt - n_processed, n_batch);
|
||||
llama_eval(ctx, tokens.data(), n_tokens, n_past + n_processed, n_threads);
|
||||
llama_eval(ctx, tokens.data(), n_tokens, n_past + n_processed);
|
||||
n_processed += n_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
|
||||
llama_token token = llama_token_bos(ctx);
|
||||
|
||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||
|
||||
for (int i = 0; i < n_gen; i++) {
|
||||
llama_eval(ctx, &token, 1, n_past + i, n_threads);
|
||||
llama_eval(ctx, &token, 1, n_past + i);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -582,7 +582,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
for (int i = 0; i < input_size; i += params.n_batch) {
|
||||
int n_eval = std::min(input_size - i, params.n_batch);
|
||||
if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) {
|
||||
if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance)) {
|
||||
LOG_TEE("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
@ -599,7 +599,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
|
||||
|
||||
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) {
|
||||
if (llama_eval(ctx, &embd[i], n_eval, n_past)) {
|
||||
LOG_TEE("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
|
|
@ -202,7 +202,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
|||
const int batch_size = std::min(end - batch_start, n_batch);
|
||||
|
||||
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
||||
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
|
||||
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch)) {
|
||||
//fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return {tokens, -1, logit_history, prob_history};
|
||||
}
|
||||
|
@ -335,7 +335,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
tokens[batch_start] = llama_token_bos(ctx);
|
||||
}
|
||||
|
||||
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
|
||||
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return {tokens, -1, logit_history, prob_history};
|
||||
}
|
||||
|
@ -405,7 +405,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
}
|
||||
|
||||
static std::vector<float> hellaswag_evaluate_tokens(
|
||||
llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch, int n_vocab, int n_thread
|
||||
llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch, int n_vocab
|
||||
) {
|
||||
std::vector<float> result;
|
||||
result.reserve(tokens.size() * n_vocab);
|
||||
|
@ -413,7 +413,7 @@ static std::vector<float> hellaswag_evaluate_tokens(
|
|||
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
|
||||
size_t n_tokens = tokens.size() - i_chunk * n_batch;
|
||||
n_tokens = std::min(n_tokens, size_t(n_batch));
|
||||
if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) {
|
||||
if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return {};
|
||||
}
|
||||
|
@ -554,7 +554,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||
query_embd.resize(32);
|
||||
}
|
||||
|
||||
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
|
||||
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab);
|
||||
if (logits.empty()) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return;
|
||||
|
@ -603,7 +603,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||
//}
|
||||
|
||||
// Evaluate the query
|
||||
logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab, params.n_threads);
|
||||
logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab);
|
||||
if (logits.empty()) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return;
|
||||
|
|
|
@ -48,7 +48,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// evaluate prompt
|
||||
llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads);
|
||||
llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past);
|
||||
|
||||
last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
|
||||
n_past += n_prompt_tokens;
|
||||
|
@ -85,7 +85,7 @@ int main(int argc, char ** argv) {
|
|||
last_n_tokens_data.push_back(next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) {
|
||||
if (llama_eval(ctx, &next_token, 1, n_past)) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
@ -145,7 +145,7 @@ int main(int argc, char ** argv) {
|
|||
last_n_tokens_data.push_back(next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) {
|
||||
if (llama_eval(ctx2, &next_token, 1, n_past)) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_free(ctx2);
|
||||
llama_free_model(model);
|
||||
|
|
|
@ -435,12 +435,11 @@ struct llama_server_context
|
|||
{
|
||||
n_eval = params.n_batch;
|
||||
}
|
||||
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads))
|
||||
if (llama_eval(ctx, &embd[n_past], n_eval, n_past))
|
||||
{
|
||||
LOG_ERROR("failed to eval", {
|
||||
{"n_eval", n_eval},
|
||||
{"n_past", n_past},
|
||||
{"n_threads", params.n_threads},
|
||||
{"embd", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())},
|
||||
});
|
||||
has_next_token = false;
|
||||
|
@ -1359,7 +1358,7 @@ int main(int argc, char **argv)
|
|||
if (llama.params.n_beams) {
|
||||
// Fill llama.generated_token_probs vector with final beam.
|
||||
llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams,
|
||||
llama.n_past, llama.n_remain, llama.params.n_threads);
|
||||
llama.n_past, llama.n_remain);
|
||||
// Translate llama.generated_token_probs to llama.generated_text.
|
||||
append_to_generated_text_from_generated_token_probs(llama);
|
||||
} else {
|
||||
|
|
|
@ -72,7 +72,7 @@ int main(int argc, char ** argv) {
|
|||
while (llama_get_kv_cache_token_count(ctx) < n_gen) {
|
||||
// evaluate the transformer
|
||||
|
||||
if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx), params.n_threads)) {
|
||||
if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx))) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
|
|
@ -70,9 +70,9 @@ int main(int argc, char ** argv) {
|
|||
const auto t_enc_start = ggml_time_us();
|
||||
|
||||
// eval the prompt with both models
|
||||
llama_eval(ctx_tgt, inp.data(), int(inp.size() - 1), 0, params.n_threads);
|
||||
llama_eval(ctx_tgt, &inp.back(), 1, inp.size() - 1, params.n_threads);
|
||||
llama_eval(ctx_dft, inp.data(), int(inp.size()), 0, params.n_threads);
|
||||
llama_eval(ctx_tgt, inp.data(), int(inp.size() - 1), 0);
|
||||
llama_eval(ctx_tgt, &inp.back(), 1, inp.size() - 1);
|
||||
llama_eval(ctx_dft, inp.data(), int(inp.size()), 0);
|
||||
|
||||
const auto t_enc_end = ggml_time_us();
|
||||
|
||||
|
@ -172,7 +172,7 @@ int main(int argc, char ** argv) {
|
|||
LOG("out of drafted tokens\n");
|
||||
}
|
||||
|
||||
llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads);
|
||||
llama_eval(ctx_dft, &id, 1, n_past_dft);
|
||||
++n_past_dft;
|
||||
|
||||
// heuristic for n_draft
|
||||
|
@ -256,7 +256,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// evaluate the drafted token on the draft model
|
||||
llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads);
|
||||
llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur);
|
||||
++n_past_cur;
|
||||
|
||||
if (grammar_dft != NULL) {
|
||||
|
@ -265,7 +265,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// evaluate the target model on the drafted tokens
|
||||
llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt, params.n_threads);
|
||||
llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt);
|
||||
++n_past_tgt;
|
||||
|
||||
// the first token is always proposed by the traget model before the speculation loop
|
||||
|
|
43
llama.cpp
43
llama.cpp
|
@ -966,6 +966,8 @@ struct llama_hparams {
|
|||
struct llama_cparams {
|
||||
uint32_t n_ctx; // context size used during inference
|
||||
uint32_t n_batch;
|
||||
uint32_t n_threads; // number of threads to use for generation
|
||||
uint32_t n_threads_batch; // number of threads to use for batch processing
|
||||
|
||||
float rope_freq_base;
|
||||
float rope_freq_scale;
|
||||
|
@ -3673,7 +3675,6 @@ static bool llama_eval_internal(
|
|||
const float * embd,
|
||||
int n_tokens,
|
||||
int n_past,
|
||||
int n_threads,
|
||||
const char * cgraph_fname) {
|
||||
|
||||
GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT
|
||||
|
@ -3687,14 +3688,14 @@ static bool llama_eval_internal(
|
|||
GGML_ASSERT(n_tokens <= n_batch);
|
||||
GGML_ASSERT(n_past + n_tokens <= n_ctx);
|
||||
|
||||
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
||||
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
#ifdef GGML_USE_MPI
|
||||
ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads);
|
||||
#endif
|
||||
|
||||
GGML_ASSERT(n_threads > 0);
|
||||
|
||||
const int N = n_tokens;
|
||||
|
||||
const auto & model = lctx.model;
|
||||
|
@ -5249,7 +5250,6 @@ struct llama_beam_search_data {
|
|||
size_t n_beams;
|
||||
int n_past;
|
||||
int n_predict;
|
||||
int n_threads;
|
||||
std::vector<llama_beam> beams;
|
||||
std::vector<llama_beam> next_beams;
|
||||
|
||||
|
@ -5259,12 +5259,11 @@ struct llama_beam_search_data {
|
|||
// Used to communicate to/from callback on beams state.
|
||||
std::vector<llama_beam_view> beam_views;
|
||||
|
||||
llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads)
|
||||
llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict)
|
||||
: ctx(ctx)
|
||||
, n_beams(n_beams)
|
||||
, n_past(n_past)
|
||||
, n_predict(n_predict)
|
||||
, n_threads(n_threads)
|
||||
, beam_views(n_beams) {
|
||||
beams.reserve(n_beams);
|
||||
next_beams.reserve(n_beams);
|
||||
|
@ -5301,7 +5300,7 @@ struct llama_beam_search_data {
|
|||
} else {
|
||||
// beam is not at end-of-sentence, so branch with next top_k tokens.
|
||||
if (!beam.tokens.empty()) {
|
||||
llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past, n_threads);
|
||||
llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past);
|
||||
}
|
||||
llama_logit_info logit_info(ctx);
|
||||
std::vector<llama_token_data> next_tokens = logit_info.top_k(n_beams);
|
||||
|
@ -5375,7 +5374,7 @@ struct llama_beam_search_data {
|
|||
callback(callback_data, get_beams_state(false)); // Sets common_prefix_length
|
||||
update_beams_from_beam_views(); // Update values (p,eob) that callback may have changed.
|
||||
if (common_prefix_length) {
|
||||
llama_eval(ctx, beams[0].tokens.data(), common_prefix_length, n_past, n_threads);
|
||||
llama_eval(ctx, beams[0].tokens.data(), common_prefix_length, n_past);
|
||||
n_past += common_prefix_length;
|
||||
}
|
||||
// Zero-out next_beam probabilities to place them last in following min-heap.
|
||||
|
@ -5416,11 +5415,11 @@ struct llama_beam_search_data {
|
|||
|
||||
void llama_beam_search(llama_context * ctx,
|
||||
llama_beam_search_callback_fn_t callback, void * callback_data,
|
||||
size_t n_beams, int n_past, int n_predict, int n_threads) {
|
||||
size_t n_beams, int n_past, int n_predict) {
|
||||
assert(ctx);
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict, n_threads);
|
||||
llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict);
|
||||
|
||||
beam_search_data.loop(callback, callback_data);
|
||||
|
||||
|
@ -5875,7 +5874,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: after the GGUF PR, this likely won't work and needs to be updated
|
||||
static int llama_apply_lora_from_file_internal(
|
||||
const struct llama_model & model, const char * path_lora, const char * path_base_model, int n_threads
|
||||
) {
|
||||
|
@ -6178,6 +6176,8 @@ struct llama_context_params llama_context_default_params() {
|
|||
/*.seed =*/ LLAMA_DEFAULT_SEED,
|
||||
/*.n_ctx =*/ 512,
|
||||
/*.n_batch =*/ 512,
|
||||
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
|
||||
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
|
||||
/*.rope_freq_base =*/ 0.0f,
|
||||
/*.rope_freq_scale =*/ 0.0f,
|
||||
/*.low_vram =*/ false,
|
||||
|
@ -6296,10 +6296,12 @@ struct llama_context * llama_new_context_with_model(
|
|||
auto & cparams = ctx->cparams;
|
||||
|
||||
cparams.n_batch = params.n_batch;
|
||||
cparams.mul_mat_q = params.mul_mat_q;
|
||||
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
||||
cparams.rope_freq_base = params.rope_freq_base == 0 ? hparams.rope_freq_base_train : params.rope_freq_base;
|
||||
cparams.rope_freq_scale = params.rope_freq_scale == 0 ? hparams.rope_freq_scale_train : params.rope_freq_scale;
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch;
|
||||
cparams.mul_mat_q = params.mul_mat_q;
|
||||
|
||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||
params.seed = time(NULL);
|
||||
|
@ -6951,9 +6953,8 @@ int llama_eval(
|
|||
struct llama_context * ctx,
|
||||
const llama_token * tokens,
|
||||
int n_tokens,
|
||||
int n_past,
|
||||
int n_threads) {
|
||||
if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) {
|
||||
int n_past) {
|
||||
if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, nullptr)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
@ -6972,9 +6973,8 @@ int llama_eval_embd(
|
|||
struct llama_context * ctx,
|
||||
const float * embd,
|
||||
int n_tokens,
|
||||
int n_past,
|
||||
int n_threads) {
|
||||
if (!llama_eval_internal(*ctx, nullptr, embd, n_tokens, n_past, n_threads, nullptr)) {
|
||||
int n_past) {
|
||||
if (!llama_eval_internal(*ctx, nullptr, embd, n_tokens, n_past, nullptr)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
@ -6989,13 +6989,18 @@ int llama_eval_embd(
|
|||
return 0;
|
||||
}
|
||||
|
||||
void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) {
|
||||
ctx->cparams.n_threads = n_threads;
|
||||
ctx->cparams.n_threads_batch = n_threads_batch;
|
||||
}
|
||||
|
||||
int llama_eval_export(struct llama_context * ctx, const char * fname) {
|
||||
const int n_batch = 1;
|
||||
const int n_ctx = 512 - n_batch;
|
||||
|
||||
const std::vector<llama_token> tmp(n_batch, llama_token_bos(ctx));
|
||||
|
||||
if (!llama_eval_internal(*ctx, tmp.data(), nullptr, tmp.size(), n_ctx, 1, fname)) {
|
||||
if (!llama_eval_internal(*ctx, tmp.data(), nullptr, tmp.size(), n_ctx, fname)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
|
16
llama.h
16
llama.h
|
@ -143,6 +143,8 @@ extern "C" {
|
|||
uint32_t seed; // RNG seed, -1 for random
|
||||
uint32_t n_ctx; // text context
|
||||
uint32_t n_batch; // prompt processing batch size
|
||||
uint32_t n_threads; // number of threads to use for generation
|
||||
uint32_t n_threads_batch; // number of threads to use for batch processing
|
||||
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
||||
float rope_freq_base; // RoPE base frequency
|
||||
|
@ -264,8 +266,10 @@ extern "C" {
|
|||
|
||||
// Get a string describing the model type
|
||||
LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
|
||||
|
||||
// Returns the total size of all the tensors in the model in bytes
|
||||
LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
|
||||
|
||||
// Returns the total number of parameters in the model
|
||||
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
|
||||
|
||||
|
@ -325,16 +329,17 @@ extern "C" {
|
|||
struct llama_context * ctx,
|
||||
const llama_token * tokens,
|
||||
int n_tokens,
|
||||
int n_past,
|
||||
int n_threads);
|
||||
int n_past);
|
||||
|
||||
// Same as llama_eval, but use float matrix input directly.
|
||||
LLAMA_API int llama_eval_embd(
|
||||
struct llama_context * ctx,
|
||||
const float * embd,
|
||||
int n_tokens,
|
||||
int n_past,
|
||||
int n_threads);
|
||||
int n_past);
|
||||
|
||||
// Set the number of threads
|
||||
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
|
||||
|
||||
// Export a static computation graph for context of 511 and batch size of 1
|
||||
// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these
|
||||
|
@ -518,8 +523,7 @@ extern "C" {
|
|||
/// @param n_beams Number of beams to use.
|
||||
/// @param n_past Number of tokens already evaluated.
|
||||
/// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
|
||||
/// @param n_threads Number of threads as passed to llama_eval().
|
||||
LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads);
|
||||
LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict);
|
||||
|
||||
// Performance information
|
||||
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue