From b448c753b939a515fa0afa74ad05e89a8efca1e9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 6 Sep 2024 12:39:43 +0300 Subject: [PATCH] sampling : remove redundant indirection calls ggml-ci --- include/llama.h | 12 +- src/llama-sampling.cpp | 341 ++++++++++++++++++++++------------------- src/llama-sampling.h | 32 +--- src/llama.cpp | 112 +------------- 4 files changed, 193 insertions(+), 304 deletions(-) diff --git a/include/llama.h b/include/llama.h index 02c565a3d..29c216f2d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -992,9 +992,9 @@ extern "C" { // // Sampling API // - // In the future, it will be utilized offload the sampling to the backends (e.g. GPU). + // In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). // - // TODO: in the future, the entire API should be changed to accept llama_vocab, instead of llama_model + // TODO: in the future, the entire API that uses llama_model should start using llama_vocab typedef void * llama_sampler_context_t; @@ -1045,16 +1045,16 @@ extern "C" { LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, int32_t min_keep); + LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep); /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 - LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, int32_t min_keep); + LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep); /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. - LLAMA_API struct llama_sampler * llama_sampler_init_tail_free (float z, int32_t min_keep); + LLAMA_API struct llama_sampler * llama_sampler_init_tail_free (float z, size_t min_keep); /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. - LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, int32_t min_keep); + LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep); LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t); /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772. diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 0084fe0b7..92d39222d 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -432,6 +432,167 @@ void llama_sampler_penalties_impl( cur_p->sorted = false; } +// llama_sampler API + +const char * llama_sampler_name(const struct llama_sampler * smpl) { + if (!smpl->iface) { + return "(null)"; + } + + return smpl->iface->name(smpl); +} + +void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { + if (smpl->iface->accept) { + smpl->iface->accept(smpl, token); + } +} + +void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) { + GGML_ASSERT(smpl->iface->apply); + smpl->iface->apply(smpl, cur_p); +} + +void llama_sampler_reset(struct llama_sampler * smpl) { + if (smpl->iface->reset) { + smpl->iface->reset(smpl); + } +} + +struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) { + if (smpl->iface->clone) { + return smpl->iface->clone(smpl); + } + + if (smpl->ctx == nullptr) { + return new llama_sampler { + /* .iface = */ smpl->iface, + /* .ctx = */ nullptr, + }; + } + + GGML_ABORT("the sampler does not support cloning"); +} + +void llama_sampler_free(struct llama_sampler * smpl) { + if (smpl == nullptr) { + return; + } + + if (smpl->iface->free) { + smpl->iface->free(smpl); + } + + delete smpl; +} + +llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { + const auto * logits = llama_get_logits_ith(ctx, idx); + + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + + // TODO: do not allocate each time + std::vector cur(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + + llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; + + llama_sampler_apply(smpl, &cur_p); + + return cur_p.data[cur_p.selected].id; +} + +// sampler chain + +static struct llama_sampler_i llama_sampler_chain_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; }, + /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + time_meas tm(chain->t_sample_us, chain->params.no_timing); + + for (auto * smpl : chain->samplers) { + llama_sampler_accept(smpl, token); + } + + chain->n_sample++; + }, + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + time_meas tm(chain->t_sample_us, chain->params.no_timing); + + for (auto * smpl : chain->samplers) { + llama_sampler_apply(smpl, cur_p); + } + }, + /* .reset = */ [](struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + llama_sampler_reset(smpl); + } + + chain->t_sample_us = 0; + chain->n_sample = 0; + }, + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * chain_src = (const llama_sampler_chain *) smpl->ctx; + + auto * result = llama_sampler_chain_init(chain_src->params); + + for (auto * smpl : chain_src->samplers) { + llama_sampler_chain_add(result, llama_sampler_clone(smpl)); + } + + return result; + }, + /* .free = */ [](struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + llama_sampler_free(smpl); + } + + delete chain; + }, +}; + +struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { + return new llama_sampler { + /* .iface = */ &llama_sampler_chain_i, + /* .ctx = */ new llama_sampler_chain { + /* .params = */ params, + /* .samplers = */ {}, + /* .t_sample_us = */ 0, + /* .n_sample = */ 0, + }, + }; +} + +void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { + auto * p = (llama_sampler_chain *) chain->ctx; + p->samplers.push_back(smpl); +} + +struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) { + const auto * p = (const llama_sampler_chain *) chain->ctx; + + if (i < 0 || i >= (int32_t) p->samplers.size()) { + return nullptr; + } + + return p->samplers[i]; +} + +int llama_sampler_chain_n(const struct llama_sampler * chain) { + const auto * p = (const llama_sampler_chain *) chain->ctx; + + return p->samplers.size(); +} + // // samplers // @@ -454,7 +615,7 @@ static struct llama_sampler_i llama_sampler_greedy_i = { /* .free = */ nullptr, }; -struct llama_sampler * llama_sampler_init_greedy_impl() { +struct llama_sampler * llama_sampler_init_greedy() { return new llama_sampler { /* .iface = */ &llama_sampler_greedy_i, /* .ctx = */ nullptr, @@ -481,14 +642,14 @@ static struct llama_sampler_i llama_sampler_dist_i = { /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_dist *) smpl->ctx; - return llama_sampler_init_dist_impl(ctx->seed); + return llama_sampler_init_dist(ctx->seed); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_dist *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_dist_impl(uint32_t seed) { +struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { return new llama_sampler { /* .iface = */ &llama_sampler_dist_i, /* .ctx = */ new llama_sampler_context_dist { @@ -512,7 +673,7 @@ static struct llama_sampler_i llama_sampler_softmax_i = { /* .free = */ nullptr, }; -struct llama_sampler * llama_sampler_init_softmax_impl() { +struct llama_sampler * llama_sampler_init_softmax() { return new llama_sampler { /* .iface = */ &llama_sampler_softmax_i, /* .ctx = */ nullptr, @@ -535,14 +696,14 @@ static struct llama_sampler_i llama_sampler_top_k_i = { /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_top_k *) smpl->ctx; - return llama_sampler_init_top_k_impl(ctx->k); + return llama_sampler_init_top_k(ctx->k); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_top_k *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_top_k_impl(int32_t k) { +struct llama_sampler * llama_sampler_init_top_k(int32_t k) { return new llama_sampler { /* .iface = */ &llama_sampler_top_k_i, /* .ctx = */ new llama_sampler_context_top_k { @@ -568,14 +729,14 @@ static struct llama_sampler_i llama_sampler_top_p_i = { /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_top_p *) smpl->ctx; - return llama_sampler_init_top_p_impl(ctx->p, ctx->min_keep); + return llama_sampler_init_top_p(ctx->p, ctx->min_keep); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_top_p *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_top_p_impl(float p, size_t min_keep) { +struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { return new llama_sampler { /* .iface = */ &llama_sampler_top_p_i, /* .ctx = */ new llama_sampler_context_top_p { @@ -602,14 +763,14 @@ static struct llama_sampler_i llama_sampler_min_p_i = { /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_min_p *) smpl->ctx; - return llama_sampler_init_min_p_impl(ctx->p, ctx->min_keep); + return llama_sampler_init_min_p(ctx->p, ctx->min_keep); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_min_p *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_min_p_impl(float p, size_t min_keep) { +struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { return new llama_sampler { /* .iface = */ &llama_sampler_min_p_i, /* .ctx = */ new llama_sampler_context_min_p { @@ -636,14 +797,14 @@ static struct llama_sampler_i llama_sampler_tail_free_i = { /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_tail_free *) smpl->ctx; - return llama_sampler_init_tail_free_impl(ctx->z, ctx->min_keep); + return llama_sampler_init_tail_free(ctx->z, ctx->min_keep); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_tail_free *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_tail_free_impl(float z, size_t min_keep) { +struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) { return new llama_sampler { /* .iface = */ &llama_sampler_tail_free_i, /* .ctx = */ new llama_sampler_context_tail_free { @@ -670,14 +831,14 @@ static struct llama_sampler_i llama_sampler_typical_i = { /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_typical *) smpl->ctx; - return llama_sampler_init_typical_impl(ctx->p, ctx->min_keep); + return llama_sampler_init_typical(ctx->p, ctx->min_keep); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_typical *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_typical_impl(float p, size_t min_keep) { +struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { return new llama_sampler { /* .iface = */ &llama_sampler_typical_i, /* .ctx = */ new llama_sampler_context_typical { @@ -703,14 +864,14 @@ static struct llama_sampler_i llama_sampler_temp_i = { /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_temp *) smpl->ctx; - return llama_sampler_init_temp_impl(ctx->temp); + return llama_sampler_init_temp(ctx->temp); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_temp *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_temp_impl(float temp) { +struct llama_sampler * llama_sampler_init_temp(float temp) { return new llama_sampler { /* .iface = */ &llama_sampler_temp_i, /* .ctx = */ new llama_sampler_context_temp { @@ -744,14 +905,14 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = { /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_temp_ext *) smpl->ctx; - return llama_sampler_init_temp_ext_impl(ctx->temp, ctx->delta, ctx->exponent); + return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_temp_ext *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_temp_ext_impl(float temp, float delta, float exponent) { +struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { return new llama_sampler { /* .iface = */ &llama_sampler_temp_ext_i, /* .ctx = */ new llama_sampler_context_temp_ext { @@ -900,14 +1061,14 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = { }, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_mirostat_v2 *) smpl->ctx; - return llama_sampler_init_mirostat_v2_impl(ctx->seed, ctx->tau, ctx->eta); + return llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_mirostat_v2 *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_mirostat_v2_impl(uint32_t seed, float tau, float eta) { +struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { return new llama_sampler { /* .iface = */ &llama_sampler_mirostat_v2_i, /* .ctx = */ new llama_sampler_context_mirostat_v2 { @@ -1146,143 +1307,3 @@ struct llama_sampler * llama_sampler_init_logit_bias_impl( }, }; } - -// sampler chain - -static struct llama_sampler_i llama_sampler_chain_i = { - /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; }, - /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { - auto * chain = (llama_sampler_chain *) smpl->ctx; - - time_meas tm(chain->t_sample_us, chain->params.no_timing); - - for (auto * smpl : chain->samplers) { - llama_sampler_accept_impl(*smpl, token); - } - - chain->n_sample++; - }, - /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - auto * chain = (llama_sampler_chain *) smpl->ctx; - - time_meas tm(chain->t_sample_us, chain->params.no_timing); - - for (auto * smpl : chain->samplers) { - llama_sampler_apply_impl(*smpl, cur_p); - } - }, - /* .reset = */ [](struct llama_sampler * smpl) { - auto * chain = (llama_sampler_chain *) smpl->ctx; - - for (auto * smpl : chain->samplers) { - llama_sampler_reset_impl(*smpl); - } - - chain->t_sample_us = 0; - chain->n_sample = 0; - }, - /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * chain_src = (const llama_sampler_chain *) smpl->ctx; - - auto * result = llama_sampler_chain_init_impl(chain_src->params); - - auto * chain_dst = (llama_sampler_chain *) result->ctx; - for (auto * smpl : chain_src->samplers) { - llama_sampler_chain_add_impl(*chain_dst, llama_sampler_clone_impl(*smpl)); - } - - return result; - }, - /* .free = */ [](struct llama_sampler * smpl) { - auto * chain = (llama_sampler_chain *) smpl->ctx; - - for (auto * smpl : chain->samplers) { - llama_sampler_free_impl(smpl); - } - - delete chain; - }, -}; - -struct llama_sampler * llama_sampler_chain_init_impl(struct llama_sampler_chain_params params) { - return new llama_sampler { - /* .iface = */ &llama_sampler_chain_i, - /* .ctx = */ new llama_sampler_chain { - /* .params = */ params, - /* .samplers = */ {}, - /* .t_sample_us = */ 0, - /* .n_sample = */ 0, - }, - }; -} - -void llama_sampler_chain_add_impl(struct llama_sampler_chain & chain, struct llama_sampler * smpl) { - chain.samplers.push_back(smpl); -} - -struct llama_sampler * llama_sampler_chain_get_impl(const struct llama_sampler_chain & chain, int32_t i) { - if (i < 0 || i >= (int32_t) chain.samplers.size()) { - return nullptr; - } - - return chain.samplers[i]; -} - -int llama_sampler_chain_n_impl(const struct llama_sampler_chain & chain) { - return chain.samplers.size(); -} - - -//////////////////////////////////////// - -const char * llama_sampler_name_impl(const struct llama_sampler & smpl) { - if (!smpl.iface) { - return "(null)"; - } - - return smpl.iface->name(&smpl); -} - -void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) { - if (smpl.iface->accept) { - smpl.iface->accept(&smpl, token); - } -} - -void llama_sampler_apply_impl(struct llama_sampler & smpl, struct llama_token_data_array * cur_p) { - GGML_ASSERT(smpl.iface->apply); - smpl.iface->apply(&smpl, cur_p); -} - -void llama_sampler_reset_impl(struct llama_sampler & smpl) { - if (smpl.iface->reset) { - smpl.iface->reset(&smpl); - } -} - -struct llama_sampler * llama_sampler_clone_impl(const struct llama_sampler & smpl) { - if (smpl.iface->clone) { - return smpl.iface->clone(&smpl); - } - - if (smpl.ctx == nullptr) { - return new llama_sampler { - /* .iface = */ smpl.iface, - /* .ctx = */ nullptr, - }; - } - - GGML_ABORT("the sampler does not support cloning"); -} - -void llama_sampler_free_impl(struct llama_sampler * smpl) { - if (smpl == nullptr) { - return; - } - - if (smpl->iface->free) { - smpl->iface->free(smpl); - } - - delete smpl; -} diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 0088060c8..05bb294a1 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -7,15 +7,6 @@ struct llama_vocab; struct llama_grammar; -// samplers - -const char * llama_sampler_name_impl (const struct llama_sampler & smpl); -void llama_sampler_accept_impl( struct llama_sampler & smpl, llama_token token); -void llama_sampler_apply_impl ( struct llama_sampler & smpl, struct llama_token_data_array * cur_p); -void llama_sampler_reset_impl ( struct llama_sampler & smpl); -struct llama_sampler * llama_sampler_clone_impl (const struct llama_sampler & smpl); -void llama_sampler_free_impl ( struct llama_sampler * smpl); - // sampler chain struct llama_sampler_chain { @@ -30,11 +21,6 @@ struct llama_sampler_chain { mutable int32_t n_sample; }; -struct llama_sampler * llama_sampler_chain_init_impl( struct llama_sampler_chain_params params); -void llama_sampler_chain_add_impl ( struct llama_sampler_chain & chain, struct llama_sampler * smpl); -struct llama_sampler * llama_sampler_chain_get_impl (const struct llama_sampler_chain & chain, int32_t i); -int llama_sampler_chain_n_impl (const struct llama_sampler_chain & chain); - using llama_token_cnt = std::unordered_map; // TODO: tmp exposed until test-sampling is fixed @@ -45,17 +31,6 @@ void llama_sampler_penalties_impl( float penalty_freq, float penalty_present); -struct llama_sampler * llama_sampler_init_greedy_impl (); -struct llama_sampler * llama_sampler_init_dist_impl (uint32_t seed); -struct llama_sampler * llama_sampler_init_softmax_impl (); -struct llama_sampler * llama_sampler_init_top_k_impl (int32_t k); -struct llama_sampler * llama_sampler_init_top_p_impl (float p, size_t min_keep); -struct llama_sampler * llama_sampler_init_min_p_impl (float p, size_t min_keep); -struct llama_sampler * llama_sampler_init_tail_free_impl(float z, size_t min_keep); -struct llama_sampler * llama_sampler_init_typical_impl (float p, size_t min_keep); -struct llama_sampler * llama_sampler_init_temp_impl (float t); -struct llama_sampler * llama_sampler_init_temp_ext_impl (float t, float delta, float exponent); - struct llama_sampler * llama_sampler_init_mirostat_impl( const struct llama_vocab & vocab, uint32_t seed, @@ -63,11 +38,6 @@ struct llama_sampler * llama_sampler_init_mirostat_impl( float eta, int32_t m); -struct llama_sampler * llama_sampler_init_mirostat_v2_impl( - uint32_t seed, - float tau, - float eta); - struct llama_sampler * llama_sampler_init_grammar_impl( const struct llama_vocab & vocab, const char * grammar_str, @@ -82,7 +52,7 @@ struct llama_sampler * llama_sampler_init_penalties_impl( bool penalize_nl, bool ignore_eos); - LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias_impl( +LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias_impl( const struct llama_vocab & vocab, int32_t n_logit_bias, const llama_logit_bias * logit_bias); diff --git a/src/llama.cpp b/src/llama.cpp index db50a0332..f5e01004f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20592,102 +20592,17 @@ int32_t llama_chat_apply_template( // sampling // -const char * llama_sampler_name(const struct llama_sampler * smpl) { - return llama_sampler_name_impl(*smpl); -} - -void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { - llama_sampler_accept_impl(*smpl, token); -} - -void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - llama_sampler_apply_impl(*smpl, cur_p); -} - -void llama_sampler_reset(struct llama_sampler * smpl) { - llama_sampler_reset_impl(*smpl); -} - -struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) { - return llama_sampler_clone_impl(*smpl); -} - -void llama_sampler_free(struct llama_sampler * smpl) { - if (smpl == nullptr) { - return; - } - - llama_sampler_free_impl(smpl); -} - -struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { - return llama_sampler_chain_init_impl(params); -} - -void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { - llama_sampler_chain_add_impl(*(struct llama_sampler_chain *) chain->ctx, smpl); -} - -struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) { - return llama_sampler_chain_get_impl(*(const struct llama_sampler_chain *) chain->ctx, i); -} - -int llama_sampler_chain_n(const struct llama_sampler * chain) { - return llama_sampler_chain_n_impl(*(const struct llama_sampler_chain *) chain->ctx); -} - -struct llama_sampler * llama_sampler_init_greedy(void) { - return llama_sampler_init_greedy_impl(); -} - -struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { - return llama_sampler_init_dist_impl(seed); -} - -struct llama_sampler * llama_sampler_init_softmax(void) { - return llama_sampler_init_softmax_impl(); -} - -struct llama_sampler * llama_sampler_init_top_k(int32_t k) { - return llama_sampler_init_top_k_impl(k); -} - -struct llama_sampler * llama_sampler_init_top_p(float p, int32_t min_keep) { - return llama_sampler_init_top_p_impl(p, min_keep); -} - -struct llama_sampler * llama_sampler_init_min_p(float p, int32_t min_keep) { - return llama_sampler_init_min_p_impl(p, min_keep); -} - -struct llama_sampler * llama_sampler_init_tail_free(float z, int32_t min_keep) { - return llama_sampler_init_tail_free_impl(z, min_keep); -} - -struct llama_sampler * llama_sampler_init_typical(float p, int32_t min_keep) { - return llama_sampler_init_typical_impl(p, min_keep); -} - -struct llama_sampler * llama_sampler_init_temp(float temp) { - return llama_sampler_init_temp_impl(temp); -} - -struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { - return llama_sampler_init_temp_ext_impl(temp, delta, exponent); -} - +// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, uint32_t seed, float tau, float eta, int32_t m) { return llama_sampler_init_mirostat_impl(model->vocab, seed, tau, eta, m); } -struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { - return llama_sampler_init_mirostat_v2_impl(seed, tau, eta); -} - +// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) { return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root); } +// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp struct llama_sampler * llama_sampler_init_penalties( const struct llama_model * model, int32_t penalty_last_n, @@ -20699,31 +20614,14 @@ struct llama_sampler * llama_sampler_init_penalties( return llama_sampler_init_penalties_impl(model->vocab, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos); } -LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( +// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp +struct llama_sampler * llama_sampler_init_logit_bias( const struct llama_model * model, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { return llama_sampler_init_logit_bias_impl(model->vocab, n_logit_bias, logit_bias); } -llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { - const auto * logits = llama_get_logits_ith(ctx, idx); - - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); - - // TODO: do not allocate each time - std::vector cur(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; - } - - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; - - llama_sampler_apply(smpl, &cur_p); - - return cur_p.data[cur_p.selected].id; -} - // // model split //