From a2ce91cbef4117d0d5d91671db2c6525d80d516c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 3 Sep 2024 16:04:22 +0300 Subject: [PATCH] cont : add penalties and logit-bias constraints [no ci] --- common/sampling.cpp | 88 ++++++------- common/sampling.h | 35 ++++-- include/llama.h | 95 ++++++++++---- src/llama-sampling.cpp | 277 +++++++++++++++++++++++++++++++++++++---- src/llama-sampling.h | 41 +++++- src/llama.cpp | 139 +++++++++++++++++++-- 6 files changed, 556 insertions(+), 119 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 96cfbe0ef..a98117cf8 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -128,57 +128,57 @@ std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx_m return result; } -char llama_sampling_type_to_chr(llama_sampler_type sampler) { +char llama_sampling_type_to_chr(llama_constraint_type sampler) { switch (sampler) { - case LLAMA_SAMPLER_TYPE_TOP_K: return 'k'; - case LLAMA_SAMPLER_TYPE_TFS_Z: return 'f'; - case LLAMA_SAMPLER_TYPE_TYPICAL_P: return 'y'; - case LLAMA_SAMPLER_TYPE_TOP_P: return 'p'; - case LLAMA_SAMPLER_TYPE_MIN_P: return 'm'; - case LLAMA_SAMPLER_TYPE_TEMPERATURE: return 't'; + case LLAMA_CONSTRAINT_TYPE_TOP_K: return 'k'; + case LLAMA_CONSTRAINT_TYPE_TFS_Z: return 'f'; + case LLAMA_CONSTRAINT_TYPE_TYPICAL_P: return 'y'; + case LLAMA_CONSTRAINT_TYPE_TOP_P: return 'p'; + case LLAMA_CONSTRAINT_TYPE_MIN_P: return 'm'; + case LLAMA_CONSTRAINT_TYPE_TEMPERATURE: return 't'; default : return '?'; } } -std::string llama_sampling_type_to_str(llama_sampler_type sampler) { +std::string llama_sampling_type_to_str(llama_constraint_type sampler) { switch (sampler) { - case LLAMA_SAMPLER_TYPE_TOP_K: return "top_k"; - case LLAMA_SAMPLER_TYPE_TFS_Z: return "tfs_z"; - case LLAMA_SAMPLER_TYPE_TYPICAL_P: return "typ_p"; - case LLAMA_SAMPLER_TYPE_TOP_P: return "top_p"; - case LLAMA_SAMPLER_TYPE_MIN_P: return "min_p"; - case LLAMA_SAMPLER_TYPE_TEMPERATURE: return "temperature"; + case LLAMA_CONSTRAINT_TYPE_TOP_K: return "top_k"; + case LLAMA_CONSTRAINT_TYPE_TFS_Z: return "tfs_z"; + case LLAMA_CONSTRAINT_TYPE_TYPICAL_P: return "typ_p"; + case LLAMA_CONSTRAINT_TYPE_TOP_P: return "top_p"; + case LLAMA_CONSTRAINT_TYPE_MIN_P: return "min_p"; + case LLAMA_CONSTRAINT_TYPE_TEMPERATURE: return "temperature"; default : return ""; } } -std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names) { - std::unordered_map sampler_canonical_name_map { - { "top_k", LLAMA_SAMPLER_TYPE_TOP_K }, - { "top_p", LLAMA_SAMPLER_TYPE_TOP_P }, - { "typ_p", LLAMA_SAMPLER_TYPE_TYPICAL_P }, - { "min_p", LLAMA_SAMPLER_TYPE_MIN_P }, - { "tfs_z", LLAMA_SAMPLER_TYPE_TFS_Z }, - { "temperature", LLAMA_SAMPLER_TYPE_TEMPERATURE }, +std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names) { + std::unordered_map sampler_canonical_name_map { + { "top_k", LLAMA_CONSTRAINT_TYPE_TOP_K }, + { "top_p", LLAMA_CONSTRAINT_TYPE_TOP_P }, + { "typ_p", LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, + { "min_p", LLAMA_CONSTRAINT_TYPE_MIN_P }, + { "tfs_z", LLAMA_CONSTRAINT_TYPE_TFS_Z }, + { "temperature", LLAMA_CONSTRAINT_TYPE_TEMPERATURE }, }; // since samplers names are written multiple ways // make it ready for both system names and input names - std::unordered_map sampler_alt_name_map { - { "top-k", LLAMA_SAMPLER_TYPE_TOP_K }, - { "top-p", LLAMA_SAMPLER_TYPE_TOP_P }, - { "nucleus", LLAMA_SAMPLER_TYPE_TOP_P }, - { "typical-p", LLAMA_SAMPLER_TYPE_TYPICAL_P }, - { "typical", LLAMA_SAMPLER_TYPE_TYPICAL_P }, - { "typ-p", LLAMA_SAMPLER_TYPE_TYPICAL_P }, - { "typ", LLAMA_SAMPLER_TYPE_TYPICAL_P }, - { "min-p", LLAMA_SAMPLER_TYPE_MIN_P }, - { "tfs-z", LLAMA_SAMPLER_TYPE_TFS_Z }, - { "tfs", LLAMA_SAMPLER_TYPE_TFS_Z }, - { "temp", LLAMA_SAMPLER_TYPE_TEMPERATURE }, + std::unordered_map sampler_alt_name_map { + { "top-k", LLAMA_CONSTRAINT_TYPE_TOP_K }, + { "top-p", LLAMA_CONSTRAINT_TYPE_TOP_P }, + { "nucleus", LLAMA_CONSTRAINT_TYPE_TOP_P }, + { "typical-p", LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, + { "typical", LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, + { "typ-p", LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, + { "typ", LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, + { "min-p", LLAMA_CONSTRAINT_TYPE_MIN_P }, + { "tfs-z", LLAMA_CONSTRAINT_TYPE_TFS_Z }, + { "tfs", LLAMA_CONSTRAINT_TYPE_TFS_Z }, + { "temp", LLAMA_CONSTRAINT_TYPE_TEMPERATURE }, }; - std::vector samplers; + std::vector samplers; samplers.reserve(names.size()); for (const auto & name : names) { @@ -198,17 +198,17 @@ std::vector llama_sampling_types_from_names(const std::vecto return samplers; } -std::vector llama_sampling_types_from_chars(const std::string & chars) { - std::unordered_map sampler_name_map { - { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TOP_K), LLAMA_SAMPLER_TYPE_TOP_K }, - { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TFS_Z), LLAMA_SAMPLER_TYPE_TFS_Z }, - { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TYPICAL_P), LLAMA_SAMPLER_TYPE_TYPICAL_P }, - { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TOP_P), LLAMA_SAMPLER_TYPE_TOP_P }, - { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_MIN_P), LLAMA_SAMPLER_TYPE_MIN_P }, - { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TEMPERATURE), LLAMA_SAMPLER_TYPE_TEMPERATURE } +std::vector llama_sampling_types_from_chars(const std::string & chars) { + std::unordered_map sampler_name_map { + { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TOP_K), LLAMA_CONSTRAINT_TYPE_TOP_K }, + { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TFS_Z), LLAMA_CONSTRAINT_TYPE_TFS_Z }, + { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TYPICAL_P), LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, + { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TOP_P), LLAMA_CONSTRAINT_TYPE_TOP_P }, + { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_MIN_P), LLAMA_CONSTRAINT_TYPE_MIN_P }, + { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TEMPERATURE), LLAMA_CONSTRAINT_TYPE_TEMPERATURE } }; - std::vector samplers; + std::vector samplers; samplers.reserve(chars.size()); for (const auto & c : chars) { diff --git a/common/sampling.h b/common/sampling.h index b96bbce1c..365b7639a 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -6,7 +6,7 @@ #include // sampling parameters -typedef struct gpt_sampling_params { +struct gpt_sampling_params { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling int32_t n_prev = 64; // number of previous tokens to remember @@ -30,13 +30,13 @@ typedef struct gpt_sampling_params { bool penalize_nl = false; // consider newlines as a repeatable token bool ignore_eos = false; - std::vector samplers = { - LLAMA_SAMPLER_TYPE_TOP_K, - LLAMA_SAMPLER_TYPE_TFS_Z, - LLAMA_SAMPLER_TYPE_TYPICAL_P, - LLAMA_SAMPLER_TYPE_TOP_P, - LLAMA_SAMPLER_TYPE_MIN_P, - LLAMA_SAMPLER_TYPE_TEMPERATURE + std::vector samplers = { + LLAMA_CONSTRAINT_TYPE_TOP_K, + LLAMA_CONSTRAINT_TYPE_TFS_Z, + LLAMA_CONSTRAINT_TYPE_TYPICAL_P, + LLAMA_CONSTRAINT_TYPE_TOP_P, + LLAMA_CONSTRAINT_TYPE_MIN_P, + LLAMA_CONSTRAINT_TYPE_TEMPERATURE }; std::string grammar; // optional BNF-like grammar to constrain sampling @@ -48,7 +48,16 @@ typedef struct gpt_sampling_params { // print the samplers into a string std::string print_samplers() const; -} gpt_sampling_params; +}; + +// TODO: implement +struct gpt_sampler { + gpt_sampling_params params; + + struct llama_constraint * grmr = nullptr; + + struct llama_sampler * smpl = nullptr; +}; // overload of llama_sampling_init using gpt_sampling_params struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params); @@ -72,8 +81,8 @@ llama_token llama_sampling_sample( // get a string representation of the last accepted tokens std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx, int n); -char llama_sampling_type_to_chr(enum llama_sampler_type sampler_type); -std::string llama_sampling_type_to_str(enum llama_sampler_type sampler_type); +char llama_sampling_type_to_chr(enum llama_constraint_type sampler_type); +std::string llama_sampling_type_to_str(enum llama_constraint_type sampler_type); -std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names); -std::vector llama_sampling_types_from_chars(const std::string & chars); +std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names); +std::vector llama_sampling_types_from_chars(const std::string & chars); diff --git a/include/llama.h b/include/llama.h index 76c1aaf98..7225874f7 100644 --- a/include/llama.h +++ b/include/llama.h @@ -46,6 +46,7 @@ #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_VERSION 2 +// TODO: remove before merge #define LLAMA_MAX_SAMPLERS 16 #ifdef __cplusplus @@ -209,14 +210,15 @@ extern "C" { LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs }; - enum llama_sampler_type { - LLAMA_SAMPLER_TYPE_NONE = 0, - LLAMA_SAMPLER_TYPE_TOP_K = 1, - LLAMA_SAMPLER_TYPE_TOP_P = 2, - LLAMA_SAMPLER_TYPE_MIN_P = 3, - LLAMA_SAMPLER_TYPE_TFS_Z = 4, - LLAMA_SAMPLER_TYPE_TYPICAL_P = 5, - LLAMA_SAMPLER_TYPE_TEMPERATURE = 6, + // TODO: move to common, rename to gpt_constraint_type + enum llama_constraint_type { + LLAMA_CONSTRAINT_TYPE_NONE = 0, + LLAMA_CONSTRAINT_TYPE_TOP_K = 1, + LLAMA_CONSTRAINT_TYPE_TOP_P = 2, + LLAMA_CONSTRAINT_TYPE_MIN_P = 3, + LLAMA_CONSTRAINT_TYPE_TFS_Z = 4, + LLAMA_CONSTRAINT_TYPE_TYPICAL_P = 5, + LLAMA_CONSTRAINT_TYPE_TEMPERATURE = 6, }; typedef struct llama_token_data { @@ -382,6 +384,7 @@ extern "C" { float bias; } llama_logit_bias; + // TODO: remove before merge // parameters for sampling the logits typedef struct llama_sampling_params { uint32_t seed; // the seed used to initialize llama_sampling_context @@ -406,7 +409,7 @@ extern "C" { // samplers int32_t n_samplers; - enum llama_sampler_type samplers[LLAMA_MAX_SAMPLERS]; + enum llama_constraint_type samplers[LLAMA_MAX_SAMPLERS]; // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. bool penalize_nl; // consider newlines as a repeatable token @@ -414,7 +417,11 @@ extern "C" { } llama_sampling_params; typedef struct llama_sampler_params { - uint32_t seed; // the seed used to initialize the rng of the sampler + uint32_t seed; // the seed used to initialize the rng of the sampler + + int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau; // target entropy + float mirostat_eta; // learning rate // TODO: add type of sampler: greedy, dist, mirostat, etc. } llama_sampler_params; @@ -1176,6 +1183,8 @@ extern "C" { typedef void * llama_constraint_context_t; struct llama_constraint_i { + // TODO: add name API + void (*accept)(struct llama_constraint * cnstr, llama_token token); // can be NULL void (*apply) (struct llama_constraint * cnstr, llama_token_data_array * candidates); // required void (*reset) (struct llama_constraint * cnstr); // can be NULL @@ -1184,6 +1193,8 @@ extern "C" { // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph //void (*apply_ggml) (struct llama_constraint * cnstr, ...); + + // TODO: add API to get timing stats }; struct llama_constraint { @@ -1191,14 +1202,28 @@ extern "C" { llama_constraint_context_t ctx; }; - LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_tail_free(float z, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t); - LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent); - LLAMA_API struct llama_constraint * llama_constraint_init_grammar (struct llama_model * model, const char * grammar_str, const char * grammar_root); + LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t); + LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent); + LLAMA_API struct llama_constraint * llama_constraint_init_grammar (struct llama_model * model, const char * grammar_str, const char * grammar_root); + + LLAMA_API struct llama_constraint * llama_constraint_init_penalties( + struct llama_model * model, + int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat, // 1.0 = disabled + float penalty_freq, // 0.0 = disabled + float penalty_present, // 0.0 = disabled + bool penalize_nl, // consider newlines as a repeatable token + bool ignore_eos); // ignore the end-of-sequence token + + LLAMA_API struct llama_constraint * llama_constraint_init_logit_bias( + struct llama_model * model, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias); // do not call if used with llama_sampler_add_constraint LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr); @@ -1209,19 +1234,47 @@ extern "C" { // samplers - LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_params params); + LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_model * model, struct llama_sampler_params params); LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl); LLAMA_API void llama_sampler_reset( struct llama_sampler * smpl); + LLAMA_API void llama_sampler_set_logits(struct llama_sampler * smpl, const float * logits); + + LLAMA_API llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smpl); + + // TODO: should this take ownership so the user does not need to call llama_constraint_free // or should just make a reference to the constraint so that it can be reused in multiple llama_sampler? // // seems better to take the ownership, otherwise the copying of the sampler will be more complicated LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr); - LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); - LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i); + LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); + LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * candidates); + + LLAMA_API llama_token llama_sampler_sample_dist (struct llama_sampler * smpl, llama_token_data_array * candidates); + LLAMA_API llama_token llama_sampler_sample_greedy (struct llama_sampler * smpl, llama_token_data_array * candidates); + LLAMA_API llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_token_data_array * candidates); + + /// @details Get the number of accepted tokens so far (max of n_prev) + LLAMA_API int llama_sampler_n_prev(const struct llama_sampler * smpl); + + /// @details Get the ith accepted token + /// @param ith [0, n_prev), ith == 0 is the last accepted token. + /// returns LLAMA_TOKEN_NULL if ith is out of bounds + LLAMA_API llama_token llama_sampler_prev( + const struct llama_sampler * smpl, + int32_t ith); + + /// @details Get the last accepted token + /// Same as llama_sampler_prev(smpl, 0) + /// returns LLAMA_TOKEN_NULL if there are no accepted tokens + LLAMA_API llama_token llama_sampler_last(const struct llama_sampler * smpl); + + // TODO: extend in the future + //LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t i); + //LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...); // // Model split diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index f6328b601..36bbc0c1b 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -676,7 +676,14 @@ struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min struct llama_constraint * result = new llama_constraint; result->iface = &llama_constraint_top_k_i; - result->ctx = new llama_constraint_context_top_k{k, min_keep}; + result->ctx = new llama_constraint_context_top_k; + + auto * ctx = (llama_constraint_context_top_k *) result->ctx; + + *ctx = { + /*.k =*/ k, + /*.min_keep =*/ min_keep, + }; return result; } @@ -691,7 +698,7 @@ struct llama_constraint_context_top_p { static struct llama_constraint_i llama_constraint_top_p_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { - llama_constraint_context_top_p * ctx = (llama_constraint_context_top_p *) cnstr->ctx; + auto * ctx = (llama_constraint_context_top_p *) cnstr->ctx; llama_sampling_top_p_impl(candidates, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, @@ -713,7 +720,14 @@ struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_k struct llama_constraint * result = new llama_constraint; result->iface = &llama_constraint_top_p_i; - result->ctx = new llama_constraint_context_top_p{p, min_keep}; + result->ctx = new llama_constraint_context_top_p; + + auto * ctx = (llama_constraint_context_top_p *) result->ctx; + + *ctx = { + /*.p =*/ p, + /*.min_keep =*/ min_keep, + }; return result; } @@ -728,7 +742,7 @@ struct llama_constraint_context_min_p { static struct llama_constraint_i llama_constraint_min_p_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { - llama_constraint_context_min_p * ctx = (llama_constraint_context_min_p *) cnstr->ctx; + auto * ctx = (llama_constraint_context_min_p *) cnstr->ctx; llama_sampling_min_p_impl(candidates, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, @@ -750,7 +764,14 @@ struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_k struct llama_constraint * result = new llama_constraint; result->iface = &llama_constraint_min_p_i; - result->ctx = new llama_constraint_context_min_p{p, min_keep}; + result->ctx = new llama_constraint_context_min_p; + + auto * ctx = (llama_constraint_context_min_p *) result->ctx; + + *ctx = { + /*.p =*/ p, + /*.min_keep =*/ min_keep, + }; return result; } @@ -765,7 +786,7 @@ struct llama_constraint_context_tail_free { static struct llama_constraint_i llama_constraint_tail_free_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { - llama_constraint_context_tail_free * ctx = (llama_constraint_context_tail_free *) cnstr->ctx; + auto * ctx = (llama_constraint_context_tail_free *) cnstr->ctx; llama_sampling_tail_free_impl(candidates, ctx->z, ctx->min_keep); }, /* .reset = */ nullptr, @@ -787,7 +808,14 @@ struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t m struct llama_constraint * result = new llama_constraint; result->iface = &llama_constraint_tail_free_i; - result->ctx = new llama_constraint_context_tail_free{z, min_keep}; + result->ctx = new llama_constraint_context_tail_free; + + auto * ctx = (llama_constraint_context_tail_free *) result->ctx; + + *ctx = { + /*.z =*/ z, + /*.min_keep =*/ min_keep, + }; return result; } @@ -802,7 +830,7 @@ struct llama_constraint_context_typical { static struct llama_constraint_i llama_constraint_typical_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { - llama_constraint_context_typical * ctx = (llama_constraint_context_typical *) cnstr->ctx; + auto * ctx = (llama_constraint_context_typical *) cnstr->ctx; llama_sampling_typical_impl(candidates, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, @@ -824,7 +852,14 @@ struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min struct llama_constraint * result = new llama_constraint; result->iface = &llama_constraint_typical_i; - result->ctx = new llama_constraint_context_typical{p, min_keep}; + result->ctx = new llama_constraint_context_typical; + + auto * ctx = (llama_constraint_context_typical *) result->ctx; + + *ctx = { + /*.p =*/ p, + /*.min_keep =*/ min_keep, + }; return result; } @@ -838,7 +873,7 @@ struct llama_constraint_context_temp { static struct llama_constraint_i llama_constraint_temp_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { - llama_constraint_context_temp * ctx = (llama_constraint_context_temp *) cnstr->ctx; + auto * ctx = (llama_constraint_context_temp *) cnstr->ctx; llama_sampling_temp_impl(candidates, ctx->temp); }, /* .reset = */ nullptr, @@ -860,7 +895,13 @@ struct llama_constraint * llama_constraint_init_temp_impl(float temp) { struct llama_constraint * result = new llama_constraint; result->iface = &llama_constraint_temp_i; - result->ctx = new llama_constraint_context_temp{temp}; + result->ctx = new llama_constraint_context_temp; + + auto * ctx = (llama_constraint_context_temp *) result->ctx; + + *ctx = { + /*.temp =*/ temp, + }; return result; } @@ -876,7 +917,7 @@ struct llama_constraint_context_temp_ext { static struct llama_constraint_i llama_constraint_temp_ext_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { - llama_constraint_context_temp_ext * ctx = (llama_constraint_context_temp_ext *) cnstr->ctx; + auto * ctx = (llama_constraint_context_temp_ext *) cnstr->ctx; if (ctx->delta > 0) { const float temp_min = std::max(0.0f, ctx->temp - ctx->delta); const float temp_max = ctx->temp + ctx->delta; @@ -905,7 +946,15 @@ struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float struct llama_constraint * result = new llama_constraint; result->iface = &llama_constraint_temp_ext_i; - result->ctx = new llama_constraint_context_temp_ext{temp, delta, exponent}; + result->ctx = new llama_constraint_context_temp_ext; + + auto * ctx = (llama_constraint_context_temp_ext *) result->ctx; + + *ctx = { + /*.temp =*/ temp, + /*.delta =*/ delta, + /*.exponent =*/ exponent, + }; return result; } @@ -920,15 +969,20 @@ struct llama_constraint_context_grammar { }; static struct llama_constraint_i llama_constraint_grammar_i = { - /* .accept = */ nullptr, + /* .accept = */ [](struct llama_constraint * cnstr, llama_token token) { + auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + if (ctx->grammar) { + llama_grammar_accept_impl(*ctx->grammar, token); + } + }, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { - llama_constraint_context_grammar * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; if (ctx->grammar) { llama_sampling_grammar_impl(candidates, *ctx->grammar); } }, /* .reset = */ [](struct llama_constraint * cnstr) { - llama_constraint_context_grammar * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; if (ctx->grammar) { llama_grammar_free_impl(ctx->grammar); ctx->grammar = nullptr; @@ -973,20 +1027,173 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_ auto * ctx = (llama_constraint_context_grammar *) result->ctx; if (grammar_str != nullptr && grammar_str[0] != '\0') { - ctx->grammar_str = grammar_str; - ctx->grammar_root = grammar_root; - - ctx->grammar = llama_grammar_init_impl(&vocab, grammar_str, grammar_root); + *ctx = { + /*.grammar_str = */ grammar_str, + /*.grammar_root = */ grammar_root, + /*.grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root), + }; } else { - ctx->grammar_str.clear(); - ctx->grammar_root.clear(); - - ctx->grammar = nullptr; + *ctx = { + /*.grammar_str = */ {}, + /*.grammar_root = */ {}, + /*.grammar = */ nullptr, + }; } return result; } +// penalties + +struct llama_constraint_context_penalties { + const struct llama_vocab * vocab; + + int32_t penalty_last_n; + float penalty_repeat; + float penalty_freq; + float penalty_present; + + bool penalize_nl; + bool ignore_eos; + + ring_buffer prev; +}; + +static struct llama_constraint_i llama_constraint_penalties_i = { + /* .accept = */ [](struct llama_constraint * cnstr, llama_token token) { + auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx; + ctx->prev.push_back(token); + }, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx; + + GGML_ASSERT(candidates->size == ctx->vocab->n_vocab && candidates->sorted == false && "the 'penalties' constraint must be applied on the full vocabulary"); + + if (ctx->ignore_eos) { + candidates->data[ctx->vocab->special_eos_id].logit = -INFINITY; + } + + if ((ctx->penalty_last_n == 0) || + (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) { + return; + } + + const float nl_logit = !ctx->penalize_nl ? candidates->data[ctx->vocab->linefeed_id].logit : -INFINITY; + + // Create a frequency map to count occurrences of each token in last_tokens + // TODO: optimize this by maintaining the token count in the constraint context + llama_token_cnt token_count; + for (int i = 0; i < ctx->penalty_last_n; ++i) { + token_count[ctx->prev.rat(i)]++; + } + + llama_sampling_penalties_impl(candidates, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present); + + if (!ctx->penalize_nl) { + // restore the logit of the newline token if it was penalized + candidates->data[ctx->vocab->linefeed_id].logit = nl_logit; + } + }, + /* .reset = */ [](struct llama_constraint * cnstr) { + auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx; + ctx->prev.clear(); + }, + /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { + cnstr->ctx = new llama_constraint_context_penalties; + const auto * ctx_src = (const llama_constraint_context_penalties *) cnstr_src->ctx; + auto * ctx_dst = ( llama_constraint_context_penalties *) cnstr->ctx; + + *ctx_dst = *ctx_src; + }, + /* .free = */ [](struct llama_constraint * cnstr) { + if (cnstr->ctx) { + delete (llama_constraint_context_penalties *) cnstr->ctx; + } + delete cnstr; + } +}; + +struct llama_constraint * llama_constraint_init_penalties_impl(const struct llama_vocab & vocab, int32_t penalty_last_n, float penalty_repeat, float penalty_freq, float penalty_present, bool penalize_nl, bool ignore_eos) { + GGML_ASSERT(penalize_nl || vocab.linefeed_id != LLAMA_TOKEN_NULL); + GGML_ASSERT(!ignore_eos || vocab.special_eos_id != LLAMA_TOKEN_NULL); + + struct llama_constraint * result = new llama_constraint; + + result->iface = &llama_constraint_penalties_i; + result->ctx = new llama_constraint_context_penalties; + + auto * ctx = (llama_constraint_context_penalties *) result->ctx; + + *ctx = { + /*.vocab = */ &vocab, + /*.penalty_last_n = */ penalty_last_n, + /*.penalty_repeat = */ penalty_repeat, + /*.penalty_freq = */ penalty_freq, + /*.penalty_present = */ penalty_present, + /*.penalize_nl = */ penalize_nl, + /*.ignore_eos = */ ignore_eos, + /*.prev = */ {}, + }; + + return result; +} + +// logit-bias + +struct llama_constraint_context_logit_bias { + const struct llama_vocab * vocab; + + std::vector logit_bias; +}; + +static struct llama_constraint_i llama_constraint_logit_bias_i = { + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + auto * ctx = (llama_constraint_context_logit_bias *) cnstr->ctx; + + GGML_ASSERT(candidates->size == ctx->vocab->n_vocab && candidates->sorted == false && "the 'logit_bias' constraint must be applied on the full vocabulary"); + + for (const auto & lb : ctx->logit_bias) { + candidates->data[lb.token].logit += lb.bias; + } + }, + /* .reset = */ nullptr, + /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { + cnstr->ctx = new llama_constraint_context_logit_bias; + const auto * ctx_src = (const llama_constraint_context_logit_bias *) cnstr_src->ctx; + auto * ctx_dst = ( llama_constraint_context_logit_bias *) cnstr->ctx; + + *ctx_dst = *ctx_src; + }, + /* .free = */ [](struct llama_constraint * cnstr) { + if (cnstr->ctx) { + delete (llama_constraint_context_logit_bias *) cnstr->ctx; + } + delete cnstr; + } +}; + +struct llama_constraint * llama_constraint_init_logit_bias_impl( + const struct llama_vocab & vocab, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias) { + struct llama_constraint * result = new llama_constraint; + + result->iface = &llama_constraint_logit_bias_i; + result->ctx = new llama_constraint_context_logit_bias; + + auto * ctx = (llama_constraint_context_logit_bias *) result->ctx; + + *ctx = { + /*.vocab = */ &vocab, + /*.logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), + }; + + return result; +} + +//////////////////////////////////////// + void llama_constraint_free_impl(struct llama_constraint * cnstr) { if (cnstr->iface->free && cnstr) { cnstr->iface->free(cnstr); @@ -1012,10 +1219,11 @@ void llama_constraint_reset_impl(struct llama_constraint & cnstr) { // samplers -struct llama_sampler * llama_sampler_init_impl(struct llama_sampler_params params) { +struct llama_sampler * llama_sampler_init_impl(const struct llama_vocab & vocab, struct llama_sampler_params params) { auto * result = new llama_sampler; result->params = params; + result->vocab = &vocab; result->rng.seed(params.seed); @@ -1075,3 +1283,22 @@ void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) { llama_constraint_accept_impl(*cnstr, token); } } + +void llama_sampler_apply_impl(struct llama_sampler & smpl, struct llama_token_data_array * candidates) { + for (auto * cnstr : smpl.constraints) { + llama_constraint_apply_impl(*cnstr, candidates); + } +} + +llama_token llama_sampler_prev_impl(const struct llama_sampler & smpl, int ith) { + if (ith < 0 || ith >= (int) smpl.prev.size()) { + return LLAMA_TOKEN_NULL; + } + + return smpl.prev.rat(ith); +} + +int llama_sampler_n_prev_impl(const struct llama_sampler & smpl) { + return smpl.prev.size(); +} + diff --git a/src/llama-sampling.h b/src/llama-sampling.h index ed18da10d..7de37c89e 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -10,6 +10,7 @@ struct llama_grammar; using llama_token_cnt = std::unordered_map; +// TODO: remove before merge struct llama_sampling { llama_sampling(const struct llama_vocab & vocab); ~llama_sampling(); @@ -27,7 +28,7 @@ struct llama_sampling { const struct llama_vocab & vocab; - std::vector samplers; + std::vector samplers; ring_buffer prev; @@ -120,7 +121,25 @@ struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t struct llama_constraint * llama_constraint_init_typical_impl (float p, size_t min_keep); struct llama_constraint * llama_constraint_init_temp_impl (float t); struct llama_constraint * llama_constraint_init_temp_ext_impl (float t, float delta, float exponent); -struct llama_constraint * llama_constraint_init_grammar_impl (const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root); + +struct llama_constraint * llama_constraint_init_grammar_impl ( + const struct llama_vocab & vocab, + const char * grammar_str, + const char * grammar_root); + +struct llama_constraint * llama_constraint_init_penalties_impl( + const struct llama_vocab & vocab, + int32_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present, + bool penalize_nl, + bool ignore_eos); + + LLAMA_API struct llama_constraint * llama_constraint_init_logit_bias_impl( + const struct llama_vocab & vocab, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias); void llama_constraint_free_impl(struct llama_constraint * cnstr); @@ -133,15 +152,22 @@ void llama_constraint_reset_impl (struct llama_constraint & cnstr); struct llama_sampler { llama_sampler_params params; + const struct llama_vocab * vocab; + // state std::mt19937 rng; - // TODO: move to a standalone penalty constraint? + float mirostat_mu; + ring_buffer prev; std::vector constraints; + std::vector cur; + + llama_token_data_array cur_p; + // timing mutable int64_t t_sample_us = 0; @@ -149,10 +175,15 @@ struct llama_sampler { mutable int32_t n_sample = 0; }; -struct llama_sampler * llama_sampler_init_impl ( struct llama_sampler_params params); +struct llama_sampler * llama_sampler_init_impl (const struct llama_vocab & vocab, struct llama_sampler_params params); void llama_sampler_free_impl ( struct llama_sampler * smpl); struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl); void llama_sampler_reset_impl( struct llama_sampler & smpl); void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr); -void llama_sampler_accept_impl (struct llama_sampler & smpl, llama_token token); + +void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token); +void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_data_array * candidates); + +llama_token llama_sampler_prev_impl (const struct llama_sampler & smpl, int ith); +int llama_sampler_n_prev_impl(const struct llama_sampler & smpl); diff --git a/src/llama.cpp b/src/llama.cpp index a47ad0103..4060fa1de 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17938,6 +17938,9 @@ struct llama_context_params llama_context_default_params() { struct llama_sampler_params llama_sampler_default_params() { struct llama_sampler_params result = { /*.seed =*/ LLAMA_DEFAULT_SEED, + /*.mirostat =*/ 0, + /*.mirostat_tau =*/ 5.00f, + /*.mirostat_eta =*/ 0.10f, }; return result; @@ -17965,7 +17968,7 @@ struct llama_sampling_params llama_sampling_default_params() { /*.mirostat_tau =*/ 5.00f, /*.mirostat_eta =*/ 0.10f, /*.n_samplers =*/ 3, - /*.samplers =*/ { LLAMA_SAMPLER_TYPE_TEMPERATURE, LLAMA_SAMPLER_TYPE_TOP_K, LLAMA_SAMPLER_TYPE_TOP_P, }, + /*.samplers =*/ { LLAMA_CONSTRAINT_TYPE_TEMPERATURE, LLAMA_CONSTRAINT_TYPE_TOP_K, LLAMA_CONSTRAINT_TYPE_TOP_P, }, /*.penalize_nl =*/ false, /*.ignore_eos =*/ false, }; @@ -20916,12 +20919,12 @@ llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data } else { for (const auto & sampler : smpl->samplers) { switch (sampler) { - case LLAMA_SAMPLER_TYPE_TOP_K: llama_sampling_top_k_impl (cur_p, smpl->params.top_k, smpl->params.min_keep); break; - case LLAMA_SAMPLER_TYPE_TFS_Z: llama_sampling_tail_free_impl(cur_p, smpl->params.tfs_z, smpl->params.min_keep); break; - case LLAMA_SAMPLER_TYPE_TYPICAL_P: llama_sampling_typical_impl (cur_p, smpl->params.typ_p, smpl->params.min_keep); break; - case LLAMA_SAMPLER_TYPE_TOP_P: llama_sampling_top_p_impl (cur_p, smpl->params.top_p, smpl->params.min_keep); break; - case LLAMA_SAMPLER_TYPE_MIN_P: llama_sampling_min_p_impl (cur_p, smpl->params.min_p, smpl->params.min_keep); break; - case LLAMA_SAMPLER_TYPE_TEMPERATURE: llama_sampling_temp_impl (cur_p, temp); break; + case LLAMA_CONSTRAINT_TYPE_TOP_K: llama_sampling_top_k_impl (cur_p, smpl->params.top_k, smpl->params.min_keep); break; + case LLAMA_CONSTRAINT_TYPE_TFS_Z: llama_sampling_tail_free_impl(cur_p, smpl->params.tfs_z, smpl->params.min_keep); break; + case LLAMA_CONSTRAINT_TYPE_TYPICAL_P: llama_sampling_typical_impl (cur_p, smpl->params.typ_p, smpl->params.min_keep); break; + case LLAMA_CONSTRAINT_TYPE_TOP_P: llama_sampling_top_p_impl (cur_p, smpl->params.top_p, smpl->params.min_keep); break; + case LLAMA_CONSTRAINT_TYPE_MIN_P: llama_sampling_min_p_impl (cur_p, smpl->params.min_p, smpl->params.min_keep); break; + case LLAMA_CONSTRAINT_TYPE_TEMPERATURE: llama_sampling_temp_impl (cur_p, temp); break; default : break; } } @@ -21007,6 +21010,24 @@ struct llama_constraint * llama_constraint_init_grammar(struct llama_model * mod return llama_constraint_init_grammar_impl(model->vocab, grammar_str, grammar_root); } +struct llama_constraint * llama_constraint_init_penalties( + struct llama_model * model, + int32_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present, + bool penalize_nl, + bool ignore_eos) { + return llama_constraint_init_penalties_impl(model->vocab, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos); +} + +LLAMA_API struct llama_constraint * llama_constraint_init_logit_bias( + struct llama_model * model, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias) { + return llama_constraint_init_logit_bias_impl(model->vocab, n_logit_bias, logit_bias); +} + void llama_constraint_free(struct llama_constraint * cnstr) { if (cnstr == nullptr) { return; @@ -21027,8 +21048,8 @@ void llama_constraint_reset(struct llama_constraint * cnstr) { llama_constraint_reset_impl(*cnstr); } -struct llama_sampler * llama_sampler_init(struct llama_sampler_params params) { - return llama_sampler_init_impl(params); +struct llama_sampler * llama_sampler_init(const struct llama_model * model, struct llama_sampler_params params) { + return llama_sampler_init_impl(model->vocab, params); } void llama_sampler_free(struct llama_sampler * smpl) { @@ -21047,6 +21068,22 @@ void llama_sampler_reset(struct llama_sampler * smpl) { llama_sampler_reset_impl(*smpl); } +void llama_sampler_set_logits(struct llama_sampler * smpl, const float * logits) { + const int n_vocab = smpl->vocab->n_vocab; + + smpl->cur.resize(n_vocab); + + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + smpl->cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + + smpl->cur_p = { smpl->cur.data(), smpl->cur.size(), false }; +} + +llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smpl) { + return &smpl->cur_p; +} + void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr) { llama_sampler_add_constraint_impl(*smpl, cnstr); } @@ -21055,10 +21092,90 @@ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { llama_sampler_accept_impl(*smpl, token); } -llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i) { - GGML_ABORT("not implemented"); +void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + llama_sampler_apply_impl(*smpl, candidates); } +llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + const auto type = smpl->params.mirostat; + + llama_token res; + + if (type == 1) { + res = llama_sampling_sample_mirostat_impl(candidates, + smpl->rng, + smpl->params.mirostat_tau, + smpl->params.mirostat_eta, + 100, + smpl->vocab->n_vocab, + smpl->mirostat_mu); + } else if (type == 2) { + res = llama_sampling_sample_mirostat_v2_impl(candidates, + smpl->rng, + smpl->params.mirostat_tau, + smpl->params.mirostat_eta, + smpl->mirostat_mu); + } else { + GGML_ABORT("invalid mirostat type: %d", type); + } + + smpl->n_sample++; + + return res; +} + +llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + auto res = llama_sampling_sample_greedy_impl(candidates); + + smpl->n_sample++; + + return res; +} + +llama_token llama_sampler_sample_dist(struct llama_sampler * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + auto res = llama_sampling_sample_dist_impl(candidates, smpl->rng); + + smpl->n_sample++; + + return res; +} + +int llama_sampler_n_prev(const struct llama_sampler * smpl) { + return llama_sampler_n_prev_impl(*smpl); +} + +llama_token llama_sampler_prev(const struct llama_sampler * smpl, int32_t ith) { + return llama_sampler_prev_impl(*smpl, ith); +} + +llama_token llama_sampler_last(const struct llama_sampler * smpl) { + return llama_sampler_prev_impl(*smpl, 0); +} + +//llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i) { +// GGML_ABORT("not implemented"); +//} + // // model split //