constraint : add name API
ggml-ci
This commit is contained in:
parent
c024fe45b0
commit
1a0de0b781
5 changed files with 52 additions and 42 deletions
|
@ -198,9 +198,11 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
|
|||
auto & grmr = gsmpl->grmr;
|
||||
auto & smpl = gsmpl->smpl;
|
||||
|
||||
auto * cur_p = llama_sampler_get_candidates(smpl);
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
|
||||
llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx));
|
||||
llama_sampler_set_logits(smpl, logits);
|
||||
|
||||
auto * cur_p = llama_sampler_get_candidates(smpl);
|
||||
|
||||
llama_constraint_apply(bias, cur_p);
|
||||
llama_constraint_apply(pnlt, cur_p);
|
||||
|
@ -223,7 +225,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
|
|||
}
|
||||
|
||||
// if the token is not valid, sample again, first apply the grammar constraints and then sample
|
||||
llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx));
|
||||
llama_sampler_set_logits(smpl, logits);
|
||||
|
||||
llama_constraint_apply(bias, cur_p);
|
||||
llama_constraint_apply(pnlt, cur_p);
|
||||
|
|
|
@ -63,7 +63,7 @@ struct gpt_sampler_params {
|
|||
// gpt_sampler extends llama_sampler with additional functionality:
|
||||
//
|
||||
// - grammar support
|
||||
// - custom sampler logic based on the paramerters
|
||||
// - custom sampler logic based on the parameters
|
||||
//
|
||||
struct gpt_sampler;
|
||||
|
||||
|
|
|
@ -386,11 +386,11 @@ extern "C" {
|
|||
double t_start_ms;
|
||||
double t_end_ms;
|
||||
double t_load_ms;
|
||||
double t_sampling_ms;
|
||||
double t_sampler_ms;
|
||||
double t_p_eval_ms;
|
||||
double t_eval_ms;
|
||||
|
||||
int32_t n_sampling;
|
||||
int32_t n_sampler;
|
||||
int32_t n_p_eval;
|
||||
int32_t n_eval;
|
||||
};
|
||||
|
@ -1025,8 +1025,7 @@ extern "C" {
|
|||
|
||||
// user code can implement the interface below in order to create custom llama_constraint
|
||||
struct llama_constraint_i {
|
||||
// TODO: add name API
|
||||
|
||||
const char * (*name) (const struct llama_constraint * cnstr); // can be NULL
|
||||
void (*accept)( struct llama_constraint * cnstr, llama_token token); // can be NULL
|
||||
void (*apply) ( struct llama_constraint * cnstr, llama_token_data_array * cur_p); // required
|
||||
void (*reset) ( struct llama_constraint * cnstr); // can be NULL
|
||||
|
@ -1035,8 +1034,6 @@ extern "C" {
|
|||
|
||||
// TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
|
||||
//void (*apply_ggml) (struct llama_constraint * cnstr, ...);
|
||||
|
||||
// TODO: add API to get timing stats
|
||||
};
|
||||
|
||||
struct llama_constraint {
|
||||
|
@ -1044,7 +1041,7 @@ extern "C" {
|
|||
llama_constraint_context_t ctx;
|
||||
};
|
||||
|
||||
LLAMA_API struct llama_constraint * llama_constraint_init_softmax ();
|
||||
LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void);
|
||||
LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep);
|
||||
LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep);
|
||||
LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep);
|
||||
|
|
|
@ -428,6 +428,7 @@ void llama_constraint_penalties_impl(
|
|||
// softmax
|
||||
|
||||
static struct llama_constraint_i llama_constraint_softmax_i = {
|
||||
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "softmax"; },
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ [](struct llama_constraint * /*cnstr*/, llama_token_data_array * cur_p) {
|
||||
llama_constraint_softmax_impl(cur_p);
|
||||
|
@ -454,9 +455,10 @@ struct llama_constraint_context_top_k {
|
|||
};
|
||||
|
||||
static struct llama_constraint_i llama_constraint_top_k_i = {
|
||||
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "top-k"; },
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx;
|
||||
const auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx;
|
||||
llama_constraint_top_k_impl(cur_p, ctx->k, ctx->min_keep);
|
||||
},
|
||||
/* .reset = */ nullptr,
|
||||
|
@ -466,7 +468,7 @@ static struct llama_constraint_i llama_constraint_top_k_i = {
|
|||
},
|
||||
/* .free = */ [](struct llama_constraint * cnstr) {
|
||||
delete (llama_constraint_context_top_k *) cnstr->ctx;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep) {
|
||||
|
@ -489,9 +491,10 @@ struct llama_constraint_context_top_p {
|
|||
};
|
||||
|
||||
static struct llama_constraint_i llama_constraint_top_p_i = {
|
||||
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "top-p"; },
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_constraint_context_top_p *) cnstr->ctx;
|
||||
const auto * ctx = (llama_constraint_context_top_p *) cnstr->ctx;
|
||||
llama_constraint_top_p_impl(cur_p, ctx->p, ctx->min_keep);
|
||||
},
|
||||
/* .reset = */ nullptr,
|
||||
|
@ -501,7 +504,7 @@ static struct llama_constraint_i llama_constraint_top_p_i = {
|
|||
},
|
||||
/* .free = */ [](struct llama_constraint * cnstr) {
|
||||
delete (llama_constraint_context_top_p *) cnstr->ctx;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_keep) {
|
||||
|
@ -524,9 +527,10 @@ struct llama_constraint_context_min_p {
|
|||
};
|
||||
|
||||
static struct llama_constraint_i llama_constraint_min_p_i = {
|
||||
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "min-p"; },
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_constraint_context_min_p *) cnstr->ctx;
|
||||
const auto * ctx = (llama_constraint_context_min_p *) cnstr->ctx;
|
||||
llama_constraint_min_p_impl(cur_p, ctx->p, ctx->min_keep);
|
||||
},
|
||||
/* .reset = */ nullptr,
|
||||
|
@ -536,7 +540,7 @@ static struct llama_constraint_i llama_constraint_min_p_i = {
|
|||
},
|
||||
/* .free = */ [](struct llama_constraint * cnstr) {
|
||||
delete (llama_constraint_context_min_p *) cnstr->ctx;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_keep) {
|
||||
|
@ -559,9 +563,10 @@ struct llama_constraint_context_tail_free {
|
|||
};
|
||||
|
||||
static struct llama_constraint_i llama_constraint_tail_free_i = {
|
||||
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "tail-free"; },
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_constraint_context_tail_free *) cnstr->ctx;
|
||||
const auto * ctx = (llama_constraint_context_tail_free *) cnstr->ctx;
|
||||
llama_constraint_tail_free_impl(cur_p, ctx->z, ctx->min_keep);
|
||||
},
|
||||
/* .reset = */ nullptr,
|
||||
|
@ -571,7 +576,7 @@ static struct llama_constraint_i llama_constraint_tail_free_i = {
|
|||
},
|
||||
/* .free = */ [](struct llama_constraint * cnstr) {
|
||||
delete (llama_constraint_context_tail_free *) cnstr->ctx;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t min_keep) {
|
||||
|
@ -594,9 +599,10 @@ struct llama_constraint_context_typical {
|
|||
};
|
||||
|
||||
static struct llama_constraint_i llama_constraint_typical_i = {
|
||||
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "typical"; },
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_constraint_context_typical *) cnstr->ctx;
|
||||
const auto * ctx = (llama_constraint_context_typical *) cnstr->ctx;
|
||||
llama_constraint_typical_impl(cur_p, ctx->p, ctx->min_keep);
|
||||
},
|
||||
/* .reset = */ nullptr,
|
||||
|
@ -606,7 +612,7 @@ static struct llama_constraint_i llama_constraint_typical_i = {
|
|||
},
|
||||
/* .free = */ [](struct llama_constraint * cnstr) {
|
||||
delete (llama_constraint_context_typical *) cnstr->ctx;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min_keep) {
|
||||
|
@ -628,9 +634,10 @@ struct llama_constraint_context_temp {
|
|||
};
|
||||
|
||||
static struct llama_constraint_i llama_constraint_temp_i = {
|
||||
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "temp"; },
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_constraint_context_temp *) cnstr->ctx;
|
||||
const auto * ctx = (llama_constraint_context_temp *) cnstr->ctx;
|
||||
llama_constraint_temp_impl(cur_p, ctx->temp);
|
||||
},
|
||||
/* .reset = */ nullptr,
|
||||
|
@ -640,7 +647,7 @@ static struct llama_constraint_i llama_constraint_temp_i = {
|
|||
},
|
||||
/* .free = */ [](struct llama_constraint * cnstr) {
|
||||
delete (llama_constraint_context_temp *) cnstr->ctx;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_constraint * llama_constraint_init_temp_impl(float temp) {
|
||||
|
@ -663,9 +670,10 @@ struct llama_constraint_context_temp_ext {
|
|||
};
|
||||
|
||||
static struct llama_constraint_i llama_constraint_temp_ext_i = {
|
||||
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "temp-ext"; },
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_constraint_context_temp_ext *) cnstr->ctx;
|
||||
const auto * ctx = (llama_constraint_context_temp_ext *) cnstr->ctx;
|
||||
if (ctx->delta > 0) {
|
||||
const float temp_min = std::max(0.0f, ctx->temp - ctx->delta);
|
||||
const float temp_max = ctx->temp + ctx->delta;
|
||||
|
@ -682,7 +690,7 @@ static struct llama_constraint_i llama_constraint_temp_ext_i = {
|
|||
},
|
||||
/* .free = */ [](struct llama_constraint * cnstr) {
|
||||
delete (llama_constraint_context_temp_ext *) cnstr->ctx;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float delta, float exponent) {
|
||||
|
@ -708,14 +716,15 @@ struct llama_constraint_context_grammar {
|
|||
};
|
||||
|
||||
static struct llama_constraint_i llama_constraint_grammar_i = {
|
||||
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "grammar"; },
|
||||
/* .accept = */ [](struct llama_constraint * cnstr, llama_token token) {
|
||||
auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx;
|
||||
const auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx;
|
||||
if (ctx->grammar) {
|
||||
llama_grammar_accept_impl(*ctx->grammar, token);
|
||||
}
|
||||
},
|
||||
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx;
|
||||
const auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx;
|
||||
if (ctx->grammar) {
|
||||
llama_constraint_grammar_impl(cur_p, *ctx->grammar);
|
||||
}
|
||||
|
@ -747,14 +756,14 @@ static struct llama_constraint_i llama_constraint_grammar_i = {
|
|||
return result;
|
||||
},
|
||||
/* .free = */ [](struct llama_constraint * cnstr) {
|
||||
auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx;
|
||||
const auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx;
|
||||
|
||||
if (ctx->grammar) {
|
||||
llama_grammar_free_impl(ctx->grammar);
|
||||
}
|
||||
|
||||
delete ctx;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
|
||||
|
@ -799,6 +808,7 @@ struct llama_constraint_context_penalties {
|
|||
};
|
||||
|
||||
static struct llama_constraint_i llama_constraint_penalties_i = {
|
||||
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "penalties"; },
|
||||
/* .accept = */ [](struct llama_constraint * cnstr, llama_token token) {
|
||||
auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx;
|
||||
ctx->prev.push_back(token);
|
||||
|
@ -855,7 +865,7 @@ static struct llama_constraint_i llama_constraint_penalties_i = {
|
|||
},
|
||||
/* .free = */ [](struct llama_constraint * cnstr) {
|
||||
delete (llama_constraint_context_penalties *) cnstr->ctx;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_constraint * llama_constraint_init_penalties_impl(const struct llama_vocab & vocab, int32_t penalty_last_n, float penalty_repeat, float penalty_freq, float penalty_present, bool penalize_nl, bool ignore_eos) {
|
||||
|
@ -888,6 +898,7 @@ struct llama_constraint_context_logit_bias {
|
|||
};
|
||||
|
||||
static struct llama_constraint_i llama_constraint_logit_bias_i = {
|
||||
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "logit-bias"; },
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_constraint_context_logit_bias *) cnstr->ctx;
|
||||
|
@ -905,7 +916,7 @@ static struct llama_constraint_i llama_constraint_logit_bias_i = {
|
|||
},
|
||||
/* .free = */ [](struct llama_constraint * cnstr) {
|
||||
delete (llama_constraint_context_logit_bias *) cnstr->ctx;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_constraint * llama_constraint_init_logit_bias_impl(
|
||||
|
|
|
@ -20609,7 +20609,7 @@ int32_t llama_chat_apply_template(
|
|||
// sampling
|
||||
//
|
||||
|
||||
struct llama_constraint * llama_constraint_init_softmax() {
|
||||
struct llama_constraint * llama_constraint_init_softmax(void) {
|
||||
return llama_constraint_init_softmax_impl();
|
||||
}
|
||||
|
||||
|
@ -20852,11 +20852,11 @@ void llama_print_timings(struct llama_context * ctx, struct llama_sampler * smpl
|
|||
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us,
|
||||
/*.t_end_ms =*/ 1.00 * ggml_time_ms(),
|
||||
/*.t_load_ms =*/ 1e-3 * ctx->t_load_us,
|
||||
/*.t_sampling_ms =*/ 1e-3 * (smpl ? smpl->t_sample_us : 0.0),
|
||||
/*.t_sampler_ms =*/ 1e-3 * (smpl ? smpl->t_sample_us : 0.0),
|
||||
/*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us,
|
||||
/*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us,
|
||||
|
||||
/*.n_sampling =*/ std::max(0, smpl ? smpl->n_sample : 0),
|
||||
/*.n_sampler =*/ std::max(0, smpl ? smpl->n_sample : 0),
|
||||
/*.n_p_eval =*/ std::max(0, ctx->n_p_eval),
|
||||
/*.n_eval =*/ std::max(1, ctx->n_eval),
|
||||
};
|
||||
|
@ -20864,7 +20864,7 @@ void llama_print_timings(struct llama_context * ctx, struct llama_sampler * smpl
|
|||
LLAMA_LOG_INFO("\n");
|
||||
LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms);
|
||||
LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, timings.t_sampling_ms, timings.n_sampling, timings.t_sampling_ms / timings.n_sampling, 1e3 / timings.t_sampling_ms * timings.n_sampling);
|
||||
__func__, timings.t_sampler_ms, timings.n_sampler, timings.t_sampler_ms / timings.n_sampler, 1e3 / timings.t_sampler_ms * timings.n_sampler);
|
||||
LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval);
|
||||
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue