common : simplify gpt_sampler
ggml-ci
This commit is contained in:
parent
757a9bf868
commit
befcfe7a31
1 changed files with 19 additions and 30 deletions
|
@ -98,10 +98,7 @@ struct ring_buffer {
|
|||
struct gpt_sampler {
|
||||
gpt_sampler_params params;
|
||||
|
||||
struct llama_sampler * bias;
|
||||
struct llama_sampler * pnlt;
|
||||
struct llama_sampler * grmr;
|
||||
|
||||
struct llama_sampler * chain;
|
||||
|
||||
ring_buffer<llama_token> prev;
|
||||
|
@ -140,11 +137,11 @@ std::string gpt_sampler_params::print() const {
|
|||
}
|
||||
|
||||
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
|
||||
std::string result = "\tlogits";
|
||||
std::string result = "\tlogits ";
|
||||
|
||||
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
|
||||
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
||||
result += std::string(" -> ") + llama_sampler_name(smpl) + " ";
|
||||
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
|
||||
}
|
||||
|
||||
return result;
|
||||
|
@ -157,18 +154,6 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
|||
|
||||
auto * result = new gpt_sampler {
|
||||
/* .params = */ params,
|
||||
/* .bias = */ llama_sampler_init_logit_bias(
|
||||
model,
|
||||
params.logit_bias.size(),
|
||||
params.logit_bias.data()),
|
||||
/* .pnlt = */ llama_sampler_init_penalties(
|
||||
model,
|
||||
params.penalty_last_n,
|
||||
params.penalty_repeat,
|
||||
params.penalty_freq,
|
||||
params.penalty_present,
|
||||
params.penalize_nl,
|
||||
params.ignore_eos),
|
||||
/* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
|
||||
/* .chain = */ llama_sampler_chain_init(lparams),
|
||||
/* .prev = */ ring_buffer<llama_token>(params.n_prev),
|
||||
|
@ -176,6 +161,22 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
|||
/* .cur_p = */ {},
|
||||
};
|
||||
|
||||
llama_sampler_chain_add(result->chain,
|
||||
llama_sampler_init_logit_bias(
|
||||
model,
|
||||
params.logit_bias.size(),
|
||||
params.logit_bias.data()));
|
||||
|
||||
llama_sampler_chain_add(result->chain,
|
||||
llama_sampler_init_penalties(
|
||||
model,
|
||||
params.penalty_last_n,
|
||||
params.penalty_repeat,
|
||||
params.penalty_freq,
|
||||
params.penalty_present,
|
||||
params.penalize_nl,
|
||||
params.ignore_eos));
|
||||
|
||||
if (params.temp > 0.0f) {
|
||||
if (params.mirostat == 0) {
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
|
@ -223,8 +224,6 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
|||
|
||||
void gpt_sampler_free(struct gpt_sampler * gsmpl) {
|
||||
if (gsmpl) {
|
||||
llama_sampler_free(gsmpl->bias);
|
||||
llama_sampler_free(gsmpl->pnlt);
|
||||
llama_sampler_free(gsmpl->grmr);
|
||||
|
||||
llama_sampler_free(gsmpl->chain);
|
||||
|
@ -236,8 +235,6 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl) {
|
|||
struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
|
||||
return new gpt_sampler {
|
||||
/* .params = */ gsmpl->params,
|
||||
/* .bias = */ llama_sampler_clone(gsmpl->bias),
|
||||
/* .pnlt = */ llama_sampler_clone(gsmpl->pnlt),
|
||||
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||
/* .prev = */ gsmpl->prev,
|
||||
|
@ -282,8 +279,6 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
|
|||
}
|
||||
|
||||
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
||||
auto & bias = gsmpl->bias;
|
||||
auto & pnlt = gsmpl->pnlt;
|
||||
auto & grmr = gsmpl->grmr;
|
||||
auto & chain = gsmpl->chain;
|
||||
|
||||
|
@ -291,9 +286,6 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
|
|||
|
||||
auto & cur_p = gsmpl->cur_p;
|
||||
|
||||
llama_sampler_apply(bias, &cur_p);
|
||||
llama_sampler_apply(pnlt, &cur_p);
|
||||
|
||||
if (grammar_first) {
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
}
|
||||
|
@ -325,10 +317,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
|
|||
// if the token is not valid, sample again, first apply the grammar samplers and then sample
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
llama_sampler_apply(bias, &cur_p);
|
||||
llama_sampler_apply(pnlt, &cur_p);
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
llama_sampler_apply(chain, &cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue