common : simplify gpt_sampler

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-09-06 14:02:17 +03:00
parent 757a9bf868
commit befcfe7a31
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -98,10 +98,7 @@ struct ring_buffer {
struct gpt_sampler {
gpt_sampler_params params;
struct llama_sampler * bias;
struct llama_sampler * pnlt;
struct llama_sampler * grmr;
struct llama_sampler * chain;
ring_buffer<llama_token> prev;
@ -157,18 +154,6 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
auto * result = new gpt_sampler {
/* .params = */ params,
/* .bias = */ llama_sampler_init_logit_bias(
model,
params.logit_bias.size(),
params.logit_bias.data()),
/* .pnlt = */ llama_sampler_init_penalties(
model,
params.penalty_last_n,
params.penalty_repeat,
params.penalty_freq,
params.penalty_present,
params.penalize_nl,
params.ignore_eos),
/* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
/* .chain = */ llama_sampler_chain_init(lparams),
/* .prev = */ ring_buffer<llama_token>(params.n_prev),
@ -176,6 +161,22 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
/* .cur_p = */ {},
};
llama_sampler_chain_add(result->chain,
llama_sampler_init_logit_bias(
model,
params.logit_bias.size(),
params.logit_bias.data()));
llama_sampler_chain_add(result->chain,
llama_sampler_init_penalties(
model,
params.penalty_last_n,
params.penalty_repeat,
params.penalty_freq,
params.penalty_present,
params.penalize_nl,
params.ignore_eos));
if (params.temp > 0.0f) {
if (params.mirostat == 0) {
for (const auto & cnstr : params.samplers) {
@ -223,8 +224,6 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
void gpt_sampler_free(struct gpt_sampler * gsmpl) {
if (gsmpl) {
llama_sampler_free(gsmpl->bias);
llama_sampler_free(gsmpl->pnlt);
llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->chain);
@ -236,8 +235,6 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl) {
struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
return new gpt_sampler {
/* .params = */ gsmpl->params,
/* .bias = */ llama_sampler_clone(gsmpl->bias),
/* .pnlt = */ llama_sampler_clone(gsmpl->pnlt),
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .prev = */ gsmpl->prev,
@ -282,8 +279,6 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
}
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
auto & bias = gsmpl->bias;
auto & pnlt = gsmpl->pnlt;
auto & grmr = gsmpl->grmr;
auto & chain = gsmpl->chain;
@ -291,9 +286,6 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
auto & cur_p = gsmpl->cur_p;
llama_sampler_apply(bias, &cur_p);
llama_sampler_apply(pnlt, &cur_p);
if (grammar_first) {
llama_sampler_apply(grmr, &cur_p);
}
@ -325,10 +317,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
// if the token is not valid, sample again, first apply the grammar samplers and then sample
gsmpl->set_logits(ctx, idx);
llama_sampler_apply(bias, &cur_p);
llama_sampler_apply(pnlt, &cur_p);
llama_sampler_apply(grmr, &cur_p);
llama_sampler_apply(chain, &cur_p);
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");