From 8e80a1cf6ba2bf27a1ccf64847e066f70a495bc4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 4 Sep 2024 21:23:35 +0300 Subject: [PATCH] sampling : simplify sample API ggml-ci --- common/sampling.cpp | 124 +++++++------------ common/sampling.h | 11 +- examples/batched/batched.cpp | 6 +- examples/gritlm/gritlm.cpp | 8 +- examples/passkey/passkey.cpp | 8 +- examples/save-load-state/save-load-state.cpp | 6 +- examples/simple/simple.cpp | 8 +- examples/speculative/speculative.cpp | 8 +- include/llama.h | 20 ++- src/llama-sampling.cpp | 83 ++++++------- src/llama-sampling.h | 16 ++- src/llama.cpp | 48 +++---- 12 files changed, 147 insertions(+), 199 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 718001844..df2d1958c 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -40,8 +40,9 @@ std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) { struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) { llama_sampler_params lparams = llama_sampler_default_params(); - lparams.seed = params.seed; - lparams.n_prev = params.n_prev; + lparams.seed = params.seed; + lparams.n_prev = params.n_prev; + lparams.type = params.temp <= 0.0f ? LLAMA_SAMPLER_TYPE_GREEDY : LLAMA_SAMPLER_TYPE_DIST; auto * result = new gpt_sampler { /* .params = */ params, @@ -61,39 +62,41 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st /* .smpl = */ llama_sampler_init(model, lparams) }; - if (params.mirostat == 0) { - for (const auto & cnstr : params.constraints) { - switch (cnstr) { - case GPT_CONSTRAINT_TYPE_TOP_K: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep)); - break; - case GPT_CONSTRAINT_TYPE_TOP_P: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep)); - break; - case GPT_CONSTRAINT_TYPE_MIN_P: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep)); - break; - case GPT_CONSTRAINT_TYPE_TFS_Z: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep)); - break; - case GPT_CONSTRAINT_TYPE_TYPICAL_P: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep)); - break; - case GPT_CONSTRAINT_TYPE_TEMPERATURE: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); - break; - default: - GGML_ASSERT(false && "unknown constraint type"); + if (params.temp > 0.0f) { + if (params.mirostat == 0) { + for (const auto & cnstr : params.constraints) { + switch (cnstr) { + case GPT_CONSTRAINT_TYPE_TOP_K: + llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TOP_P: + llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_MIN_P: + llama_sampler_constraint_add(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TFS_Z: + llama_sampler_constraint_add(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TYPICAL_P: + llama_sampler_constraint_add(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TEMPERATURE: + llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + break; + default: + GGML_ASSERT(false && "unknown constraint type"); + } } + } else if (params.mirostat == 1) { + llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta)); + } else if (params.mirostat == 2) { + llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta)); + } else { + GGML_ASSERT(false && "unknown mirostat version"); } - } else if (params.mirostat == 1) { - llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp)); - llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta)); - } else if (params.mirostat == 2) { - llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp)); - llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta)); - } else { - GGML_ASSERT(false && "unknown mirostat version"); } return result; @@ -151,45 +154,11 @@ void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl) { llama_print_timings(ctx, gsmpl ? gsmpl->smpl : nullptr); } -static llama_token gpt_sampler_sample( - struct llama_sampler * smpl, - struct llama_token_data_array * cur_p, - float temp, - int n_probs) { - llama_token res = 0; - - if (temp < 0.0f || (temp == 0.0f && n_probs > 0)) { - // greedy sampling, with probs - res = llama_sampler_sample_greedy(smpl, cur_p, true); - } else if (temp == 0.0f) { - // greedy sampling, no probs - res = llama_sampler_sample_greedy(smpl, cur_p, false); - } else { - // apply all sampling constraints and then sample - llama_sampler_apply(smpl, cur_p); - - res = llama_sampler_sample_dist(smpl, cur_p); - - //{ - // const int n_top = 10; - // LOG("top %d candidates:\n", n_top); - - // for (int i = 0; i < n_top; i++) { - // const llama_token id = cur_p.data[i].id; - // (void)id; // To avoid a warning that id is unused when logging is disabled. - // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p); - // } - //} - - //LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str()); - } - - return res; +llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_token_data_array * cur_p) { + return llama_sampler_sample(gsmpl->smpl, cur_p); } llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx) { - const auto & params = gsmpl->params; - auto & bias = gsmpl->bias; auto & pnlt = gsmpl->pnlt; auto & grmr = gsmpl->grmr; @@ -204,8 +173,9 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context llama_constraint_apply(bias, cur_p); llama_constraint_apply(pnlt, cur_p); - // first, sample the token without any grammar constraints - const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.n_probs); + llama_sampler_apply(smpl, cur_p); + + const llama_token id = llama_sampler_sample(smpl, cur_p); // check if it the sampled token fits the grammar { @@ -228,7 +198,9 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context llama_constraint_apply(pnlt, cur_p); llama_constraint_apply(grmr, cur_p); - return gpt_sampler_sample(smpl, cur_p, params.temp, params.n_probs); + llama_sampler_apply(smpl, cur_p); + + return llama_sampler_sample(smpl, cur_p); } void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p) { @@ -237,14 +209,6 @@ void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_arra llama_constraint_apply(gsmpl->grmr, cur_p); } -llama_token gpt_sampler_sample_dist(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p) { - return llama_sampler_sample_dist(gsmpl->smpl, cur_p); -} - -llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p, bool probs) { - return llama_sampler_sample_greedy(gsmpl->smpl, cur_p, probs); -} - std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) { auto & smpl = gsmpl->smpl; diff --git a/common/sampling.h b/common/sampling.h index 8ec745999..9bdeadf78 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -73,15 +73,19 @@ struct gpt_sampler * gpt_sampler_cp(gpt_sampler * gsmpl); void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar); void gpt_sampler_reset (struct gpt_sampler * gsmpl); +void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p); + void gpt_sampler_set_logits(struct gpt_sampler * gsmpl, const float * logits); llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl); +llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_token_data_array * cur_p); + llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl); void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl); -// common sampling implementation: +// extended sampling implementation: // // - set logits // - apply the configured sampling constraints @@ -90,11 +94,6 @@ void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl); // llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx); -void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p); - -llama_token gpt_sampler_sample_dist (struct gpt_sampler * gsmpl, llama_token_data_array * cur_p); -llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p, bool probs); - // helpers // print the constraints into a string diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index cbab4b66b..9f0a40873 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -66,7 +66,7 @@ int main(int argc, char ** argv) { auto sparams = llama_sampler_default_params(); - sparams.seed = params.sparams.seed; + sparams.seed = params.sparams.seed; llama_sampler * smpl = llama_sampler_init(model, sparams); @@ -177,9 +177,7 @@ int main(int argc, char ** argv) { llama_sampler_set_logits(smpl, logits); - const llama_token new_token_id = llama_sampler_sample_dist(smpl, nullptr); - - //const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false); + const llama_token new_token_id = llama_sampler_sample(smpl, nullptr); // is it an end of generation? -> mark the stream as finished if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 978642cc3..07475ecd3 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -124,7 +124,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std llama_sampler_set_logits(smpl, logits); - llama_token token = llama_sampler_sample_greedy(smpl, nullptr, false); + llama_token token = llama_sampler_sample(smpl, nullptr); if (token == eos_token) { break; } @@ -171,7 +171,11 @@ int main(int argc, char * argv[]) { // create generation context llama_context * ctx = llama_new_context_with_model(model, cparams); - llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params()); + auto sparams = llama_sampler_default_params(); + + sparams.type = LLAMA_SAMPLER_TYPE_GREEDY; + + llama_sampler * smpl = llama_sampler_init(model, sparams); // ### Embedding/Representation ### // samples taken from: https://github.com/ContextualAI/gritlm#basic diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index b287d8403..b9800a917 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -83,7 +83,11 @@ int main(int argc, char ** argv) { return 1; } - llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params()); + auto sparams = llama_sampler_default_params(); + + sparams.type = LLAMA_SAMPLER_TYPE_GREEDY; + + llama_sampler * smpl = llama_sampler_init(model, sparams); // tokenize the prompt std::vector tokens_list; @@ -221,7 +225,7 @@ int main(int argc, char ** argv) { llama_sampler_set_logits(smpl, logits); // sample the most likely token - const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false); + const llama_token new_token_id = llama_sampler_sample(smpl, nullptr); // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 01f66886e..6f8c84137 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -73,7 +73,7 @@ int main(int argc, char ** argv) { llama_sampler_set_logits(smpl, logits); - auto next_token = llama_sampler_sample_dist(smpl, nullptr); + auto next_token = llama_sampler_sample(smpl, nullptr); auto next_token_str = llama_token_to_piece(ctx, next_token); printf("%s", next_token_str.c_str()); @@ -130,7 +130,7 @@ int main(int argc, char ** argv) { llama_sampler_set_logits(smpl2, logits); - auto next_token = llama_sampler_sample_dist(smpl2, nullptr); + auto next_token = llama_sampler_sample(smpl2, nullptr); auto next_token_str = llama_token_to_piece(ctx2, next_token); printf("%s", next_token_str.c_str()); @@ -219,7 +219,7 @@ int main(int argc, char ** argv) { llama_sampler_set_logits(smpl3, logits); - auto next_token = llama_sampler_sample_dist(smpl3, nullptr); + auto next_token = llama_sampler_sample(smpl3, nullptr); auto next_token_str = llama_token_to_piece(ctx3, next_token); printf("%s", next_token_str.c_str()); diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index ffaa609cb..7193f1ee4 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -55,7 +55,11 @@ int main(int argc, char ** argv) { return 1; } - llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params()); + auto sparams = llama_sampler_default_params(); + + sparams.type = LLAMA_SAMPLER_TYPE_GREEDY; + + llama_sampler * smpl = llama_sampler_init(model, sparams); // tokenize the prompt @@ -117,7 +121,7 @@ int main(int argc, char ** argv) { llama_sampler_set_logits(smpl, logits); // sample the most likely token - const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false); + const llama_token new_token_id = llama_sampler_sample(smpl, nullptr); // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index bb26f8eb5..0b49b2b06 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -182,6 +182,8 @@ int main(int argc, char ** argv) { // target model sampling context (reuse the llama_context's sampling instance) struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams); + struct llama_constraint * softmax = llama_constraint_init_softmax(); + // draft sequence data std::vector drafts(n_seq_dft); @@ -236,7 +238,7 @@ int main(int argc, char ** argv) { auto & dist_tgt = *gpt_sampler_get_candidates(smpl); gpt_sampler_apply_grammar(smpl, &dist_tgt); - gpt_sampler_sample_greedy(smpl, &dist_tgt, true); // applies softmax + llama_constraint_apply(softmax, &dist_tgt); float p_tgt = 0.0f; float p_dft = 0.0f; @@ -335,11 +337,10 @@ int main(int argc, char ** argv) { // all drafted tokens were rejected // sample from the target model LOG("all drafted tokens were rejected, sampling from residual distribution\n"); - token_id = gpt_sampler_sample_dist(smpl, &dist_tgt); + token_id = gpt_sampler_sample(smpl, &dist_tgt); gpt_sampler_accept(smpl, token_id, true); token_str = llama_token_to_piece(ctx_tgt, token_id); } - } else { // greedy verification @@ -615,6 +616,7 @@ int main(int argc, char ** argv) { gpt_sampler_free(drafts[s].smpl); } + llama_constraint_free(softmax); llama_batch_free(batch_dft); llama_free(ctx_tgt); diff --git a/include/llama.h b/include/llama.h index 813d854ef..763d32abb 100644 --- a/include/llama.h +++ b/include/llama.h @@ -370,8 +370,8 @@ extern "C" { } llama_logit_bias; enum llama_sampler_type { - LLAMA_SAMPLER_TYPE_GREEDY = 0, - LLAMA_SAMPLER_TYPE_DIST = 1, + LLAMA_SAMPLER_TYPE_GREEDY = 0, + LLAMA_SAMPLER_TYPE_DIST = 1, }; typedef struct llama_sampler_params { @@ -1092,10 +1092,12 @@ extern "C" { // samplers - LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_model * model, struct llama_sampler_params params); - LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); - LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl); - LLAMA_API void llama_sampler_reset( struct llama_sampler * smpl); + LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_model * model, struct llama_sampler_params params); + LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); + LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl); + LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl); + LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); + LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); LLAMA_API void llama_sampler_set_logits(struct llama_sampler * smpl, const float * logits); @@ -1107,11 +1109,7 @@ extern "C" { LLAMA_API struct llama_constraint * llama_sampler_constraint_get(const struct llama_sampler * smpl, int32_t i); - LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); - LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p); - - LLAMA_API llama_token llama_sampler_sample_dist (struct llama_sampler * smpl, llama_token_data_array * cur_p); - LLAMA_API llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs); + LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, llama_token_data_array * cur_p); /// @details Get the number of accepted tokens so far (max of n_prev) LLAMA_API int llama_sampler_n_prev(const struct llama_sampler * smpl); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 91a2cb5f5..dfd618c33 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1183,6 +1183,20 @@ void llama_sampler_reset_impl(struct llama_sampler & smpl) { // TODO: should we reset the timings? } +void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) { + smpl.prev.push_back(token); + + for (auto * cnstr : smpl.constraints) { + llama_constraint_accept_impl(*cnstr, token); + } +} + +void llama_sampler_apply_impl(struct llama_sampler & smpl, struct llama_token_data_array * cur_p) { + for (auto * cnstr : smpl.constraints) { + llama_constraint_apply_impl(*cnstr, cur_p); + } +} + void llama_sampler_constraint_add_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr) { smpl.constraints.push_back(cnstr); } @@ -1199,17 +1213,31 @@ struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_s return smpl.constraints[ith]; } -void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) { - smpl.prev.push_back(token); +llama_token llama_sampler_sample_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, enum llama_sampler_type type) { + switch (type) { + case LLAMA_SAMPLER_TYPE_GREEDY: + { + llama_constraint_softmax_impl(cur_p); - for (auto * cnstr : smpl.constraints) { - llama_constraint_accept_impl(*cnstr, token); - } -} + return cur_p->data[0].id; + } + case LLAMA_SAMPLER_TYPE_DIST: + { + llama_constraint_softmax_impl(cur_p); -void llama_sampler_apply_impl(struct llama_sampler & smpl, struct llama_token_data_array * cur_p) { - for (auto * cnstr : smpl.constraints) { - llama_constraint_apply_impl(*cnstr, cur_p); + std::vector probs(cur_p->size); + for (size_t i = 0; i < cur_p->size; ++i) { + probs[i] = cur_p->data[i].p; + } + + std::discrete_distribution<> dist(probs.begin(), probs.end()); + + const int idx = dist(rng); + + return cur_p->data[idx].id; + } + default: + GGML_ABORT("invalid sampler type"); } } @@ -1224,40 +1252,3 @@ llama_token llama_sampler_prev_impl(const struct llama_sampler & smpl, int ith) int llama_sampler_n_prev_impl(const struct llama_sampler & smpl) { return smpl.prev.size(); } - -llama_token llama_sampler_sample_greedy_impl(llama_token_data_array * cur_p, bool probs) { - if (probs) { - // if probs are needed, we apply softmax to get the probabilities - llama_constraint_softmax_impl(cur_p); - - // the cur_p are sorted, so we can just return the first one - return cur_p->data[0].id; - } - - // return the token with the highest logit - auto * max_iter = std::max_element(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit < b.logit; - }); - - llama_token result = max_iter->id; - - return result; -} - -llama_token llama_sampler_sample_dist_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng) { - llama_constraint_softmax_impl(cur_p); - - std::vector probs; - probs.reserve(cur_p->size); - - for (size_t i = 0; i < cur_p->size; ++i) { - probs.push_back(cur_p->data[i].p); - } - - std::discrete_distribution<> dist(probs.begin(), probs.end()); - - const int idx = dist(rng); - llama_token result = cur_p->data[idx].id; - - return result; -} diff --git a/src/llama-sampling.h b/src/llama-sampling.h index bf5f596f7..acb1e04fe 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -104,20 +104,18 @@ struct llama_sampler { mutable int32_t n_sample; }; -struct llama_sampler * llama_sampler_init_impl (const struct llama_vocab & vocab, struct llama_sampler_params params); -void llama_sampler_free_impl ( struct llama_sampler * smpl); -struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl); -void llama_sampler_reset_impl( struct llama_sampler & smpl); +struct llama_sampler * llama_sampler_init_impl (const struct llama_vocab & vocab, struct llama_sampler_params params); +void llama_sampler_free_impl ( struct llama_sampler * smpl); +struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl); +void llama_sampler_reset_impl ( struct llama_sampler & smpl); +void llama_sampler_accept_impl( struct llama_sampler & smpl, llama_token token); +void llama_sampler_apply_impl ( struct llama_sampler & smpl, struct llama_token_data_array * cur_p); void llama_sampler_constraint_add_impl( struct llama_sampler & smpl, struct llama_constraint * cnstr); int llama_sampler_n_constraints_impl (const struct llama_sampler & smpl); struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_sampler & smpl, int ith); -void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token); -void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_data_array * cur_p); +llama_token llama_sampler_sample_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, enum llama_sampler_type type); llama_token llama_sampler_prev_impl (const struct llama_sampler & smpl, int ith); int llama_sampler_n_prev_impl(const struct llama_sampler & smpl); - -llama_token llama_sampler_sample_greedy_impl(struct llama_token_data_array * cur_p, bool probs); -llama_token llama_sampler_sample_dist_impl (struct llama_token_data_array * cur_p, std::mt19937 & rng); diff --git a/src/llama.cpp b/src/llama.cpp index 6426073eb..2f5df2433 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17939,7 +17939,7 @@ struct llama_sampler_params llama_sampler_default_params() { struct llama_sampler_params result = { /*.seed =*/ LLAMA_DEFAULT_SEED, /*.n_prev =*/ 256, - /*.type =*/ LLAMA_SAMPLER_TYPE_GREEDY, + /*.type =*/ LLAMA_SAMPLER_TYPE_DIST, }; return result; @@ -20713,6 +20713,20 @@ void llama_sampler_reset(struct llama_sampler * smpl) { llama_sampler_reset_impl(*smpl); } +void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { + llama_sampler_accept_impl(*smpl, token); +} + +void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + time_meas tm(smpl->t_sample_us); + + if (cur_p == nullptr) { + cur_p = &smpl->cur_p; + } + + llama_sampler_apply_impl(*smpl, cur_p); +} + void llama_sampler_set_logits(struct llama_sampler * smpl, const float * logits) { const int n_vocab = smpl->vocab->n_vocab; @@ -20741,42 +20755,14 @@ struct llama_constraint * llama_sampler_constraint_get(const struct llama_sample return llama_sampler_constraint_get_impl(*smpl, i); } -void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { - llama_sampler_accept_impl(*smpl, token); -} - -void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { +llama_token llama_sampler_sample(struct llama_sampler * smpl, llama_token_data_array * cur_p) { time_meas tm(smpl->t_sample_us); if (cur_p == nullptr) { cur_p = &smpl->cur_p; } - llama_sampler_apply_impl(*smpl, cur_p); -} - -llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs) { - time_meas tm(smpl->t_sample_us); - - if (cur_p == nullptr) { - cur_p = &smpl->cur_p; - } - - auto res = llama_sampler_sample_greedy_impl(cur_p, probs); - - smpl->n_sample++; - - return res; -} - -llama_token llama_sampler_sample_dist(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - time_meas tm(smpl->t_sample_us); - - if (cur_p == nullptr) { - cur_p = &smpl->cur_p; - } - - auto res = llama_sampler_sample_dist_impl(cur_p, smpl->rng); + auto res = llama_sampler_sample_impl(cur_p, smpl->rng, smpl->params.type); smpl->n_sample++;