common : simplify gpt_sampler
ggml-ci
This commit is contained in:
parent
757a9bf868
commit
befcfe7a31
1 changed files with 19 additions and 30 deletions
|
@ -98,10 +98,7 @@ struct ring_buffer {
|
||||||
struct gpt_sampler {
|
struct gpt_sampler {
|
||||||
gpt_sampler_params params;
|
gpt_sampler_params params;
|
||||||
|
|
||||||
struct llama_sampler * bias;
|
|
||||||
struct llama_sampler * pnlt;
|
|
||||||
struct llama_sampler * grmr;
|
struct llama_sampler * grmr;
|
||||||
|
|
||||||
struct llama_sampler * chain;
|
struct llama_sampler * chain;
|
||||||
|
|
||||||
ring_buffer<llama_token> prev;
|
ring_buffer<llama_token> prev;
|
||||||
|
@ -140,11 +137,11 @@ std::string gpt_sampler_params::print() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
|
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
|
||||||
std::string result = "\tlogits";
|
std::string result = "\tlogits ";
|
||||||
|
|
||||||
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
|
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
|
||||||
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
||||||
result += std::string(" -> ") + llama_sampler_name(smpl) + " ";
|
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -157,18 +154,6 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
||||||
|
|
||||||
auto * result = new gpt_sampler {
|
auto * result = new gpt_sampler {
|
||||||
/* .params = */ params,
|
/* .params = */ params,
|
||||||
/* .bias = */ llama_sampler_init_logit_bias(
|
|
||||||
model,
|
|
||||||
params.logit_bias.size(),
|
|
||||||
params.logit_bias.data()),
|
|
||||||
/* .pnlt = */ llama_sampler_init_penalties(
|
|
||||||
model,
|
|
||||||
params.penalty_last_n,
|
|
||||||
params.penalty_repeat,
|
|
||||||
params.penalty_freq,
|
|
||||||
params.penalty_present,
|
|
||||||
params.penalize_nl,
|
|
||||||
params.ignore_eos),
|
|
||||||
/* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
|
/* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
|
||||||
/* .chain = */ llama_sampler_chain_init(lparams),
|
/* .chain = */ llama_sampler_chain_init(lparams),
|
||||||
/* .prev = */ ring_buffer<llama_token>(params.n_prev),
|
/* .prev = */ ring_buffer<llama_token>(params.n_prev),
|
||||||
|
@ -176,6 +161,22 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
||||||
/* .cur_p = */ {},
|
/* .cur_p = */ {},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
llama_sampler_chain_add(result->chain,
|
||||||
|
llama_sampler_init_logit_bias(
|
||||||
|
model,
|
||||||
|
params.logit_bias.size(),
|
||||||
|
params.logit_bias.data()));
|
||||||
|
|
||||||
|
llama_sampler_chain_add(result->chain,
|
||||||
|
llama_sampler_init_penalties(
|
||||||
|
model,
|
||||||
|
params.penalty_last_n,
|
||||||
|
params.penalty_repeat,
|
||||||
|
params.penalty_freq,
|
||||||
|
params.penalty_present,
|
||||||
|
params.penalize_nl,
|
||||||
|
params.ignore_eos));
|
||||||
|
|
||||||
if (params.temp > 0.0f) {
|
if (params.temp > 0.0f) {
|
||||||
if (params.mirostat == 0) {
|
if (params.mirostat == 0) {
|
||||||
for (const auto & cnstr : params.samplers) {
|
for (const auto & cnstr : params.samplers) {
|
||||||
|
@ -223,8 +224,6 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
||||||
|
|
||||||
void gpt_sampler_free(struct gpt_sampler * gsmpl) {
|
void gpt_sampler_free(struct gpt_sampler * gsmpl) {
|
||||||
if (gsmpl) {
|
if (gsmpl) {
|
||||||
llama_sampler_free(gsmpl->bias);
|
|
||||||
llama_sampler_free(gsmpl->pnlt);
|
|
||||||
llama_sampler_free(gsmpl->grmr);
|
llama_sampler_free(gsmpl->grmr);
|
||||||
|
|
||||||
llama_sampler_free(gsmpl->chain);
|
llama_sampler_free(gsmpl->chain);
|
||||||
|
@ -236,8 +235,6 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl) {
|
||||||
struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
|
struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
|
||||||
return new gpt_sampler {
|
return new gpt_sampler {
|
||||||
/* .params = */ gsmpl->params,
|
/* .params = */ gsmpl->params,
|
||||||
/* .bias = */ llama_sampler_clone(gsmpl->bias),
|
|
||||||
/* .pnlt = */ llama_sampler_clone(gsmpl->pnlt),
|
|
||||||
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
||||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||||
/* .prev = */ gsmpl->prev,
|
/* .prev = */ gsmpl->prev,
|
||||||
|
@ -282,8 +279,6 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
||||||
auto & bias = gsmpl->bias;
|
|
||||||
auto & pnlt = gsmpl->pnlt;
|
|
||||||
auto & grmr = gsmpl->grmr;
|
auto & grmr = gsmpl->grmr;
|
||||||
auto & chain = gsmpl->chain;
|
auto & chain = gsmpl->chain;
|
||||||
|
|
||||||
|
@ -291,9 +286,6 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
|
||||||
|
|
||||||
auto & cur_p = gsmpl->cur_p;
|
auto & cur_p = gsmpl->cur_p;
|
||||||
|
|
||||||
llama_sampler_apply(bias, &cur_p);
|
|
||||||
llama_sampler_apply(pnlt, &cur_p);
|
|
||||||
|
|
||||||
if (grammar_first) {
|
if (grammar_first) {
|
||||||
llama_sampler_apply(grmr, &cur_p);
|
llama_sampler_apply(grmr, &cur_p);
|
||||||
}
|
}
|
||||||
|
@ -325,10 +317,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
|
||||||
// if the token is not valid, sample again, first apply the grammar samplers and then sample
|
// if the token is not valid, sample again, first apply the grammar samplers and then sample
|
||||||
gsmpl->set_logits(ctx, idx);
|
gsmpl->set_logits(ctx, idx);
|
||||||
|
|
||||||
llama_sampler_apply(bias, &cur_p);
|
llama_sampler_apply(grmr, &cur_p);
|
||||||
llama_sampler_apply(pnlt, &cur_p);
|
|
||||||
llama_sampler_apply(grmr, &cur_p);
|
|
||||||
|
|
||||||
llama_sampler_apply(chain, &cur_p);
|
llama_sampler_apply(chain, &cur_p);
|
||||||
|
|
||||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue