diff --git a/common/sampling.cpp b/common/sampling.cpp index edc6cd05b..2887207f1 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -67,7 +67,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st for (const auto & cnstr : params.constraints) { switch (cnstr) { case GPT_CONSTRAINT_TYPE_TOP_K: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k)); break; case GPT_CONSTRAINT_TYPE_TOP_P: llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep)); diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 6ff62ae06..a02fa4da9 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -61,7 +61,7 @@ defer { llama_sampler_free(smpl) } -llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(40, 1)); +llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(40)); llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(0.9, 1)); llama_sampler_constraint_add(smpl, llama_constraint_init_temp (0.4)); diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 9f0a40873..5896526ab 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -70,7 +70,7 @@ int main(int argc, char ** argv) { llama_sampler * smpl = llama_sampler_init(model, sparams); - llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep)); + llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(params.sparams.top_k)); llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep)); llama_sampler_constraint_add(smpl, llama_constraint_init_temp (params.sparams.temp)); diff --git a/include/llama.h b/include/llama.h index 763d32abb..02f7a8491 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1045,7 +1045,7 @@ extern "C" { }; LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void); - LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k); LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 4e44ec417..c07c509bc 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -49,7 +49,7 @@ static void llama_constraint_softmax_impl(llama_token_data_array * cur_p) { } } -static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t k, size_t min_keep) { +static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t k) { // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast // if (k >= (int32_t)cur_p->size) { // return; @@ -59,7 +59,6 @@ static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t k = cur_p->size; } - k = std::max(k, (int) min_keep); k = std::min(k, (int) cur_p->size); // Sort scores in descending order @@ -449,7 +448,6 @@ struct llama_constraint * llama_constraint_init_softmax_impl() { struct llama_constraint_context_top_k { const int32_t k; - const size_t min_keep; }; static struct llama_constraint_i llama_constraint_top_k_i = { @@ -457,24 +455,23 @@ static struct llama_constraint_i llama_constraint_top_k_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { const auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx; - llama_constraint_top_k_impl(cur_p, ctx->k, ctx->min_keep); + llama_constraint_top_k_impl(cur_p, ctx->k); }, /* .reset = */ nullptr, /* .copy = */ [](const struct llama_constraint * cnstr) { const auto * ctx = (const llama_constraint_context_top_k *) cnstr->ctx; - return llama_constraint_init_top_k_impl(ctx->k, ctx->min_keep); + return llama_constraint_init_top_k_impl(ctx->k); }, /* .free = */ [](struct llama_constraint * cnstr) { delete (llama_constraint_context_top_k *) cnstr->ctx; }, }; -struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep) { +struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k) { return new llama_constraint { /* .iface = */ &llama_constraint_top_k_i, /* .ctx = */ new llama_constraint_context_top_k { - /*.k =*/ k, - /*.min_keep =*/ min_keep, + /* .k = */ k, }, }; } @@ -507,8 +504,8 @@ struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_k return new llama_constraint { /* .iface = */ &llama_constraint_top_p_i, /* .ctx = */ new llama_constraint_context_top_p { - /*.p =*/ p, - /*.min_keep =*/ min_keep, + /* .p = */ p, + /* .min_keep = */ min_keep, }, }; } @@ -541,8 +538,8 @@ struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_k return new llama_constraint { /* .iface = */ &llama_constraint_min_p_i, /* .ctx = */ new llama_constraint_context_min_p { - /*.p =*/ p, - /*.min_keep =*/ min_keep, + /* .p = */ p, + /* .min_keep = */ min_keep, }, }; } @@ -575,8 +572,8 @@ struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t m return new llama_constraint { /* .iface = */ &llama_constraint_tail_free_i, /* .ctx = */ new llama_constraint_context_tail_free { - /*.z =*/ z, - /*.min_keep =*/ min_keep, + /* .z = */ z, + /*. min_keep = */ min_keep, }, }; } @@ -609,8 +606,8 @@ struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min return new llama_constraint { /* .iface = */ &llama_constraint_typical_i, /* .ctx = */ new llama_constraint_context_typical { - /*.p =*/ p, - /*.min_keep =*/ min_keep, + /* .p = */ p, + /* .min_keep = */ min_keep, }, }; } @@ -642,7 +639,7 @@ struct llama_constraint * llama_constraint_init_temp_impl(float temp) { return new llama_constraint { /* .iface = */ &llama_constraint_temp_i, /* .ctx = */ new llama_constraint_context_temp { - /*.temp =*/ temp, + /*.temp = */ temp, }, }; } @@ -683,9 +680,9 @@ struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float return new llama_constraint { /* .iface = */ &llama_constraint_temp_ext_i, /* .ctx = */ new llama_constraint_context_temp_ext { - /*.temp =*/ temp, - /*.delta =*/ delta, - /*.exponent =*/ exponent, + /* .temp = */ temp, + /* .delta = */ delta, + /* .exponent = */ exponent, }, }; } @@ -745,7 +742,7 @@ static struct llama_constraint_i llama_constraint_mirostat_i = { float epsilon_hat = s_hat - 1; float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->vocab->n_vocab, -epsilon_hat)), 1 / s_hat); - llama_constraint_top_k_impl(cur_p, int(k), 1); + llama_constraint_top_k_impl(cur_p, std::max(int(k), 1)); // remember the order to be able to compute the distance later when accepting the token ctx->cur.resize(cur_p->size); @@ -755,7 +752,7 @@ static struct llama_constraint_i llama_constraint_mirostat_i = { }, /* .reset = */ [](struct llama_constraint * cnstr) { auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx; - ctx->mu = 0.0f; + ctx->mu = 2.0f*ctx->tau; }, /* .copy = */ [](const struct llama_constraint * cnstr) { const auto * ctx = (const llama_constraint_context_mirostat *) cnstr->ctx; @@ -770,12 +767,12 @@ struct llama_constraint * llama_constraint_init_mirostat_impl(const struct llama return new llama_constraint { /* .iface = */ &llama_constraint_mirostat_i, /* .ctx = */ new llama_constraint_context_mirostat { - /*.vocab =*/ &vocab, - /*.tau =*/ tau, - /*.eta =*/ eta, - /*.m =*/ m, - /*.mu =*/ 0.0f, - /*.cur =*/ {}, + /* .vocab = */ &vocab, + /* .tau = */ tau, + /* .eta = */ eta, + /* .m = */ m, + /* .mu = */ 2.0f*tau, + /* .cur = */ {}, }, }; } @@ -826,10 +823,16 @@ static struct llama_constraint_i llama_constraint_mirostat_v2_i = { // Normalize the probabilities of the remaining words llama_constraint_softmax_impl(cur_p); + + // remember the order to be able to compute the distance later when accepting the token + ctx->cur.resize(cur_p->size); + for (size_t i = 0; i < cur_p->size; ++i) { + ctx->cur[i] = cur_p->data[i]; + } }, /* .reset = */ [](struct llama_constraint * cnstr) { auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx; - ctx->mu = 0.0f; + ctx->mu = 2.0f*ctx->tau; }, /* .copy = */ [](const struct llama_constraint * cnstr) { const auto * ctx = (const llama_constraint_context_mirostat_v2 *) cnstr->ctx; @@ -844,10 +847,10 @@ struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, floa return new llama_constraint { /* .iface = */ &llama_constraint_mirostat_v2_i, /* .ctx = */ new llama_constraint_context_mirostat_v2 { - /*.tau =*/ tau, - /*.eta =*/ eta, - /*.mu =*/ 0.0f, - /*.cur =*/ {}, + /* .tau = */ tau, + /* .eta = */ eta, + /* .mu = */ 2.0f*tau, + /* .cur = */ {}, }, }; } @@ -919,17 +922,17 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_ if (grammar_str != nullptr && grammar_str[0] != '\0') { *ctx = { - /*.vocab = */ &vocab, - /*.grammar_str = */ grammar_str, - /*.grammar_root = */ grammar_root, - /*.grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root), + /* .vocab = */ &vocab, + /* .grammar_str = */ grammar_str, + /* .grammar_root = */ grammar_root, + /* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root), }; } else { *ctx = { - /*.vocab = */ &vocab, - /*.grammar_str = */ {}, - /*.grammar_root = */ {}, - /*.grammar = */ nullptr, + /* .vocab = */ &vocab, + /* .grammar_str = */ {}, + /* .grammar_root = */ {}, + /* .grammar = */ nullptr, }; } @@ -1023,14 +1026,14 @@ struct llama_constraint * llama_constraint_init_penalties_impl(const struct llam return new llama_constraint { /* .iface = */ &llama_constraint_penalties_i, /* .ctx = */ new llama_constraint_context_penalties { - /*.vocab =*/ &vocab, - /*.penalty_last_n =*/ penalty_last_n, - /*.penalty_repeat =*/ penalty_repeat, - /*.penalty_freq =*/ penalty_freq, - /*.penalty_present =*/ penalty_present, - /*.penalize_nl =*/ penalize_nl, - /*.ignore_eos =*/ ignore_eos, - /*.prev =*/ ring_buffer(penalty_last_n), + /* .vocab = */ &vocab, + /* .penalty_last_n = */ penalty_last_n, + /* .penalty_repeat = */ penalty_repeat, + /* .penalty_freq = */ penalty_freq, + /* .penalty_present = */ penalty_present, + /* .penalize_nl = */ penalize_nl, + /* .ignore_eos = */ ignore_eos, + /* .prev = */ ring_buffer(penalty_last_n), }, }; } @@ -1072,8 +1075,8 @@ struct llama_constraint * llama_constraint_init_logit_bias_impl( return new llama_constraint { /* .iface = */ &llama_constraint_logit_bias_i, /* .ctx = */ new llama_constraint_context_logit_bias { - /*.vocab =*/ &vocab, - /*.logit_bias=*/ std::vector(logit_bias, logit_bias + n_logit_bias), + /* .vocab = */ &vocab, + /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), }, }; } diff --git a/src/llama-sampling.h b/src/llama-sampling.h index acb1e04fe..1295bc823 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -21,7 +21,7 @@ void llama_constraint_penalties_impl( // constraints struct llama_constraint * llama_constraint_init_softmax_impl (); -struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k, size_t min_keep); +struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k); struct llama_constraint * llama_constraint_init_top_p_impl (float p, size_t min_keep); struct llama_constraint * llama_constraint_init_min_p_impl (float p, size_t min_keep); struct llama_constraint * llama_constraint_init_tail_free_impl (float z, size_t min_keep); diff --git a/src/llama.cpp b/src/llama.cpp index 2f5df2433..6a30daf39 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20611,8 +20611,8 @@ struct llama_constraint * llama_constraint_init_softmax(void) { return llama_constraint_init_softmax_impl(); } -struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep) { - return llama_constraint_init_top_k_impl(k, min_keep); +struct llama_constraint * llama_constraint_init_top_k(int32_t k) { + return llama_constraint_init_top_k_impl(k); } struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep) { diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 0c9b46429..74bb4a3a3 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -19,7 +19,7 @@ static void dump(const llama_token_data_array * cur_p) { #define DUMP(__cur_p) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__cur_p)); printf("-\n"); } while(0) -#define TEST(__cnstr, __cur_p) do { \ +#define APPLY(__cnstr, __cur_p) do { \ auto * cnstr = (__cnstr); \ llama_constraint_apply(cnstr, (__cur_p)); \ llama_constraint_free(cnstr); \ @@ -36,9 +36,9 @@ static void test_top_k(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector llama_token_data_array cur_p = { cur.data(), cur.size(), false }; DUMP(&cur_p); - TEST(llama_constraint_init_tail_free(z, 1), &cur_p); + APPLY(llama_constraint_init_tail_free(z, 1), &cur_p); DUMP(&cur_p); GGML_ASSERT(cur_p.size == expected_probs.size()); @@ -102,9 +102,9 @@ static void test_min_p(const std::vector & probs, const std::vector & probs, const std::vector