constraint : add name API

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-09-04 15:37:51 +03:00
parent c024fe45b0
commit 1a0de0b781
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 52 additions and 42 deletions

View file

@ -198,9 +198,11 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
auto & grmr = gsmpl->grmr;
auto & smpl = gsmpl->smpl;
auto * cur_p = llama_sampler_get_candidates(smpl);
const auto * logits = llama_get_logits_ith(ctx, idx);
llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx));
llama_sampler_set_logits(smpl, logits);
auto * cur_p = llama_sampler_get_candidates(smpl);
llama_constraint_apply(bias, cur_p);
llama_constraint_apply(pnlt, cur_p);
@ -223,7 +225,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
}
// if the token is not valid, sample again, first apply the grammar constraints and then sample
llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx));
llama_sampler_set_logits(smpl, logits);
llama_constraint_apply(bias, cur_p);
llama_constraint_apply(pnlt, cur_p);

View file

@ -63,7 +63,7 @@ struct gpt_sampler_params {
// gpt_sampler extends llama_sampler with additional functionality:
//
// - grammar support
// - custom sampler logic based on the paramerters
// - custom sampler logic based on the parameters
//
struct gpt_sampler;

View file

@ -386,11 +386,11 @@ extern "C" {
double t_start_ms;
double t_end_ms;
double t_load_ms;
double t_sampling_ms;
double t_sampler_ms;
double t_p_eval_ms;
double t_eval_ms;
int32_t n_sampling;
int32_t n_sampler;
int32_t n_p_eval;
int32_t n_eval;
};
@ -1025,8 +1025,7 @@ extern "C" {
// user code can implement the interface below in order to create custom llama_constraint
struct llama_constraint_i {
// TODO: add name API
const char * (*name) (const struct llama_constraint * cnstr); // can be NULL
void (*accept)( struct llama_constraint * cnstr, llama_token token); // can be NULL
void (*apply) ( struct llama_constraint * cnstr, llama_token_data_array * cur_p); // required
void (*reset) ( struct llama_constraint * cnstr); // can be NULL
@ -1035,8 +1034,6 @@ extern "C" {
// TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
//void (*apply_ggml) (struct llama_constraint * cnstr, ...);
// TODO: add API to get timing stats
};
struct llama_constraint {
@ -1044,7 +1041,7 @@ extern "C" {
llama_constraint_context_t ctx;
};
LLAMA_API struct llama_constraint * llama_constraint_init_softmax ();
LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void);
LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep);

View file

@ -428,6 +428,7 @@ void llama_constraint_penalties_impl(
// softmax
static struct llama_constraint_i llama_constraint_softmax_i = {
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "softmax"; },
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_constraint * /*cnstr*/, llama_token_data_array * cur_p) {
llama_constraint_softmax_impl(cur_p);
@ -454,9 +455,10 @@ struct llama_constraint_context_top_k {
};
static struct llama_constraint_i llama_constraint_top_k_i = {
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "top-k"; },
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx;
const auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx;
llama_constraint_top_k_impl(cur_p, ctx->k, ctx->min_keep);
},
/* .reset = */ nullptr,
@ -466,7 +468,7 @@ static struct llama_constraint_i llama_constraint_top_k_i = {
},
/* .free = */ [](struct llama_constraint * cnstr) {
delete (llama_constraint_context_top_k *) cnstr->ctx;
}
},
};
struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep) {
@ -489,9 +491,10 @@ struct llama_constraint_context_top_p {
};
static struct llama_constraint_i llama_constraint_top_p_i = {
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "top-p"; },
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
auto * ctx = (llama_constraint_context_top_p *) cnstr->ctx;
const auto * ctx = (llama_constraint_context_top_p *) cnstr->ctx;
llama_constraint_top_p_impl(cur_p, ctx->p, ctx->min_keep);
},
/* .reset = */ nullptr,
@ -501,7 +504,7 @@ static struct llama_constraint_i llama_constraint_top_p_i = {
},
/* .free = */ [](struct llama_constraint * cnstr) {
delete (llama_constraint_context_top_p *) cnstr->ctx;
}
},
};
struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_keep) {
@ -524,9 +527,10 @@ struct llama_constraint_context_min_p {
};
static struct llama_constraint_i llama_constraint_min_p_i = {
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "min-p"; },
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
auto * ctx = (llama_constraint_context_min_p *) cnstr->ctx;
const auto * ctx = (llama_constraint_context_min_p *) cnstr->ctx;
llama_constraint_min_p_impl(cur_p, ctx->p, ctx->min_keep);
},
/* .reset = */ nullptr,
@ -536,7 +540,7 @@ static struct llama_constraint_i llama_constraint_min_p_i = {
},
/* .free = */ [](struct llama_constraint * cnstr) {
delete (llama_constraint_context_min_p *) cnstr->ctx;
}
},
};
struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_keep) {
@ -559,9 +563,10 @@ struct llama_constraint_context_tail_free {
};
static struct llama_constraint_i llama_constraint_tail_free_i = {
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "tail-free"; },
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
auto * ctx = (llama_constraint_context_tail_free *) cnstr->ctx;
const auto * ctx = (llama_constraint_context_tail_free *) cnstr->ctx;
llama_constraint_tail_free_impl(cur_p, ctx->z, ctx->min_keep);
},
/* .reset = */ nullptr,
@ -571,7 +576,7 @@ static struct llama_constraint_i llama_constraint_tail_free_i = {
},
/* .free = */ [](struct llama_constraint * cnstr) {
delete (llama_constraint_context_tail_free *) cnstr->ctx;
}
},
};
struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t min_keep) {
@ -594,9 +599,10 @@ struct llama_constraint_context_typical {
};
static struct llama_constraint_i llama_constraint_typical_i = {
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "typical"; },
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
auto * ctx = (llama_constraint_context_typical *) cnstr->ctx;
const auto * ctx = (llama_constraint_context_typical *) cnstr->ctx;
llama_constraint_typical_impl(cur_p, ctx->p, ctx->min_keep);
},
/* .reset = */ nullptr,
@ -606,7 +612,7 @@ static struct llama_constraint_i llama_constraint_typical_i = {
},
/* .free = */ [](struct llama_constraint * cnstr) {
delete (llama_constraint_context_typical *) cnstr->ctx;
}
},
};
struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min_keep) {
@ -628,9 +634,10 @@ struct llama_constraint_context_temp {
};
static struct llama_constraint_i llama_constraint_temp_i = {
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "temp"; },
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
auto * ctx = (llama_constraint_context_temp *) cnstr->ctx;
const auto * ctx = (llama_constraint_context_temp *) cnstr->ctx;
llama_constraint_temp_impl(cur_p, ctx->temp);
},
/* .reset = */ nullptr,
@ -640,7 +647,7 @@ static struct llama_constraint_i llama_constraint_temp_i = {
},
/* .free = */ [](struct llama_constraint * cnstr) {
delete (llama_constraint_context_temp *) cnstr->ctx;
}
},
};
struct llama_constraint * llama_constraint_init_temp_impl(float temp) {
@ -663,9 +670,10 @@ struct llama_constraint_context_temp_ext {
};
static struct llama_constraint_i llama_constraint_temp_ext_i = {
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "temp-ext"; },
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
auto * ctx = (llama_constraint_context_temp_ext *) cnstr->ctx;
const auto * ctx = (llama_constraint_context_temp_ext *) cnstr->ctx;
if (ctx->delta > 0) {
const float temp_min = std::max(0.0f, ctx->temp - ctx->delta);
const float temp_max = ctx->temp + ctx->delta;
@ -682,7 +690,7 @@ static struct llama_constraint_i llama_constraint_temp_ext_i = {
},
/* .free = */ [](struct llama_constraint * cnstr) {
delete (llama_constraint_context_temp_ext *) cnstr->ctx;
}
},
};
struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float delta, float exponent) {
@ -708,14 +716,15 @@ struct llama_constraint_context_grammar {
};
static struct llama_constraint_i llama_constraint_grammar_i = {
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "grammar"; },
/* .accept = */ [](struct llama_constraint * cnstr, llama_token token) {
auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx;
const auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx;
if (ctx->grammar) {
llama_grammar_accept_impl(*ctx->grammar, token);
}
},
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx;
const auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx;
if (ctx->grammar) {
llama_constraint_grammar_impl(cur_p, *ctx->grammar);
}
@ -747,14 +756,14 @@ static struct llama_constraint_i llama_constraint_grammar_i = {
return result;
},
/* .free = */ [](struct llama_constraint * cnstr) {
auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx;
const auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx;
if (ctx->grammar) {
llama_grammar_free_impl(ctx->grammar);
}
delete ctx;
}
},
};
struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
@ -799,6 +808,7 @@ struct llama_constraint_context_penalties {
};
static struct llama_constraint_i llama_constraint_penalties_i = {
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "penalties"; },
/* .accept = */ [](struct llama_constraint * cnstr, llama_token token) {
auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx;
ctx->prev.push_back(token);
@ -855,7 +865,7 @@ static struct llama_constraint_i llama_constraint_penalties_i = {
},
/* .free = */ [](struct llama_constraint * cnstr) {
delete (llama_constraint_context_penalties *) cnstr->ctx;
}
},
};
struct llama_constraint * llama_constraint_init_penalties_impl(const struct llama_vocab & vocab, int32_t penalty_last_n, float penalty_repeat, float penalty_freq, float penalty_present, bool penalize_nl, bool ignore_eos) {
@ -888,6 +898,7 @@ struct llama_constraint_context_logit_bias {
};
static struct llama_constraint_i llama_constraint_logit_bias_i = {
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "logit-bias"; },
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
auto * ctx = (llama_constraint_context_logit_bias *) cnstr->ctx;
@ -905,7 +916,7 @@ static struct llama_constraint_i llama_constraint_logit_bias_i = {
},
/* .free = */ [](struct llama_constraint * cnstr) {
delete (llama_constraint_context_logit_bias *) cnstr->ctx;
}
},
};
struct llama_constraint * llama_constraint_init_logit_bias_impl(

View file

@ -20609,7 +20609,7 @@ int32_t llama_chat_apply_template(
// sampling
//
struct llama_constraint * llama_constraint_init_softmax() {
struct llama_constraint * llama_constraint_init_softmax(void) {
return llama_constraint_init_softmax_impl();
}
@ -20852,11 +20852,11 @@ void llama_print_timings(struct llama_context * ctx, struct llama_sampler * smpl
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us,
/*.t_end_ms =*/ 1.00 * ggml_time_ms(),
/*.t_load_ms =*/ 1e-3 * ctx->t_load_us,
/*.t_sampling_ms =*/ 1e-3 * (smpl ? smpl->t_sample_us : 0.0),
/*.t_sampler_ms =*/ 1e-3 * (smpl ? smpl->t_sample_us : 0.0),
/*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us,
/*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us,
/*.n_sampling =*/ std::max(0, smpl ? smpl->n_sample : 0),
/*.n_sampler =*/ std::max(0, smpl ? smpl->n_sample : 0),
/*.n_p_eval =*/ std::max(0, ctx->n_p_eval),
/*.n_eval =*/ std::max(1, ctx->n_eval),
};
@ -20864,7 +20864,7 @@ void llama_print_timings(struct llama_context * ctx, struct llama_sampler * smpl
LLAMA_LOG_INFO("\n");
LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms);
LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, timings.t_sampling_ms, timings.n_sampling, timings.t_sampling_ms / timings.n_sampling, 1e3 / timings.t_sampling_ms * timings.n_sampling);
__func__, timings.t_sampler_ms, timings.n_sampler, timings.t_sampler_ms / timings.n_sampler, 1e3 / timings.t_sampler_ms * timings.n_sampler);
LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval);
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",