common : simplify gpt_sampler

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-09-06 14:02:17 +03:00
parent 757a9bf868
commit befcfe7a31
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -98,10 +98,7 @@ struct ring_buffer {
struct gpt_sampler { struct gpt_sampler {
gpt_sampler_params params; gpt_sampler_params params;
struct llama_sampler * bias;
struct llama_sampler * pnlt;
struct llama_sampler * grmr; struct llama_sampler * grmr;
struct llama_sampler * chain; struct llama_sampler * chain;
ring_buffer<llama_token> prev; ring_buffer<llama_token> prev;
@ -140,11 +137,11 @@ std::string gpt_sampler_params::print() const {
} }
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) { std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
std::string result = "\tlogits"; std::string result = "\tlogits ";
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) { for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i); const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
result += std::string(" -> ") + llama_sampler_name(smpl) + " "; result += std::string("-> ") + llama_sampler_name(smpl) + " ";
} }
return result; return result;
@ -157,18 +154,6 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
auto * result = new gpt_sampler { auto * result = new gpt_sampler {
/* .params = */ params, /* .params = */ params,
/* .bias = */ llama_sampler_init_logit_bias(
model,
params.logit_bias.size(),
params.logit_bias.data()),
/* .pnlt = */ llama_sampler_init_penalties(
model,
params.penalty_last_n,
params.penalty_repeat,
params.penalty_freq,
params.penalty_present,
params.penalize_nl,
params.ignore_eos),
/* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"), /* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
/* .chain = */ llama_sampler_chain_init(lparams), /* .chain = */ llama_sampler_chain_init(lparams),
/* .prev = */ ring_buffer<llama_token>(params.n_prev), /* .prev = */ ring_buffer<llama_token>(params.n_prev),
@ -176,6 +161,22 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
/* .cur_p = */ {}, /* .cur_p = */ {},
}; };
llama_sampler_chain_add(result->chain,
llama_sampler_init_logit_bias(
model,
params.logit_bias.size(),
params.logit_bias.data()));
llama_sampler_chain_add(result->chain,
llama_sampler_init_penalties(
model,
params.penalty_last_n,
params.penalty_repeat,
params.penalty_freq,
params.penalty_present,
params.penalize_nl,
params.ignore_eos));
if (params.temp > 0.0f) { if (params.temp > 0.0f) {
if (params.mirostat == 0) { if (params.mirostat == 0) {
for (const auto & cnstr : params.samplers) { for (const auto & cnstr : params.samplers) {
@ -223,8 +224,6 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
void gpt_sampler_free(struct gpt_sampler * gsmpl) { void gpt_sampler_free(struct gpt_sampler * gsmpl) {
if (gsmpl) { if (gsmpl) {
llama_sampler_free(gsmpl->bias);
llama_sampler_free(gsmpl->pnlt);
llama_sampler_free(gsmpl->grmr); llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->chain); llama_sampler_free(gsmpl->chain);
@ -236,8 +235,6 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl) {
struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) { struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
return new gpt_sampler { return new gpt_sampler {
/* .params = */ gsmpl->params, /* .params = */ gsmpl->params,
/* .bias = */ llama_sampler_clone(gsmpl->bias),
/* .pnlt = */ llama_sampler_clone(gsmpl->pnlt),
/* .grmr = */ llama_sampler_clone(gsmpl->grmr), /* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .chain = */ llama_sampler_clone(gsmpl->chain), /* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .prev = */ gsmpl->prev, /* .prev = */ gsmpl->prev,
@ -282,8 +279,6 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
} }
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
auto & bias = gsmpl->bias;
auto & pnlt = gsmpl->pnlt;
auto & grmr = gsmpl->grmr; auto & grmr = gsmpl->grmr;
auto & chain = gsmpl->chain; auto & chain = gsmpl->chain;
@ -291,9 +286,6 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
auto & cur_p = gsmpl->cur_p; auto & cur_p = gsmpl->cur_p;
llama_sampler_apply(bias, &cur_p);
llama_sampler_apply(pnlt, &cur_p);
if (grammar_first) { if (grammar_first) {
llama_sampler_apply(grmr, &cur_p); llama_sampler_apply(grmr, &cur_p);
} }
@ -325,10 +317,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
// if the token is not valid, sample again, first apply the grammar samplers and then sample // if the token is not valid, sample again, first apply the grammar samplers and then sample
gsmpl->set_logits(ctx, idx); gsmpl->set_logits(ctx, idx);
llama_sampler_apply(bias, &cur_p); llama_sampler_apply(grmr, &cur_p);
llama_sampler_apply(pnlt, &cur_p);
llama_sampler_apply(grmr, &cur_p);
llama_sampler_apply(chain, &cur_p); llama_sampler_apply(chain, &cur_p);
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");