From befcfe7a31dec28c7284bde9ff82847ca6578de9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 6 Sep 2024 14:02:17 +0300 Subject: [PATCH] common : simplify gpt_sampler ggml-ci --- common/sampling.cpp | 49 ++++++++++++++++++--------------------------- 1 file changed, 19 insertions(+), 30 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 553aefbf4..a4baf9db6 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -98,10 +98,7 @@ struct ring_buffer { struct gpt_sampler { gpt_sampler_params params; - struct llama_sampler * bias; - struct llama_sampler * pnlt; struct llama_sampler * grmr; - struct llama_sampler * chain; ring_buffer prev; @@ -140,11 +137,11 @@ std::string gpt_sampler_params::print() const { } std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) { - std::string result = "\tlogits"; + std::string result = "\tlogits "; for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) { const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i); - result += std::string(" -> ") + llama_sampler_name(smpl) + " "; + result += std::string("-> ") + llama_sampler_name(smpl) + " "; } return result; @@ -157,18 +154,6 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st auto * result = new gpt_sampler { /* .params = */ params, - /* .bias = */ llama_sampler_init_logit_bias( - model, - params.logit_bias.size(), - params.logit_bias.data()), - /* .pnlt = */ llama_sampler_init_penalties( - model, - params.penalty_last_n, - params.penalty_repeat, - params.penalty_freq, - params.penalty_present, - params.penalize_nl, - params.ignore_eos), /* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"), /* .chain = */ llama_sampler_chain_init(lparams), /* .prev = */ ring_buffer(params.n_prev), @@ -176,6 +161,22 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st /* .cur_p = */ {}, }; + llama_sampler_chain_add(result->chain, + llama_sampler_init_logit_bias( + model, + params.logit_bias.size(), + params.logit_bias.data())); + + llama_sampler_chain_add(result->chain, + llama_sampler_init_penalties( + model, + params.penalty_last_n, + params.penalty_repeat, + params.penalty_freq, + params.penalty_present, + params.penalize_nl, + params.ignore_eos)); + if (params.temp > 0.0f) { if (params.mirostat == 0) { for (const auto & cnstr : params.samplers) { @@ -223,8 +224,6 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st void gpt_sampler_free(struct gpt_sampler * gsmpl) { if (gsmpl) { - llama_sampler_free(gsmpl->bias); - llama_sampler_free(gsmpl->pnlt); llama_sampler_free(gsmpl->grmr); llama_sampler_free(gsmpl->chain); @@ -236,8 +235,6 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl) { struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) { return new gpt_sampler { /* .params = */ gsmpl->params, - /* .bias = */ llama_sampler_clone(gsmpl->bias), - /* .pnlt = */ llama_sampler_clone(gsmpl->pnlt), /* .grmr = */ llama_sampler_clone(gsmpl->grmr), /* .chain = */ llama_sampler_clone(gsmpl->chain), /* .prev = */ gsmpl->prev, @@ -282,8 +279,6 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * } llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { - auto & bias = gsmpl->bias; - auto & pnlt = gsmpl->pnlt; auto & grmr = gsmpl->grmr; auto & chain = gsmpl->chain; @@ -291,9 +286,6 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context auto & cur_p = gsmpl->cur_p; - llama_sampler_apply(bias, &cur_p); - llama_sampler_apply(pnlt, &cur_p); - if (grammar_first) { llama_sampler_apply(grmr, &cur_p); } @@ -325,10 +317,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context // if the token is not valid, sample again, first apply the grammar samplers and then sample gsmpl->set_logits(ctx, idx); - llama_sampler_apply(bias, &cur_p); - llama_sampler_apply(pnlt, &cur_p); - llama_sampler_apply(grmr, &cur_p); - + llama_sampler_apply(grmr, &cur_p); llama_sampler_apply(chain, &cur_p); GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");