diff --git a/common/sampling.cpp b/common/sampling.cpp index 4dfbe9021..4e8843224 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -2,6 +2,16 @@ #include "common.h" +struct gpt_sampler { + gpt_sampler_params params; + + struct llama_constraint * bias; + struct llama_constraint * pnlt; + struct llama_constraint * grmr; + + struct llama_sampler * smpl; +}; + std::string gpt_sampler_params::print_all() const { char result[1024]; @@ -33,8 +43,6 @@ std::string gpt_sampler_params::print_constraints() const { } struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) { - gpt_sampler * result = new gpt_sampler(); - llama_sampler_params lparams = llama_sampler_default_params(); lparams.seed = params.seed; @@ -43,21 +51,23 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st lparams.mirostat_tau = params.mirostat_tau; lparams.mirostat_eta = params.mirostat_eta; - result->smpl = llama_sampler_init(model, lparams); - - llama_sampler_add_constraint(result->smpl, llama_constraint_init_logit_bias( - model, - params.logit_bias.size(), - params.logit_bias.data())); - - llama_sampler_add_constraint(result->smpl, llama_constraint_init_penalties( - model, - params.penalty_last_n, - params.penalty_repeat, - params.penalty_freq, - params.penalty_present, - params.penalize_nl, - params.ignore_eos)); + auto * result = new gpt_sampler { + .params = params, + .bias = llama_constraint_init_logit_bias( + model, + params.logit_bias.size(), + params.logit_bias.data()), + .pnlt = llama_constraint_init_penalties( + model, + params.penalty_last_n, + params.penalty_repeat, + params.penalty_freq, + params.penalty_present, + params.penalize_nl, + params.ignore_eos), + .grmr = llama_constraint_init_grammar(model, params.grammar.c_str(), "root"), + .smpl = llama_sampler_init(model, lparams) + }; for (const auto & cnstr : params.constraints) { switch (cnstr) { @@ -84,14 +94,15 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st } } - result->grmr = llama_constraint_init_grammar(model, params.grammar.c_str(), "root"); - return result; } void gpt_sampler_free(struct gpt_sampler * gsmpl) { if (gsmpl) { + llama_constraint_free(gsmpl->bias); + llama_constraint_free(gsmpl->pnlt); llama_constraint_free(gsmpl->grmr); + llama_sampler_free(gsmpl->smpl); delete gsmpl; @@ -121,18 +132,28 @@ void gpt_sampler_reset (struct gpt_sampler * gsmpl) { llama_sampler_reset(gsmpl->smpl); } +void gpt_sampler_set_logits(struct gpt_sampler * gsmpl, const float * logits) { + llama_sampler_set_logits(gsmpl->smpl, logits); +} + +llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) { + return llama_sampler_get_candidates(gsmpl->smpl); +} + llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) { return llama_sampler_last(gsmpl->smpl); } +void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl) { + llama_print_timings(ctx, gsmpl->smpl); +} + static llama_token gpt_sampler_sample( struct llama_sampler * smpl, struct llama_token_data_array * cur_p, float temp, int mirostat, int n_probs) { - GGML_ASSERT(cur_p != nullptr && "candidates array must be provided"); - llama_token res = 0; if (temp < 0.0f || (temp == 0.0f && n_probs > 0)) { @@ -142,6 +163,7 @@ static llama_token gpt_sampler_sample( // greedy sampling, no probs res = llama_sampler_sample_greedy(smpl, cur_p, false); } else { + // apply all sampling constraints and then sample llama_sampler_apply(smpl, cur_p); if (mirostat != 0) { @@ -167,42 +189,62 @@ static llama_token gpt_sampler_sample( return res; } -llama_token gpt_sampler_sample( - struct gpt_sampler * gsmpl, - struct llama_context * ctx, - int idx) { +llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx) { const auto & params = gsmpl->params; + auto & bias = gsmpl->bias; + auto & pnlt = gsmpl->pnlt; auto & grmr = gsmpl->grmr; auto & smpl = gsmpl->smpl; - llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx)); - auto * cur_p = llama_sampler_get_candidates(smpl); - // first, sample the token without any grammar constraints - const llama_token id = gpt_sampler_sample(smpl, cur_p, params.temp, params.mirostat, params.n_probs); - - // create an array with a single token data element for the sampled id - llama_token_data single_token_data = { id, 1.0f, 0.0f }; - llama_token_data_array single_token_data_array = { &single_token_data, 1, false }; - - llama_constraint_apply(grmr, &single_token_data_array); - - // check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY - const bool is_valid = single_token_data_array.data[0].logit != -INFINITY; - if (is_valid) { - return id; - } - - // if the token is not valid, sample again, after applying the grammar constraints llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx)); + llama_constraint_apply(bias, cur_p); + llama_constraint_apply(pnlt, cur_p); + + // first, sample the token without any grammar constraints + const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.mirostat, params.n_probs); + + // check if it the sampled token fits the grammar + { + llama_token_data single_token_data = { id, 1.0f, 0.0f }; + llama_token_data_array single_token_data_array = { &single_token_data, 1, false }; + + llama_constraint_apply(grmr, &single_token_data_array); + + // check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY + const bool is_valid = single_token_data_array.data[0].logit != -INFINITY; + if (is_valid) { + return id; + } + } + + // if the token is not valid, sample again, first apply the grammar constraints and then sample + llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx)); + + llama_constraint_apply(bias, cur_p); + llama_constraint_apply(pnlt, cur_p); llama_constraint_apply(grmr, cur_p); return gpt_sampler_sample(smpl, cur_p, params.temp, params.mirostat, params.n_probs); } +void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * candidates) { + GGML_ASSERT(candidates != nullptr); + + llama_constraint_apply(gsmpl->grmr, candidates); +} + +llama_token gpt_sampler_sample_dist(struct gpt_sampler * gsmpl, llama_token_data_array * candidates) { + return llama_sampler_sample_dist(gsmpl->smpl, candidates); +} + +llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_data_array * candidates, bool probs) { + return llama_sampler_sample_greedy(gsmpl->smpl, candidates, probs); +} + std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) { auto & smpl = gsmpl->smpl; diff --git a/common/sampling.h b/common/sampling.h index 4efa4a17c..8cb3da762 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -60,15 +60,12 @@ struct gpt_sampler_params { std::string print_constraints() const; }; -struct gpt_sampler { - gpt_sampler_params params; - - struct llama_constraint * grmr = nullptr; - - struct llama_sampler * smpl = nullptr; -}; - -// llama_sampler API overload +// gpt_sampler extends llama_sampler with additional functionality: +// +// - grammar support +// - custom sampler logic based on the paramerters +// +struct gpt_sampler; struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params); @@ -79,8 +76,14 @@ struct gpt_sampler * gpt_sampler_cp(gpt_sampler * gsmpl); void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar); void gpt_sampler_reset (struct gpt_sampler * gsmpl); +void gpt_sampler_set_logits(struct gpt_sampler * gsmpl, const float * logits); + +llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl); + llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl); +void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl); + // common sampling implementation: // // - set logits @@ -88,10 +91,12 @@ llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl); // - check if the token fits the grammar (if any) // - if not: resample by first applying the grammar constraints and then sampling again (slower path) // -llama_token gpt_sampler_sample( - struct gpt_sampler * gsmpl, - struct llama_context * ctx, - int idx); +llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx); + +void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * candidates); + +llama_token gpt_sampler_sample_dist (struct gpt_sampler * gsmpl, llama_token_data_array * candidates); +llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_data_array * candidates, bool probs); // helpers diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 4dfa19ce8..3052b96ae 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -64,14 +64,15 @@ int main(int argc, char ** argv) { llama_context * ctx = llama_new_context_with_model(model, ctx_params); - auto sparams = llama_sampling_default_params(); + auto sparams = llama_sampler_default_params(); sparams.seed = params.sparams.seed; - sparams.top_k = 40; - sparams.top_p = 0.9f; - sparams.temp = 0.4f; - llama_sampling * smpl = llama_sampling_init(model, sparams); + llama_sampler * smpl = llama_sampler_init(model, sparams); + + llama_sampler_add_constraint(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep)); + llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_p)); + llama_sampler_add_constraint(smpl, llama_constraint_init_temp (params.sparams.temp)); if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); @@ -174,15 +175,11 @@ int main(int argc, char ** argv) { const auto * logits = llama_get_logits_ith(ctx, i_batch[i]); - llama_sampling_set_logits(smpl, logits); + llama_sampler_set_logits(smpl, logits); - llama_sampling_top_k(smpl, nullptr); - llama_sampling_top_p(smpl, nullptr); - llama_sampling_temp (smpl, nullptr); + const llama_token new_token_id = llama_sampler_sample_dist(smpl, nullptr); - const llama_token new_token_id = llama_sampling_sample_dist(smpl, nullptr); - - //const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr); + //const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr); // is it an end of generation? -> mark the stream as finished if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { @@ -246,7 +243,7 @@ int main(int argc, char ** argv) { llama_batch_free(batch); - llama_sampling_free(smpl); + llama_sampler_free(smpl); llama_free(ctx); llama_free_model(model); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 7d2ae7713..978642cc3 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -92,7 +92,7 @@ static std::vector> encode(llama_context * ctx, const std::ve return result; } -static std::string generate(llama_context * ctx, llama_sampling * smpl, const std::string & prompt, bool stream) { +static std::string generate(llama_context * ctx, llama_sampler * smpl, const std::string & prompt, bool stream) { std::string result; const llama_model * model = llama_get_model(ctx); @@ -122,9 +122,9 @@ static std::string generate(llama_context * ctx, llama_sampling * smpl, const st const auto * logits = llama_get_logits_ith(ctx, bat.n_tokens - 1); - llama_sampling_set_logits(smpl, logits); + llama_sampler_set_logits(smpl, logits); - llama_token token = llama_sampling_sample_greedy(smpl, nullptr); + llama_token token = llama_sampler_sample_greedy(smpl, nullptr, false); if (token == eos_token) { break; } @@ -171,7 +171,7 @@ int main(int argc, char * argv[]) { // create generation context llama_context * ctx = llama_new_context_with_model(model, cparams); - llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params()); + llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params()); // ### Embedding/Representation ### // samples taken from: https://github.com/ContextualAI/gritlm#basic @@ -212,7 +212,7 @@ int main(int argc, char * argv[]) { std::string response = generate(ctx, smpl, prompt, true); } - llama_sampling_free(smpl); + llama_sampler_free(smpl); llama_free(ctx); llama_free_model(model); llama_backend_free(); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 371232421..9f9f81a7f 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -33,7 +33,7 @@ static llama_context ** g_ctx; static llama_model ** g_model; -static llama_sampling ** g_smpl; +static gpt_sampler ** g_smpl; static gpt_params * g_params; static std::vector * g_input_tokens; static std::ostringstream * g_output_ss; @@ -93,7 +93,7 @@ static void sigint_handler(int signo) { } else { console::cleanup(); printf("\n"); - llama_print_timings(*g_ctx, *g_smpl); + gpt_print_timings(*g_ctx, *g_smpl); write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); _exit(130); } @@ -167,7 +167,7 @@ int main(int argc, char ** argv) { llama_model * model = nullptr; llama_context * ctx = nullptr; - llama_sampling * smpl = nullptr; + gpt_sampler * smpl = nullptr; g_model = &model; g_ctx = &ctx; @@ -345,7 +345,7 @@ int main(int argc, char ** argv) { std::vector embd; - smpl = llama_sampling_init(model, sparams); + smpl = gpt_sampler_init(model, sparams); while (n_remain != 0 || params.interactive) { // predict @@ -417,9 +417,9 @@ int main(int argc, char ** argv) { embd.clear(); if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - const llama_token id = llama_sampling_sample(smpl, ctx, -1); + const llama_token id = gpt_sampler_sample(smpl, ctx, -1); - llama_sampling_accept(smpl, id, true); + gpt_sampler_accept(smpl, id, true); // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str()); @@ -440,7 +440,7 @@ int main(int argc, char ** argv) { // push the prompt in the sampling context in order to apply repetition penalties later // for the prompt, we don't apply grammar rules - llama_sampling_accept(smpl, embd_inp[n_consumed], false); + gpt_sampler_accept(smpl, embd_inp[n_consumed], false); ++n_consumed; if ((int) embd.size() >= params.n_batch) { @@ -472,7 +472,7 @@ int main(int argc, char ** argv) { // if not currently processing queued inputs; if ((int) embd_inp.size() <= n_consumed) { // deal with eot token in infill mode - if ((llama_sampling_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){ + if ((gpt_sampler_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){ if (is_interacting && !params.interactive_first) { // print an eot token printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str()); @@ -538,7 +538,7 @@ int main(int argc, char ** argv) { is_interacting = false; } // deal with end of generation tokens in interactive mode - else if (llama_token_is_eog(model, llama_sampling_last(smpl))) { + else if (llama_token_is_eog(model, gpt_sampler_last(smpl))) { LOG("found EOS token\n"); if (params.interactive) { @@ -611,7 +611,7 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { - llama_sampling_reset(smpl); + gpt_sampler_reset(smpl); } is_interacting = false; } @@ -634,13 +634,13 @@ int main(int argc, char ** argv) { fflush(stdout); } - llama_print_timings(ctx, smpl); + gpt_print_timings(ctx, smpl); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); llama_free(ctx); llama_free_model(model); - llama_sampling_free(smpl); + gpt_sampler_free(smpl); llama_backend_free(); #ifndef LOG_DISABLE_LOGS diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index d01c5909e..63a75c4a3 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -40,11 +40,11 @@ static bool eval_string(struct llama_context * ctx_llama, const char* str, int n return true; } -static const char * sample(struct llama_sampling * smpl, +static const char * sample(struct gpt_sampler * smpl, struct llama_context * ctx_llama, int * n_past) { - const llama_token id = llama_sampling_sample(smpl, ctx_llama, -1); - llama_sampling_accept(smpl, id, true); + const llama_token id = gpt_sampler_sample(smpl, ctx_llama, -1); + gpt_sampler_accept(smpl, id, true); static std::string ret; if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { ret = ""; @@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ LOG_TEE("\n"); - struct llama_sampling * smpl = llama_sampling_init(ctx_llava->model, params->sparams); + struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams); if (!smpl) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); @@ -211,7 +211,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ fflush(stdout); } - llama_sampling_free(smpl); + gpt_sampler_free(smpl); printf("\n"); } diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index c041fe530..15f258b91 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -163,11 +163,11 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e LOG_TEE("%s: image token past: %d\n", __func__, n_past); } -static const char * sample(struct llama_sampling * smpl, +static const char * sample(struct gpt_sampler * smpl, struct llama_context * ctx_llama, int * n_past) { - const llama_token id = llama_sampling_sample(smpl, ctx_llama, -1); - llama_sampling_accept(smpl, id, true); + const llama_token id = gpt_sampler_sample(smpl, ctx_llama, -1); + gpt_sampler_accept(smpl, id, true); static std::string ret; if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { ret = ""; @@ -214,7 +214,7 @@ static struct llava_context * minicpmv_init(gpt_params * params, const std::stri return ctx_llava; } -static struct llama_sampling * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){ +static struct gpt_sampler * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){ std::string user_prompt = prompt; int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip); if (!is_first) { @@ -238,11 +238,11 @@ static struct llama_sampling * llama_init(struct llava_context * ctx_llava, gpt_ LOG_TEE("\n"); - struct llama_sampling * smpl = llama_sampling_init(ctx_llava->model, params->sparams); + struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams); return smpl; } -static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling * smpl, int &n_past){ +static const char * llama_loop(struct llava_context * ctx_llava,struct gpt_sampler * smpl, int &n_past){ const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past); return tmp; @@ -296,7 +296,7 @@ int main(int argc, char ** argv) { fflush(stdout); } - llama_sampling_free(smpl); + gpt_sampler_free(smpl); }else { while (true) { LOG_TEE(""); @@ -315,7 +315,7 @@ int main(int argc, char ** argv) { if (strstr(response.c_str(), "")) break; // minicpm-v fflush(stdout); } - llama_sampling_free(smpl); + gpt_sampler_free(smpl); } } printf("\n"); diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 2bd31d002..8b461555b 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -117,7 +117,7 @@ int main(int argc, char ** argv) { llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); // target model sampling context - struct llama_sampling * smpl = llama_sampling_init(model, params.sparams); + struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams); // verification n-grams std::vector ngrams_cur(G); @@ -158,9 +158,9 @@ int main(int argc, char ** argv) { // sample first token { - id = llama_sampling_sample(smpl, ctx, 0); + id = gpt_sampler_sample(smpl, ctx, 0); - llama_sampling_accept(smpl, id, true); + gpt_sampler_accept(smpl, id, true); { const std::string token_str = llama_token_to_piece(ctx, id); @@ -283,9 +283,9 @@ int main(int argc, char ** argv) { } // sample the next token - id = llama_sampling_sample(smpl, ctx, i_batch); + id = gpt_sampler_sample(smpl, ctx, i_batch); - llama_sampling_accept(smpl, id, true); + gpt_sampler_accept(smpl, id, true); // print { @@ -360,7 +360,7 @@ int main(int argc, char ** argv) { if (v == 0) { // sample from the last level for (int i = 0; i < W; i++) { - tokens_j[N - 2][i] = llama_sampling_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i); + tokens_j[N - 2][i] = gpt_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i); } } else { for (int i = 0; i < W; i++) { @@ -467,10 +467,11 @@ int main(int argc, char ** argv) { LOG_TEE("n_predict = %d\n", n_predict); LOG_TEE("n_accept = %d\n", n_accept); - llama_print_timings(ctx, smpl); + gpt_print_timings(ctx, smpl); + + gpt_sampler_free(smpl); llama_kv_cache_view_free(&kvc_view); - llama_sampling_free(smpl); llama_batch_free(batch); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index da4d57a51..da3583f3c 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -104,7 +104,7 @@ int main(int argc, char ** argv){ bool has_eos = false; - struct llama_sampling * smpl = llama_sampling_init(model, params.sparams); + struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams); std::vector draft; @@ -128,9 +128,9 @@ int main(int argc, char ** argv){ int i_dft = 0; while (true) { // sample from the target model - llama_token id = llama_sampling_sample(smpl, ctx, i_dft); + llama_token id = gpt_sampler_sample(smpl, ctx, i_dft); - llama_sampling_accept(smpl, id, true); + gpt_sampler_accept(smpl, id, true); const std::string token_str = llama_token_to_piece(ctx, id); @@ -239,9 +239,10 @@ int main(int argc, char ** argv){ LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); LOG_TEE("\ntarget:\n"); - llama_print_timings(ctx, smpl); + gpt_print_timings(ctx, smpl); + + gpt_sampler_free(smpl); - llama_sampling_free(smpl); llama_batch_free(batch_tgt); llama_free(ctx); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 88202b800..1b706efbc 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -106,7 +106,7 @@ static void sigint_handler(int signo) { } else { console::cleanup(); printf("\n"); - llama_print_timings(*g_ctx, (*g_smpl)->smpl); + gpt_print_timings(*g_ctx, *g_smpl); write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); _exit(130); } @@ -928,7 +928,7 @@ int main(int argc, char ** argv) { llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); } - llama_print_timings(ctx, smpl->smpl); + gpt_print_timings(ctx, smpl); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); gpt_sampler_free(smpl); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 7ce982a92..7422042db 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -51,7 +51,7 @@ static std::vector k_prompts = { struct client { ~client() { if (smpl) { - llama_sampling_free(smpl); + gpt_sampler_free(smpl); } } @@ -72,7 +72,7 @@ struct client { std::string prompt; std::string response; - struct llama_sampling * smpl = nullptr; + struct gpt_sampler * smpl = nullptr; }; static void print_date_time() { @@ -161,7 +161,7 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < clients.size(); ++i) { auto & client = clients[i]; client.id = i; - client.smpl = llama_sampling_init(model, params.sparams); + client.smpl = gpt_sampler_init(model, params.sparams); } std::vector tokens_system; @@ -253,7 +253,7 @@ int main(int argc, char ** argv) { client.prompt = client.input + "\nAssistant:"; client.response = ""; - llama_sampling_reset(client.smpl); + gpt_sampler_reset(client.smpl); // do not prepend BOS because we have a system prompt! std::vector tokens_prompt; @@ -341,9 +341,9 @@ int main(int argc, char ** argv) { //printf("client %d, seq %d, token %d, pos %d, batch %d\n", // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); - const llama_token id = llama_sampling_sample(client.smpl, ctx, client.i_batch - i); + const llama_token id = gpt_sampler_sample(client.smpl, ctx, client.i_batch - i); - llama_sampling_accept(client.smpl, id, true); + gpt_sampler_accept(client.smpl, id, true); if (client.n_decoded == 1) { // start measuring generation time after the first token to make sure all concurrent clients diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 0992ccc3c..b287d8403 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -83,7 +83,7 @@ int main(int argc, char ** argv) { return 1; } - llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params()); + llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params()); // tokenize the prompt std::vector tokens_list; @@ -218,10 +218,10 @@ int main(int argc, char ** argv) { { const auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); - llama_sampling_set_logits(smpl, logits); + llama_sampler_set_logits(smpl, logits); // sample the most likely token - const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr); + const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false); // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { @@ -262,9 +262,10 @@ int main(int argc, char ** argv) { fprintf(stderr, "\n"); + llama_sampler_free(smpl); + llama_batch_free(batch); - llama_sampling_free(smpl); llama_free(ctx); llama_free_model(model); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 02f7a93eb..01f66886e 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -38,10 +38,10 @@ int main(int argc, char ** argv) { return 1; } - llama_sampling_params sparams = llama_sampling_default_params(); + llama_sampler_params sparams = llama_sampler_default_params(); sparams.seed = params.sparams.seed; - llama_sampling * smpl = llama_sampling_init(model, sparams); + llama_sampler * smpl = llama_sampler_init(model, sparams); // tokenize prompt auto tokens = llama_tokenize(ctx, params.prompt, true); @@ -71,9 +71,9 @@ int main(int argc, char ** argv) { for (auto i = 0; i < params.n_predict; i++) { const auto * logits = llama_get_logits(ctx); - llama_sampling_set_logits(smpl, logits); + llama_sampler_set_logits(smpl, logits); - auto next_token = llama_sampling_sample_dist(smpl, nullptr); + auto next_token = llama_sampler_sample_dist(smpl, nullptr); auto next_token_str = llama_token_to_piece(ctx, next_token); printf("%s", next_token_str.c_str()); @@ -96,7 +96,7 @@ int main(int argc, char ** argv) { // make new context auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); - llama_sampling * smpl2 = llama_sampling_init(model, sparams); + llama_sampler * smpl2 = llama_sampler_init(model, sparams); printf("\nsecond run: %s", params.prompt.c_str()); @@ -128,9 +128,9 @@ int main(int argc, char ** argv) { for (auto i = 0; i < params.n_predict; i++) { const auto * logits = llama_get_logits(ctx2); - llama_sampling_set_logits(smpl2, logits); + llama_sampler_set_logits(smpl2, logits); - auto next_token = llama_sampling_sample_dist(smpl2, nullptr); + auto next_token = llama_sampler_sample_dist(smpl2, nullptr); auto next_token_str = llama_token_to_piece(ctx2, next_token); printf("%s", next_token_str.c_str()); @@ -157,7 +157,7 @@ int main(int argc, char ** argv) { // make new context auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); - llama_sampling * smpl3 = llama_sampling_init(model, sparams); + llama_sampler * smpl3 = llama_sampler_init(model, sparams); printf("\nsingle seq run: %s", params.prompt.c_str()); @@ -217,9 +217,9 @@ int main(int argc, char ** argv) { for (auto i = 0; i < params.n_predict; i++) { const auto * logits = llama_get_logits(ctx3); - llama_sampling_set_logits(smpl3, logits); + llama_sampler_set_logits(smpl3, logits); - auto next_token = llama_sampling_sample_dist(smpl3, nullptr); + auto next_token = llama_sampler_sample_dist(smpl3, nullptr); auto next_token_str = llama_token_to_piece(ctx3, next_token); printf("%s", next_token_str.c_str()); @@ -236,9 +236,9 @@ int main(int argc, char ** argv) { printf("\n"); - llama_sampling_free(smpl); - llama_sampling_free(smpl2); - llama_sampling_free(smpl3); + llama_sampler_free(smpl); + llama_sampler_free(smpl2); + llama_sampler_free(smpl3); llama_free(ctx3); llama_free_model(model); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 139e503b9..03e512e03 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -170,10 +170,10 @@ struct server_slot { // sampling json json_schema; - struct gpt_sampling_params sparams; + struct gpt_sampler_params sparams; + struct gpt_sampler * smpl = nullptr; llama_token sampled; - llama_sampling * smpl = nullptr; int32_t ga_i = 0; // group-attention state int32_t ga_n = 1; // group-attention factor @@ -653,7 +653,7 @@ struct server_context { // Clear any sampling context for (server_slot & slot : slots) { if (slot.smpl != nullptr) { - llama_sampling_free(slot.smpl); + gpt_sampler_free(slot.smpl); } } @@ -1027,26 +1027,26 @@ struct server_context { } { - const auto & samplers = data.find("samplers"); - if (samplers != data.end() && samplers->is_array()) { - std::vector sampler_names; - for (const auto & sampler_name : *samplers) { - if (sampler_name.is_string()) { - sampler_names.emplace_back(sampler_name); + const auto & constraints = data.find("samplers"); + if (constraints != data.end() && constraints->is_array()) { + std::vector constraint_names; + for (const auto & name : *constraints) { + if (name.is_string()) { + constraint_names.emplace_back(name); } } - slot.sparams.samplers = llama_sampling_types_from_names(sampler_names, false); + slot.sparams.constraints = gpt_constraint_types_from_names(constraint_names, false); } else { - slot.sparams.samplers = default_sparams.samplers; + slot.sparams.constraints = default_sparams.constraints; } } { if (slot.smpl != nullptr) { - llama_sampling_free(slot.smpl); + gpt_sampler_free(slot.smpl); } - slot.smpl = llama_sampling_init(model, slot.sparams); + slot.smpl = gpt_sampler_init(model, slot.sparams); if (slot.smpl == nullptr) { // for now, the only error that may happen here is invalid grammar send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); @@ -1253,10 +1253,10 @@ struct server_context { } json get_formated_generation(const server_slot & slot) const { - std::vector samplers; - samplers.reserve(slot.sparams.samplers.size()); - for (const auto & sampler : slot.sparams.samplers) { - samplers.emplace_back(llama_sampling_type_to_str(sampler)); + std::vector constraints; + constraints.reserve(slot.sparams.constraints.size()); + for (const auto & constraint : slot.sparams.constraints) { + constraints.emplace_back(gpt_constraint_type_to_str(constraint)); } return json { @@ -1290,7 +1290,7 @@ struct server_context { {"n_probs", slot.sparams.n_probs}, {"min_keep", slot.sparams.min_keep}, {"grammar", slot.sparams.grammar}, - {"samplers", samplers}, + {"samplers", constraints}, }; } @@ -2084,7 +2084,7 @@ struct server_context { GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); } - llama_sampling_reset(slot.smpl); + gpt_sampler_reset(slot.smpl); if (!slot.params.cache_prompt) { slot.n_past_se = 0; @@ -2097,7 +2097,7 @@ struct server_context { // push the prompt into the sampling context (do not apply grammar) for (int i = 0; i < slot.n_past; ++i) { - llama_sampling_accept(slot.smpl, slot.cache_tokens[i], false); + gpt_sampler_accept(slot.smpl, slot.cache_tokens[i], false); } } } @@ -2150,7 +2150,7 @@ struct server_context { slot.n_past_se = 0; slot.ga_i = 0; // TODO: is the system prompt ever in the sampling context? - llama_sampling_reset(slot.smpl); + gpt_sampler_reset(slot.smpl); } // remove the non-common part from the cache @@ -2332,9 +2332,9 @@ struct server_context { } completion_token_output result; - const llama_token id = llama_sampling_sample(slot.smpl, ctx, slot.i_batch - i); + const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i); - llama_sampling_accept(slot.smpl, id, true); + gpt_sampler_accept(slot.smpl, id, true); slot.n_decoded += 1; if (slot.n_decoded == 1) { @@ -2345,7 +2345,7 @@ struct server_context { result.tok = id; - const auto * cur_p = llama_sampling_get_candidates(slot.smpl); + const auto * cur_p = gpt_sampler_get_candidates(slot.smpl); // TODO: this logic might have been broken during https://github.com/ggerganov/llama.cpp/pull/8643 // fix if necessary diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 674158b85..ffaa609cb 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -55,7 +55,7 @@ int main(int argc, char ** argv) { return 1; } - llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params()); + llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params()); // tokenize the prompt @@ -114,10 +114,10 @@ int main(int argc, char ** argv) { { const auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); - llama_sampling_set_logits(smpl, logits); + llama_sampler_set_logits(smpl, logits); // sample the most likely token - const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr); + const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false); // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { @@ -159,8 +159,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "\n"); llama_batch_free(batch); - - llama_sampling_free(smpl); + llama_sampler_free(smpl); llama_free(ctx); llama_free_model(model); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index e95073366..bb26f8eb5 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -21,7 +21,7 @@ struct seq_draft { std::vector tokens; std::vector> dists; - struct llama_sampling * smpl; + struct gpt_sampler * smpl = nullptr; }; int main(int argc, char ** argv) { @@ -180,14 +180,14 @@ int main(int argc, char ** argv) { bool has_eos = false; // target model sampling context (reuse the llama_context's sampling instance) - struct llama_sampling * smpl = llama_sampling_init(model_tgt, params.sparams); + struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams); // draft sequence data std::vector drafts(n_seq_dft); for (int s = 0; s < n_seq_dft; ++s) { - // allocate llama_sampling for each draft sequence - drafts[s].smpl = llama_sampling_init(model_dft, params.sparams); + // allocate gpt_sampler for each draft sequence + drafts[s].smpl = gpt_sampler_init(model_dft, params.sparams); } llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); @@ -229,13 +229,14 @@ int main(int argc, char ** argv) { bool accept = false; if (params.sparams.temp > 0) { // stochastic verification + const float * logits = llama_get_logits_ith(ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]); - llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft])); + gpt_sampler_set_logits(smpl, logits); - auto & dist_tgt = *llama_sampling_get_candidates(smpl); + auto & dist_tgt = *gpt_sampler_get_candidates(smpl); - llama_sampling_grammar(smpl, &dist_tgt); - llama_sampling_softmax(smpl, &dist_tgt); + gpt_sampler_apply_grammar(smpl, &dist_tgt); + gpt_sampler_sample_greedy(smpl, &dist_tgt, true); // applies softmax float p_tgt = 0.0f; float p_dft = 0.0f; @@ -280,7 +281,7 @@ int main(int argc, char ** argv) { accept = true; token_id = drafts[s].tokens[i_dft]; token_str = llama_token_to_piece(ctx_tgt, token_id); - llama_sampling_accept(smpl, token_id, true); + gpt_sampler_accept(smpl, token_id, true); LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str()); break; @@ -334,8 +335,8 @@ int main(int argc, char ** argv) { // all drafted tokens were rejected // sample from the target model LOG("all drafted tokens were rejected, sampling from residual distribution\n"); - token_id = llama_sampling_sample_dist(smpl, &dist_tgt); - llama_sampling_accept(smpl, token_id, true); + token_id = gpt_sampler_sample_dist(smpl, &dist_tgt); + gpt_sampler_accept(smpl, token_id, true); token_str = llama_token_to_piece(ctx_tgt, token_id); } @@ -344,9 +345,9 @@ int main(int argc, char ** argv) { // sample from the target model LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); - token_id = llama_sampling_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]); + token_id = gpt_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]); - llama_sampling_accept(smpl, token_id, true); + gpt_sampler_accept(smpl, token_id, true); //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, smpl->prev).c_str()); @@ -436,7 +437,10 @@ int main(int argc, char ** argv) { break; } - llama_sampling_cp(smpl, drafts[0].smpl); + if (drafts[0].smpl) { + gpt_sampler_free(drafts[0].smpl); + } + drafts[0].smpl = gpt_sampler_cp(smpl); int n_seq_cur = 1; int n_past_cur = n_past_dft; @@ -465,9 +469,9 @@ int main(int argc, char ** argv) { continue; } - llama_sampling_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft); + gpt_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft); - const auto * cur_p = llama_sampling_get_candidates(drafts[s].smpl); + const auto * cur_p = gpt_sampler_get_candidates(drafts[s].smpl); for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) { LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n", @@ -505,7 +509,11 @@ int main(int argc, char ** argv) { drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft; drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt; - llama_sampling_cp(drafts[s].smpl, drafts[n_seq_cur].smpl); + if (drafts[n_seq_cur].smpl) { + gpt_sampler_free(drafts[n_seq_cur].smpl); + } + drafts[n_seq_cur].smpl = gpt_sampler_cp(drafts[s].smpl); + sa.push_back(n_seq_cur); @@ -521,7 +529,7 @@ int main(int argc, char ** argv) { const int s = sa[is]; - llama_sampling_accept(drafts[s].smpl, id, true); + gpt_sampler_accept(drafts[s].smpl, id, true); drafts[s].tokens.push_back(id); // save cur_p.data into drafts[s].dists @@ -597,14 +605,14 @@ int main(int argc, char ** argv) { LOG_TEE("\ndraft:\n"); // TODO: print sampling/grammar timings for all drafts - llama_print_timings(ctx_dft, nullptr); + gpt_print_timings(ctx_dft, nullptr); LOG_TEE("\ntarget:\n"); - llama_print_timings(ctx_tgt, smpl); + gpt_print_timings(ctx_tgt, smpl); - llama_sampling_free(smpl); + gpt_sampler_free(smpl); for (int s = 0; s < n_seq_dft; ++s) { - llama_sampling_free(drafts[s].smpl); + gpt_sampler_free(drafts[s].smpl); } llama_batch_free(batch_dft); diff --git a/include/llama.h b/include/llama.h index 49011b7fe..8a02800ce 100644 --- a/include/llama.h +++ b/include/llama.h @@ -46,9 +46,6 @@ #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_VERSION 2 -// TODO: remove before merge -#define LLAMA_MAX_SAMPLERS 16 - #ifdef __cplusplus extern "C" { #endif @@ -1001,133 +998,6 @@ extern "C" { // // Sampling API - // TODO: remove before merge - // - - // TODO: llama_model should become llama_vocab - //LLAMA_API struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params); - - //LLAMA_API void llama_sampling_free(struct llama_sampling * smpl); - - //// Copies the internal state of the sampler (rng, prev, params, grammar, etc.) - //LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl); - - //// - clear prev token - //// - reset grammar state - //LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl); - - //// Sampling parameter mutation - //// TODO: not sure if we want to keep these. Maybe it's better to keep llama_sampling immutable - //LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root); - //LLAMA_API void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias); - - //// Set the logits from which to sample. - //// This call initializes the internal token candidates array. - //// The internal candidates are implicitly used by the sampling API below when no candidates are provided. - //LLAMA_API void llama_sampling_set_logits( - // struct llama_sampling * smpl, - // const float * logits); - - ///// @details Returns the current candidate tokens. - //LLAMA_API llama_token_data_array * llama_sampling_get_candidates( - // struct llama_sampling * smpl); - - //// The llama_sampling_ API below uses the parameters passed during the creation of the llama_sampling object. - //// Each function can accept an array of token candidates. If the candidates are not provided, the internal - //// candidates are used. The internal candidates are initialized by llama_sampling_set_logits(). - - ///// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - //LLAMA_API void llama_sampling_softmax( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - //LLAMA_API void llama_sampling_top_k( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - //LLAMA_API void llama_sampling_top_p( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 - //LLAMA_API void llama_sampling_min_p( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. - //LLAMA_API void llama_sampling_tail_free( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. - //LLAMA_API void llama_sampling_typical( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Apply temperature and entropy - //LLAMA_API void llama_sampling_temp( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Apply constraints from grammar - //LLAMA_API void llama_sampling_grammar( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. - ///// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. - //LLAMA_API void llama_sampling_penalties( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Mirostat algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. - //LLAMA_API llama_token llama_sampling_sample_mirostat( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Selects the token with the highest probability. - ///// Does not compute the token probabilities. Use llama_sampling_softmax() instead. - //LLAMA_API llama_token llama_sampling_sample_greedy( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Randomly selects a token from the candidates based on their probability distribution. - //LLAMA_API llama_token llama_sampling_sample_dist( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Sample a token using the configured samplers (see "llama_sampling_params.samplers"). - //LLAMA_API llama_token llama_sampling_sample( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Accepts the sampled token into the sampling context. - ///// - adds it to "prev" tokens - ///// - updates the grammar state (if apply_grammar is true) - //LLAMA_API void llama_sampling_accept( - // struct llama_sampling * smpl, - // llama_token token, - // bool apply_grammar); - - ///// @details Get the number of accepted tokens so far (max of n_prev) - //LLAMA_API int llama_sampling_n_prev(const struct llama_sampling * smpl); - - ///// @details Get the ith accepted token - ///// @param ith [0, n_prev), ith == 0 is the last accepted token. - ///// returns LLAMA_TOKEN_NULL if ith is out of bounds - //LLAMA_API llama_token llama_sampling_prev( - // const struct llama_sampling * smpl, - // int32_t ith); - - ///// @details Get the last accepted token - ///// Same as llama_sampling_prev(smpl, 0) - ///// returns LLAMA_TOKEN_NULL if there are no accepted tokens - //LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl); - - // - // Sampling v2 API // // - Constraints // The llama_constraint object works on a set of candidate tokens (llama_token_data_array), by modifying their @@ -1203,7 +1073,7 @@ extern "C" { LLAMA_API struct llama_constraint * llama_constraint_cp(const struct llama_constraint * cnstr); - // do not call if used with llama_sampler_add_constraint + // important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_add_constraint) LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr); LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token); @@ -1221,11 +1091,7 @@ extern "C" { LLAMA_API llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smpl); - - // TODO: should this take ownership so the user does not need to call llama_constraint_free - // or should just make a reference to the constraint so that it can be reused in multiple llama_sampler? - // - // seems better to take the ownership, otherwise the copying of the sampler will be more complicated + // important: takes ownership of the constraint object and will free it in llama_sampler_free LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr); LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index abf9d5a8e..eca2adb2b 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -421,113 +421,8 @@ void llama_constraint_penalties_impl( candidates->sorted = false; } -llama_token llama_sampler_sample_mirostat_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) { - llama_constraint_softmax_impl(candidates); - - // Estimate s_hat using the most probable m tokens - float s_hat = 0.0; - float sum_ti_bi = 0.0; - float sum_ti_sq = 0.0; - for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) { - float t_i = logf(float(i + 2) / float(i + 1)); - float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p); - sum_ti_bi += t_i * b_i; - sum_ti_sq += t_i * t_i; - } - s_hat = sum_ti_bi / sum_ti_sq; - - // Compute k from the estimated s_hat and target surprise value - float epsilon_hat = s_hat - 1; - float k = powf((epsilon_hat * powf(2, mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); - - // Sample the next word X using top-k sampling - llama_constraint_top_k_impl(candidates, int(k), 1); - llama_token X = llama_sampler_sample_dist_impl(candidates, rng); - - // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { - return candidate.id == X; - })); - float observed_surprise = -log2f(candidates->data[X_idx].p); - float e = observed_surprise - tau; - - // Update mu using the learning rate and error - mu = mu - eta * e; - - return X; -} - -llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu) { - llama_constraint_softmax_impl(candidates); - - // Truncate the words with surprise values greater than mu - candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { - return -log2f(candidate.p) > mu; - })); - - if (candidates->size == 0) { - candidates->size = 1; - } - - // Normalize the probabilities of the remaining words - llama_constraint_softmax_impl(candidates); - - // Sample the next word X from the remaining words - llama_token X = llama_sampler_sample_dist_impl(candidates, rng); - - // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { - return candidate.id == X; - })); - - float observed_surprise = -log2f(candidates->data[X_idx].p); - float e = observed_surprise - tau; - - // Update mu using the learning rate and error - mu = mu - eta * e; - - return X; -} - -llama_token llama_sampler_sample_greedy_impl(llama_token_data_array * candidates, bool probs) { - if (probs) { - // if probs are needed, we apply softmax to get the probabilities - llama_constraint_softmax_impl(candidates); - - // the candidates are sorted, so we can just return the first one - return candidates->data[0].id; - } - - // return the token with the highest logit - auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit < b.logit; - }); - - llama_token result = max_iter->id; - - return result; -} - -llama_token llama_sampler_sample_dist_impl(struct llama_token_data_array * candidates, std::mt19937 & rng) { - llama_constraint_softmax_impl(candidates); - - std::vector probs; - probs.reserve(candidates->size); - - for (size_t i = 0; i < candidates->size; ++i) { - probs.push_back(candidates->data[i].p); - } - - std::discrete_distribution<> dist(probs.begin(), probs.end()); - - const int idx = dist(rng); - llama_token result = candidates->data[idx].id; - - return result; -} - // -// sampling v2 +// sampling // // constraints @@ -1172,3 +1067,107 @@ int llama_sampler_n_prev_impl(const struct llama_sampler & smpl) { return smpl.prev.size(); } +llama_token llama_sampler_sample_mirostat_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) { + llama_constraint_softmax_impl(candidates); + + // Estimate s_hat using the most probable m tokens + float s_hat = 0.0; + float sum_ti_bi = 0.0; + float sum_ti_sq = 0.0; + for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) { + float t_i = logf(float(i + 2) / float(i + 1)); + float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p); + sum_ti_bi += t_i * b_i; + sum_ti_sq += t_i * t_i; + } + s_hat = sum_ti_bi / sum_ti_sq; + + // Compute k from the estimated s_hat and target surprise value + float epsilon_hat = s_hat - 1; + float k = powf((epsilon_hat * powf(2, mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); + + // Sample the next word X using top-k sampling + llama_constraint_top_k_impl(candidates, int(k), 1); + llama_token X = llama_sampler_sample_dist_impl(candidates, rng); + + // Compute error as the difference between observed surprise and target surprise value + size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return candidate.id == X; + })); + float observed_surprise = -log2f(candidates->data[X_idx].p); + float e = observed_surprise - tau; + + // Update mu using the learning rate and error + mu = mu - eta * e; + + return X; +} + +llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu) { + llama_constraint_softmax_impl(candidates); + + // Truncate the words with surprise values greater than mu + candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return -log2f(candidate.p) > mu; + })); + + if (candidates->size == 0) { + candidates->size = 1; + } + + // Normalize the probabilities of the remaining words + llama_constraint_softmax_impl(candidates); + + // Sample the next word X from the remaining words + llama_token X = llama_sampler_sample_dist_impl(candidates, rng); + + // Compute error as the difference between observed surprise and target surprise value + size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return candidate.id == X; + })); + + float observed_surprise = -log2f(candidates->data[X_idx].p); + float e = observed_surprise - tau; + + // Update mu using the learning rate and error + mu = mu - eta * e; + + return X; +} + +llama_token llama_sampler_sample_greedy_impl(llama_token_data_array * candidates, bool probs) { + if (probs) { + // if probs are needed, we apply softmax to get the probabilities + llama_constraint_softmax_impl(candidates); + + // the candidates are sorted, so we can just return the first one + return candidates->data[0].id; + } + + // return the token with the highest logit + auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit < b.logit; + }); + + llama_token result = max_iter->id; + + return result; +} + +llama_token llama_sampler_sample_dist_impl(struct llama_token_data_array * candidates, std::mt19937 & rng) { + llama_constraint_softmax_impl(candidates); + + std::vector probs; + probs.reserve(candidates->size); + + for (size_t i = 0; i < candidates->size; ++i) { + probs.push_back(candidates->data[i].p); + } + + std::discrete_distribution<> dist(probs.begin(), probs.end()); + + const int idx = dist(rng); + llama_token result = candidates->data[idx].id; + + return result; +} diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 501f11de8..f60d5b95f 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -10,6 +10,7 @@ struct llama_grammar; using llama_token_cnt = std::unordered_map; +// TODO: tmp exposed, until tests start using llama_constraint void llama_constraint_softmax_impl (struct llama_token_data_array * candidates); void llama_constraint_top_k_impl (struct llama_token_data_array * candidates, int32_t k, size_t min_keep); void llama_constraint_top_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); @@ -27,30 +28,6 @@ void llama_constraint_penalties_impl( float penalty_freq, float penalty_present); -/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. -/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. -/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. -/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. -/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. -/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. -llama_token llama_sampler_sample_mirostat_impl (struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu); - -/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. -/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. -/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. -/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. -/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. -llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu); - -llama_token llama_sampler_sample_greedy_impl(struct llama_token_data_array * candidates, bool probs); -llama_token llama_sampler_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng); - - - -// -// sampling v2 -// - // constraints struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k, size_t min_keep); @@ -128,3 +105,21 @@ void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_d llama_token llama_sampler_prev_impl (const struct llama_sampler & smpl, int ith); int llama_sampler_n_prev_impl(const struct llama_sampler & smpl); + +/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. +/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. +/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. +/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. +/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. +/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. +llama_token llama_sampler_sample_mirostat_impl (struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu); + +/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. +/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. +/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. +/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. +/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. +llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu); + +llama_token llama_sampler_sample_greedy_impl(struct llama_token_data_array * candidates, bool probs); +llama_token llama_sampler_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng); diff --git a/src/llama.cpp b/src/llama.cpp index 903be4575..2b54a1ff3 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20609,346 +20609,6 @@ int32_t llama_chat_apply_template( // sampling // -//struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params) { -// return llama_sampling_init_impl(model->vocab, params); -//} - -//void llama_sampling_free(struct llama_sampling * smpl) { -// if (smpl == nullptr) { -// return; -// } - -// llama_sampling_free_impl(smpl); -//} - -//struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl) { -// return llama_sampling_cp_impl(*smpl); -//} - -//void llama_sampling_reset(struct llama_sampling * smpl) { -// llama_sampling_reset_impl(*smpl); -//} - -//void llama_sampling_set_grammar(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) { -// llama_sampling_set_grammar_impl(*smpl, grammar_str, grammar_root); -//} - -//void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { -// llama_sampling_set_logit_bias_impl(*smpl, n_logit_bias, logit_bias); -//} - -//void llama_sampling_set_logits(struct llama_sampling * smpl, const float * logits) { -// const int n_vocab = smpl->vocab.n_vocab; - -// smpl->cur.resize(n_vocab); - -// for (llama_token token_id = 0; token_id < n_vocab; token_id++) { -// smpl->cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; -// } - -// for (const auto & lb : smpl->logit_bias) { -// smpl->cur[lb.token].logit += lb.bias; -// } - -// if (smpl->params.ignore_eos) { -// smpl->cur[llama_token_eos_impl(smpl->vocab)].logit = -INFINITY; -// } - -// smpl->cur_p = { smpl->cur.data(), smpl->cur.size(), false }; - -// // apply penalties -// { -// const float nl_logit = smpl->cur[llama_token_nl_impl(smpl->vocab)].logit; - -// llama_sampling_penalties(smpl, &smpl->cur_p); - -// if (!smpl->params.penalize_nl) { -// for (size_t idx = 0; idx < smpl->cur_p.size; idx++) { -// if (smpl->cur_p.data[idx].id == llama_token_nl_impl(smpl->vocab)) { -// smpl->cur_p.data[idx].logit = nl_logit; -// break; -// } -// } -// } -// } -//} - -//llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling * smpl) { -// return &smpl->cur_p; -//} - -//void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// llama_sampling_softmax_impl(candidates); -//} - -//void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// llama_sampling_top_k_impl(candidates, smpl->params.top_k, smpl->params.min_keep); -//} - -//void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// llama_sampling_top_p_impl(candidates, smpl->params.top_p, smpl->params.min_keep); -//} - -//void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// llama_sampling_min_p_impl(candidates, smpl->params.min_p, smpl->params.min_keep); -//} - -//void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// llama_sampling_tail_free_impl(candidates, smpl->params.tfs_z, smpl->params.min_keep); -//} - -//void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// llama_sampling_typical_impl(candidates, smpl->params.typ_p, smpl->params.min_keep); -//} - -//void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// if (smpl->params.dynatemp_range > 0) { -// const float dynatemp_min = std::max(0.0f, smpl->params.temp - smpl->params.dynatemp_range); -// const float dynatemp_max = std::max(0.0f, smpl->params.temp + smpl->params.dynatemp_range); - -// llama_sampling_entropy_impl(candidates, dynatemp_min, dynatemp_max, smpl->params.dynatemp_exponent); -// } else { -// llama_sampling_temp_impl(candidates, smpl->params.temp); -// } -//} - -//void llama_sampling_grammar(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_grammar_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// if (smpl->grammar) { -// llama_sampling_grammar_impl(candidates, *smpl->grammar); - -// smpl->n_grammar++; -// } -//} - -//void llama_sampling_penalties( -// struct llama_sampling * smpl, -// llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// const size_t penalty_last_n = std::min(smpl->params.penalty_last_n, smpl->prev.size()); - -// const float penalty_repeat = smpl->params.penalty_repeat; -// const float penalty_freq = smpl->params.penalty_freq; -// const float penalty_present = smpl->params.penalty_present; - -// if ((penalty_last_n == 0) || -// (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { -// return; -// } - -// // Create a frequency map to count occurrences of each token in last_tokens -// // TODO: move to sampling state and avoid reallocation -// llama_token_cnt token_count; -// for (size_t i = 0; i < penalty_last_n; ++i) { -// token_count[smpl->prev.rat(i)]++; -// } - -// llama_sampling_penalties_impl(candidates, token_count, penalty_repeat, penalty_freq, penalty_present); -//} - -//llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// const auto type = smpl->params.mirostat; - -// llama_token res; - -// if (type == 1) { -// res = llama_sampling_sample_mirostat_impl(candidates, -// smpl->rng, -// smpl->params.mirostat_tau, -// smpl->params.mirostat_eta, -// 100, -// smpl->vocab.n_vocab, -// smpl->mirostat_mu); -// } else if (type == 2) { -// res = llama_sampling_sample_mirostat_v2_impl(candidates, -// smpl->rng, -// smpl->params.mirostat_tau, -// smpl->params.mirostat_eta, -// smpl->mirostat_mu); -// } else { -// GGML_ABORT("invalid mirostat type: %d", type); -// } - -// smpl->n_sample++; - -// return res; -//} - -//llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// auto res = llama_sampling_sample_greedy_impl(candidates); - -// smpl->n_sample++; - -// return res; -//} - -//llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// auto res = llama_sampling_sample_dist_impl(candidates, smpl->rng); - -// smpl->n_sample++; - -// return res; -//} - -//llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// const auto & params = smpl->params; - -// const float temp = params.temp; -// const int mirostat = params.mirostat; - -// auto & cur_p = candidates; - -// llama_token res = 0; - -// if (temp < 0.0f || (temp == 0.0f && params.n_probs > 0)) { -// // greedy sampling, with probs -// llama_sampling_softmax_impl(cur_p); -// res = cur_p->data[0].id; -// } else if (temp == 0.0f) { -// // greedy sampling, no probs -// res = llama_sampling_sample_greedy(smpl, cur_p); -// } else { -// if (mirostat != 0) { -// llama_sampling_temp(smpl, cur_p); -// res = llama_sampling_sample_mirostat(smpl, cur_p); -// } else { -// for (const auto & sampler : smpl->samplers) { -// switch (sampler) { -// case LLAMA_CONSTRAINT_TYPE_TOP_K: llama_sampling_top_k_impl (cur_p, smpl->params.top_k, smpl->params.min_keep); break; -// case LLAMA_CONSTRAINT_TYPE_TFS_Z: llama_sampling_tail_free_impl(cur_p, smpl->params.tfs_z, smpl->params.min_keep); break; -// case LLAMA_CONSTRAINT_TYPE_TYPICAL_P: llama_sampling_typical_impl (cur_p, smpl->params.typ_p, smpl->params.min_keep); break; -// case LLAMA_CONSTRAINT_TYPE_TOP_P: llama_sampling_top_p_impl (cur_p, smpl->params.top_p, smpl->params.min_keep); break; -// case LLAMA_CONSTRAINT_TYPE_MIN_P: llama_sampling_min_p_impl (cur_p, smpl->params.min_p, smpl->params.min_keep); break; -// case LLAMA_CONSTRAINT_TYPE_TEMPERATURE: llama_sampling_temp_impl (cur_p, temp); break; -// default : break; -// } -// } - -// res = llama_sampling_sample_dist(smpl, cur_p); - -// //{ -// // const int n_top = 10; -// // LOG("top %d candidates:\n", n_top); - -// // for (int i = 0; i < n_top; i++) { -// // const llama_token id = cur_p.data[i].id; -// // (void)id; // To avoid a warning that id is unused when logging is disabled. -// // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p); -// // } -// //} - -// //LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str()); -// } -// } - -// smpl->n_sample++; - -// return res; -//} - -//void llama_sampling_accept( -// struct llama_sampling * smpl, -// llama_token token, -// bool apply_grammar) { -// time_meas tm(smpl->t_accept_us); - -// llama_sampling_accept_impl(*smpl, token, apply_grammar); - -// smpl->n_accept++; -//} - -//int llama_sampling_n_prev(const struct llama_sampling * smpl) { -// return llama_sampling_n_prev_impl(*smpl); -//} - -//llama_token llama_sampling_prev(const struct llama_sampling * smpl, int32_t ith) { -// return llama_sampling_prev_impl(*smpl, ith); -//} - -//llama_token llama_sampling_last(const struct llama_sampling * smpl) { -// return llama_sampling_prev_impl(*smpl, 0); -//} - -// -// sampling v2 -// - struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep) { return llama_constraint_init_top_k_impl(k, min_keep); } @@ -21070,6 +20730,10 @@ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_sample_us); + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + llama_sampler_apply_impl(*smpl, candidates); } diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index f5e32a741..16eeaa1c8 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -30,9 +30,9 @@ static void test_top_k(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; DUMP(&candidates_p); - llama_sampling_tail_free_impl(&candidates_p, z, 1); + llama_constraint_tail_free_impl(&candidates_p, z, 1); DUMP(&candidates_p); GGML_ASSERT(candidates_p.size == expected_probs.size()); @@ -96,9 +96,9 @@ static void test_min_p(const std::vector & probs, const std::vector & probs, const std::vector