From 1a0de0b781f727162f85a45c47a03275ae0f7f31 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 4 Sep 2024 15:37:51 +0300 Subject: [PATCH] constraint : add name API ggml-ci --- common/sampling.cpp | 8 ++++--- common/sampling.h | 2 +- include/llama.h | 11 ++++----- src/llama-sampling.cpp | 51 +++++++++++++++++++++++++----------------- src/llama.cpp | 22 +++++++++--------- 5 files changed, 52 insertions(+), 42 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 34371bc24..f5edd87a6 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -198,9 +198,11 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context auto & grmr = gsmpl->grmr; auto & smpl = gsmpl->smpl; - auto * cur_p = llama_sampler_get_candidates(smpl); + const auto * logits = llama_get_logits_ith(ctx, idx); - llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx)); + llama_sampler_set_logits(smpl, logits); + + auto * cur_p = llama_sampler_get_candidates(smpl); llama_constraint_apply(bias, cur_p); llama_constraint_apply(pnlt, cur_p); @@ -223,7 +225,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context } // if the token is not valid, sample again, first apply the grammar constraints and then sample - llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx)); + llama_sampler_set_logits(smpl, logits); llama_constraint_apply(bias, cur_p); llama_constraint_apply(pnlt, cur_p); diff --git a/common/sampling.h b/common/sampling.h index a04645a67..bab264937 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -63,7 +63,7 @@ struct gpt_sampler_params { // gpt_sampler extends llama_sampler with additional functionality: // // - grammar support -// - custom sampler logic based on the paramerters +// - custom sampler logic based on the parameters // struct gpt_sampler; diff --git a/include/llama.h b/include/llama.h index 0f08c44c0..bf67f24c7 100644 --- a/include/llama.h +++ b/include/llama.h @@ -386,11 +386,11 @@ extern "C" { double t_start_ms; double t_end_ms; double t_load_ms; - double t_sampling_ms; + double t_sampler_ms; double t_p_eval_ms; double t_eval_ms; - int32_t n_sampling; + int32_t n_sampler; int32_t n_p_eval; int32_t n_eval; }; @@ -1025,8 +1025,7 @@ extern "C" { // user code can implement the interface below in order to create custom llama_constraint struct llama_constraint_i { - // TODO: add name API - + const char * (*name) (const struct llama_constraint * cnstr); // can be NULL void (*accept)( struct llama_constraint * cnstr, llama_token token); // can be NULL void (*apply) ( struct llama_constraint * cnstr, llama_token_data_array * cur_p); // required void (*reset) ( struct llama_constraint * cnstr); // can be NULL @@ -1035,8 +1034,6 @@ extern "C" { // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph //void (*apply_ggml) (struct llama_constraint * cnstr, ...); - - // TODO: add API to get timing stats }; struct llama_constraint { @@ -1044,7 +1041,7 @@ extern "C" { llama_constraint_context_t ctx; }; - LLAMA_API struct llama_constraint * llama_constraint_init_softmax (); + LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void); LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 99e0edfd9..733957fdd 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -428,6 +428,7 @@ void llama_constraint_penalties_impl( // softmax static struct llama_constraint_i llama_constraint_softmax_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "softmax"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * /*cnstr*/, llama_token_data_array * cur_p) { llama_constraint_softmax_impl(cur_p); @@ -454,9 +455,10 @@ struct llama_constraint_context_top_k { }; static struct llama_constraint_i llama_constraint_top_k_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "top-k"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx; + const auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx; llama_constraint_top_k_impl(cur_p, ctx->k, ctx->min_keep); }, /* .reset = */ nullptr, @@ -466,7 +468,7 @@ static struct llama_constraint_i llama_constraint_top_k_i = { }, /* .free = */ [](struct llama_constraint * cnstr) { delete (llama_constraint_context_top_k *) cnstr->ctx; - } + }, }; struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep) { @@ -489,9 +491,10 @@ struct llama_constraint_context_top_p { }; static struct llama_constraint_i llama_constraint_top_p_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "top-p"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_top_p *) cnstr->ctx; + const auto * ctx = (llama_constraint_context_top_p *) cnstr->ctx; llama_constraint_top_p_impl(cur_p, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, @@ -501,7 +504,7 @@ static struct llama_constraint_i llama_constraint_top_p_i = { }, /* .free = */ [](struct llama_constraint * cnstr) { delete (llama_constraint_context_top_p *) cnstr->ctx; - } + }, }; struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_keep) { @@ -524,9 +527,10 @@ struct llama_constraint_context_min_p { }; static struct llama_constraint_i llama_constraint_min_p_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "min-p"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_min_p *) cnstr->ctx; + const auto * ctx = (llama_constraint_context_min_p *) cnstr->ctx; llama_constraint_min_p_impl(cur_p, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, @@ -536,7 +540,7 @@ static struct llama_constraint_i llama_constraint_min_p_i = { }, /* .free = */ [](struct llama_constraint * cnstr) { delete (llama_constraint_context_min_p *) cnstr->ctx; - } + }, }; struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_keep) { @@ -559,9 +563,10 @@ struct llama_constraint_context_tail_free { }; static struct llama_constraint_i llama_constraint_tail_free_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "tail-free"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_tail_free *) cnstr->ctx; + const auto * ctx = (llama_constraint_context_tail_free *) cnstr->ctx; llama_constraint_tail_free_impl(cur_p, ctx->z, ctx->min_keep); }, /* .reset = */ nullptr, @@ -571,7 +576,7 @@ static struct llama_constraint_i llama_constraint_tail_free_i = { }, /* .free = */ [](struct llama_constraint * cnstr) { delete (llama_constraint_context_tail_free *) cnstr->ctx; - } + }, }; struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t min_keep) { @@ -594,9 +599,10 @@ struct llama_constraint_context_typical { }; static struct llama_constraint_i llama_constraint_typical_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "typical"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_typical *) cnstr->ctx; + const auto * ctx = (llama_constraint_context_typical *) cnstr->ctx; llama_constraint_typical_impl(cur_p, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, @@ -606,7 +612,7 @@ static struct llama_constraint_i llama_constraint_typical_i = { }, /* .free = */ [](struct llama_constraint * cnstr) { delete (llama_constraint_context_typical *) cnstr->ctx; - } + }, }; struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min_keep) { @@ -628,9 +634,10 @@ struct llama_constraint_context_temp { }; static struct llama_constraint_i llama_constraint_temp_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "temp"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_temp *) cnstr->ctx; + const auto * ctx = (llama_constraint_context_temp *) cnstr->ctx; llama_constraint_temp_impl(cur_p, ctx->temp); }, /* .reset = */ nullptr, @@ -640,7 +647,7 @@ static struct llama_constraint_i llama_constraint_temp_i = { }, /* .free = */ [](struct llama_constraint * cnstr) { delete (llama_constraint_context_temp *) cnstr->ctx; - } + }, }; struct llama_constraint * llama_constraint_init_temp_impl(float temp) { @@ -663,9 +670,10 @@ struct llama_constraint_context_temp_ext { }; static struct llama_constraint_i llama_constraint_temp_ext_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "temp-ext"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_temp_ext *) cnstr->ctx; + const auto * ctx = (llama_constraint_context_temp_ext *) cnstr->ctx; if (ctx->delta > 0) { const float temp_min = std::max(0.0f, ctx->temp - ctx->delta); const float temp_max = ctx->temp + ctx->delta; @@ -682,7 +690,7 @@ static struct llama_constraint_i llama_constraint_temp_ext_i = { }, /* .free = */ [](struct llama_constraint * cnstr) { delete (llama_constraint_context_temp_ext *) cnstr->ctx; - } + }, }; struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float delta, float exponent) { @@ -708,14 +716,15 @@ struct llama_constraint_context_grammar { }; static struct llama_constraint_i llama_constraint_grammar_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "grammar"; }, /* .accept = */ [](struct llama_constraint * cnstr, llama_token token) { - auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + const auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; if (ctx->grammar) { llama_grammar_accept_impl(*ctx->grammar, token); } }, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + const auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; if (ctx->grammar) { llama_constraint_grammar_impl(cur_p, *ctx->grammar); } @@ -747,14 +756,14 @@ static struct llama_constraint_i llama_constraint_grammar_i = { return result; }, /* .free = */ [](struct llama_constraint * cnstr) { - auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + const auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; if (ctx->grammar) { llama_grammar_free_impl(ctx->grammar); } delete ctx; - } + }, }; struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) { @@ -799,6 +808,7 @@ struct llama_constraint_context_penalties { }; static struct llama_constraint_i llama_constraint_penalties_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "penalties"; }, /* .accept = */ [](struct llama_constraint * cnstr, llama_token token) { auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx; ctx->prev.push_back(token); @@ -855,7 +865,7 @@ static struct llama_constraint_i llama_constraint_penalties_i = { }, /* .free = */ [](struct llama_constraint * cnstr) { delete (llama_constraint_context_penalties *) cnstr->ctx; - } + }, }; struct llama_constraint * llama_constraint_init_penalties_impl(const struct llama_vocab & vocab, int32_t penalty_last_n, float penalty_repeat, float penalty_freq, float penalty_present, bool penalize_nl, bool ignore_eos) { @@ -888,6 +898,7 @@ struct llama_constraint_context_logit_bias { }; static struct llama_constraint_i llama_constraint_logit_bias_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "logit-bias"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { auto * ctx = (llama_constraint_context_logit_bias *) cnstr->ctx; @@ -905,7 +916,7 @@ static struct llama_constraint_i llama_constraint_logit_bias_i = { }, /* .free = */ [](struct llama_constraint * cnstr) { delete (llama_constraint_context_logit_bias *) cnstr->ctx; - } + }, }; struct llama_constraint * llama_constraint_init_logit_bias_impl( diff --git a/src/llama.cpp b/src/llama.cpp index 28f406ce2..712d9cfb5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20609,7 +20609,7 @@ int32_t llama_chat_apply_template( // sampling // -struct llama_constraint * llama_constraint_init_softmax() { +struct llama_constraint * llama_constraint_init_softmax(void) { return llama_constraint_init_softmax_impl(); } @@ -20849,22 +20849,22 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int void llama_print_timings(struct llama_context * ctx, struct llama_sampler * smpl) { const llama_timings timings = { - /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, - /*.t_end_ms =*/ 1.00 * ggml_time_ms(), - /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, - /*.t_sampling_ms =*/ 1e-3 * (smpl ? smpl->t_sample_us : 0.0), - /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, - /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, + /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, + /*.t_end_ms =*/ 1.00 * ggml_time_ms(), + /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, + /*.t_sampler_ms =*/ 1e-3 * (smpl ? smpl->t_sample_us : 0.0), + /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, + /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, - /*.n_sampling =*/ std::max(0, smpl ? smpl->n_sample : 0), - /*.n_p_eval =*/ std::max(0, ctx->n_p_eval), - /*.n_eval =*/ std::max(1, ctx->n_eval), + /*.n_sampler =*/ std::max(0, smpl ? smpl->n_sample : 0), + /*.n_p_eval =*/ std::max(0, ctx->n_p_eval), + /*.n_eval =*/ std::max(1, ctx->n_eval), }; LLAMA_LOG_INFO("\n"); LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms); LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, timings.t_sampling_ms, timings.n_sampling, timings.t_sampling_ms / timings.n_sampling, 1e3 / timings.t_sampling_ms * timings.n_sampling); + __func__, timings.t_sampler_ms, timings.n_sampler, timings.t_sampler_ms / timings.n_sampler, 1e3 / timings.t_sampler_ms * timings.n_sampler); LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval); LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",