diff --git a/include/llama.h b/include/llama.h index bd756fc5c..76c1aaf98 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1176,11 +1176,11 @@ extern "C" { typedef void * llama_constraint_context_t; struct llama_constraint_i { - void (*accept)(struct llama_constraint * cnstr, llama_token token); // can be NULL - void (*apply) (struct llama_constraint * cnstr, llama_token_data_array * candidates); - void (*reset) (struct llama_constraint * cnstr); // e.g. for grammar and penalty constraints, can be NULL - void (*copy) (struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src); - void (*free) (struct llama_constraint * cnstr); // can be NULL + void (*accept)(struct llama_constraint * cnstr, llama_token token); // can be NULL + void (*apply) (struct llama_constraint * cnstr, llama_token_data_array * candidates); // required + void (*reset) (struct llama_constraint * cnstr); // can be NULL + void (*copy) (struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src); // can be NULL if ctx is NULL + void (*free) (struct llama_constraint * cnstr); // can be NULL // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph //void (*apply_ggml) (struct llama_constraint * cnstr, ...); @@ -1191,9 +1191,14 @@ extern "C" { llama_constraint_context_t ctx; }; - LLAMA_API struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep); - // ... + LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_tail_free(float z, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t); + LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent); + LLAMA_API struct llama_constraint * llama_constraint_init_grammar (struct llama_model * model, const char * grammar_str, const char * grammar_root); // do not call if used with llama_sampler_add_constraint LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr); diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 8cd98bae4..092a738aa 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1043,6 +1043,10 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, } void llama_grammar_free_impl(struct llama_grammar * grammar) { + if (grammar == nullptr) { + return; + } + delete grammar; } diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 5ccbf029f..becc949dc 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -718,6 +718,269 @@ struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_k return result; } +// min-p + +struct llama_constraint_context_min_p { + float p; + size_t min_keep; +}; + +static struct llama_constraint_i llama_constraint_min_p_i = { + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + llama_constraint_context_min_p * ctx = (llama_constraint_context_min_p *) cnstr->ctx; + llama_sampling_min_p_impl(candidates, ctx->p, ctx->min_keep); + }, + /* .reset = */ nullptr, + /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { + cnstr->ctx = new llama_constraint_context_min_p; + const auto * ctx_src = (const llama_constraint_context_min_p *) cnstr_src->ctx; + auto * ctx_dst = ( llama_constraint_context_min_p *) cnstr->ctx; + *ctx_dst = *ctx_src; + }, + /* .free = */ [](struct llama_constraint * cnstr) { + if (cnstr->ctx) { + delete (llama_constraint_context_min_p *) cnstr->ctx; + } + delete cnstr; + } +}; + +struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_keep) { + struct llama_constraint * result = new llama_constraint; + + result->iface = &llama_constraint_min_p_i; + result->ctx = new llama_constraint_context_min_p{p, min_keep}; + + return result; +} + +// tail-free + +struct llama_constraint_context_tail_free { + float z; + size_t min_keep; +}; + +static struct llama_constraint_i llama_constraint_tail_free_i = { + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + llama_constraint_context_tail_free * ctx = (llama_constraint_context_tail_free *) cnstr->ctx; + llama_sampling_tail_free_impl(candidates, ctx->z, ctx->min_keep); + }, + /* .reset = */ nullptr, + /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { + cnstr->ctx = new llama_constraint_context_tail_free; + const auto * ctx_src = (const llama_constraint_context_tail_free *) cnstr_src->ctx; + auto * ctx_dst = ( llama_constraint_context_tail_free *) cnstr->ctx; + *ctx_dst = *ctx_src; + }, + /* .free = */ [](struct llama_constraint * cnstr) { + if (cnstr->ctx) { + delete (llama_constraint_context_tail_free *) cnstr->ctx; + } + delete cnstr; + } +}; + +struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t min_keep) { + struct llama_constraint * result = new llama_constraint; + + result->iface = &llama_constraint_tail_free_i; + result->ctx = new llama_constraint_context_tail_free{z, min_keep}; + + return result; +} + +// typical + +struct llama_constraint_context_typical { + float p; + size_t min_keep; +}; + +static struct llama_constraint_i llama_constraint_typical_i = { + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + llama_constraint_context_typical * ctx = (llama_constraint_context_typical *) cnstr->ctx; + llama_sampling_typical_impl(candidates, ctx->p, ctx->min_keep); + }, + /* .reset = */ nullptr, + /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { + cnstr->ctx = new llama_constraint_context_typical; + const auto * ctx_src = (const llama_constraint_context_typical *) cnstr_src->ctx; + auto * ctx_dst = ( llama_constraint_context_typical *) cnstr->ctx; + *ctx_dst = *ctx_src; + }, + /* .free = */ [](struct llama_constraint * cnstr) { + if (cnstr->ctx) { + delete (llama_constraint_context_typical *) cnstr->ctx; + } + delete cnstr; + } +}; + +struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min_keep) { + struct llama_constraint * result = new llama_constraint; + + result->iface = &llama_constraint_typical_i; + result->ctx = new llama_constraint_context_typical{p, min_keep}; + + return result; +} + +// temp + +struct llama_constraint_context_temp { + float temp; +}; + +static struct llama_constraint_i llama_constraint_temp_i = { + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + llama_constraint_context_temp * ctx = (llama_constraint_context_temp *) cnstr->ctx; + llama_sampling_temp_impl(candidates, ctx->temp); + }, + /* .reset = */ nullptr, + /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { + cnstr->ctx = new llama_constraint_context_temp; + const auto * ctx_src = (const llama_constraint_context_temp *) cnstr_src->ctx; + auto * ctx_dst = ( llama_constraint_context_temp *) cnstr->ctx; + *ctx_dst = *ctx_src; + }, + /* .free = */ [](struct llama_constraint * cnstr) { + if (cnstr->ctx) { + delete (llama_constraint_context_temp *) cnstr->ctx; + } + delete cnstr; + } +}; + +struct llama_constraint * llama_constraint_init_temp_impl(float temp) { + struct llama_constraint * result = new llama_constraint; + + result->iface = &llama_constraint_temp_i; + result->ctx = new llama_constraint_context_temp{temp}; + + return result; +} + +// temp-ext + +struct llama_constraint_context_temp_ext { + float temp; + float delta; + float exponent; +}; + +static struct llama_constraint_i llama_constraint_temp_ext_i = { + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + llama_constraint_context_temp_ext * ctx = (llama_constraint_context_temp_ext *) cnstr->ctx; + if (ctx->delta > 0) { + const float temp_min = std::max(0.0f, ctx->temp - ctx->delta); + const float temp_max = ctx->temp + ctx->delta; + + llama_sampling_entropy_impl(candidates, temp_min, temp_max, ctx->exponent); + } else { + llama_sampling_temp_impl(candidates, ctx->temp); + } + }, + /* .reset = */ nullptr, + /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { + cnstr->ctx = new llama_constraint_context_temp_ext; + const auto * ctx_src = (const llama_constraint_context_temp_ext *) cnstr_src->ctx; + auto * ctx_dst = ( llama_constraint_context_temp_ext *) cnstr->ctx; + *ctx_dst = *ctx_src; + }, + /* .free = */ [](struct llama_constraint * cnstr) { + if (cnstr->ctx) { + delete (llama_constraint_context_temp_ext *) cnstr->ctx; + } + delete cnstr; + } +}; + +struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float delta, float exponent) { + struct llama_constraint * result = new llama_constraint; + + result->iface = &llama_constraint_temp_ext_i; + result->ctx = new llama_constraint_context_temp_ext{temp, delta, exponent}; + + return result; +} + +// grammar + +struct llama_constraint_context_grammar { + std::string grammar_str; + std::string grammar_root; + + struct llama_grammar * grammar; +}; + +static struct llama_constraint_i llama_constraint_grammar_i = { + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + llama_constraint_context_grammar * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + if (ctx->grammar) { + llama_sampling_grammar_impl(candidates, *ctx->grammar); + } + }, + /* .reset = */ [](struct llama_constraint * cnstr) { + llama_constraint_context_grammar * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + if (ctx->grammar) { + llama_grammar_free_impl(ctx->grammar); + ctx->grammar = nullptr; + } + + if (!ctx->grammar_str.empty()) { + ctx->grammar = llama_grammar_init_impl(nullptr, ctx->grammar_str.c_str(), ctx->grammar_root.c_str()); + } + }, + /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { + cnstr->ctx = new llama_constraint_context_grammar; + const auto * ctx_src = (const llama_constraint_context_grammar *) cnstr_src->ctx; + auto * ctx_dst = ( llama_constraint_context_grammar *) cnstr->ctx; + + *ctx_dst = *ctx_src; + + if (ctx_src->grammar) { + ctx_dst->grammar = llama_grammar_cp_impl(*ctx_src->grammar); + } else { + ctx_dst->grammar = nullptr; + } + }, + /* .free = */ [](struct llama_constraint * cnstr) { + if (cnstr->ctx) { + { + auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + llama_grammar_free_impl(ctx->grammar); + } + + delete (llama_constraint_context_grammar *) cnstr->ctx; + } + delete cnstr; + } +}; + +struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) { + struct llama_constraint * result = new llama_constraint; + + result->iface = &llama_constraint_grammar_i; + result->ctx = new llama_constraint_context_grammar; + + auto * ctx = (llama_constraint_context_grammar *) result->ctx; + + if (grammar_str != nullptr && grammar_str[0] != '\0') { + ctx->grammar = llama_grammar_init_impl(&vocab, grammar_str, grammar_root); + } else { + ctx->grammar = nullptr; + } + + return result; +} + void llama_constraint_free_impl(struct llama_constraint * cnstr) { if (cnstr->iface->free && cnstr) { cnstr->iface->free(cnstr); diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 2a98d103e..ed18da10d 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -113,8 +113,14 @@ int llama_sampling_n_prev_impl(const struct llama_sampling & smpl); // constraints -struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep); -struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_keep); +struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k, size_t min_keep); +struct llama_constraint * llama_constraint_init_top_p_impl (float p, size_t min_keep); +struct llama_constraint * llama_constraint_init_min_p_impl (float p, size_t min_keep); +struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t min_keep); +struct llama_constraint * llama_constraint_init_typical_impl (float p, size_t min_keep); +struct llama_constraint * llama_constraint_init_temp_impl (float t); +struct llama_constraint * llama_constraint_init_temp_ext_impl (float t, float delta, float exponent); +struct llama_constraint * llama_constraint_init_grammar_impl (const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root); void llama_constraint_free_impl(struct llama_constraint * cnstr); diff --git a/src/llama.cpp b/src/llama.cpp index 4d41c9262..a47ad0103 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20983,6 +20983,30 @@ struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep) return llama_constraint_init_top_p_impl(p, min_keep); } +struct llama_constraint * llama_constraint_init_min_p(float p, int32_t min_keep) { + return llama_constraint_init_min_p_impl(p, min_keep); +} + +struct llama_constraint * llama_constraint_init_tail_free(float z, int32_t min_keep) { + return llama_constraint_init_tail_free_impl(z, min_keep); +} + +struct llama_constraint * llama_constraint_init_typical(float p, int32_t min_keep) { + return llama_constraint_init_typical_impl(p, min_keep); +} + +struct llama_constraint * llama_constraint_init_temp(float temp) { + return llama_constraint_init_temp_impl(temp); +} + +struct llama_constraint * llama_constraint_init_temp_ext(float temp, float delta, float exponent) { + return llama_constraint_init_temp_ext_impl(temp, delta, exponent); +} + +struct llama_constraint * llama_constraint_init_grammar(struct llama_model * model, const char * grammar_str, const char * grammar_root) { + return llama_constraint_init_grammar_impl(model->vocab, grammar_str, grammar_root); +} + void llama_constraint_free(struct llama_constraint * cnstr) { if (cnstr == nullptr) { return;