sampling : remove top-k min_keep, fix mirostat init and state

This commit is contained in:
Georgi Gerganov 2024-09-05 10:18:04 +03:00
parent b2b36e9e95
commit 69551ffd60
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
8 changed files with 76 additions and 73 deletions

View file

@ -67,7 +67,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
for (const auto & cnstr : params.constraints) { for (const auto & cnstr : params.constraints) {
switch (cnstr) { switch (cnstr) {
case GPT_CONSTRAINT_TYPE_TOP_K: case GPT_CONSTRAINT_TYPE_TOP_K:
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep)); llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k));
break; break;
case GPT_CONSTRAINT_TYPE_TOP_P: case GPT_CONSTRAINT_TYPE_TOP_P:
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep)); llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep));

View file

@ -61,7 +61,7 @@ defer {
llama_sampler_free(smpl) llama_sampler_free(smpl)
} }
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(40, 1)); llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(40));
llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(0.9, 1)); llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(0.9, 1));
llama_sampler_constraint_add(smpl, llama_constraint_init_temp (0.4)); llama_sampler_constraint_add(smpl, llama_constraint_init_temp (0.4));

View file

@ -70,7 +70,7 @@ int main(int argc, char ** argv) {
llama_sampler * smpl = llama_sampler_init(model, sparams); llama_sampler * smpl = llama_sampler_init(model, sparams);
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep)); llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(params.sparams.top_k));
llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep)); llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep));
llama_sampler_constraint_add(smpl, llama_constraint_init_temp (params.sparams.temp)); llama_sampler_constraint_add(smpl, llama_constraint_init_temp (params.sparams.temp));

View file

@ -1045,7 +1045,7 @@ extern "C" {
}; };
LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void); LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void);
LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k);
LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep);

View file

@ -49,7 +49,7 @@ static void llama_constraint_softmax_impl(llama_token_data_array * cur_p) {
} }
} }
static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t k, size_t min_keep) { static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
// if (k >= (int32_t)cur_p->size) { // if (k >= (int32_t)cur_p->size) {
// return; // return;
@ -59,7 +59,6 @@ static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t
k = cur_p->size; k = cur_p->size;
} }
k = std::max(k, (int) min_keep);
k = std::min(k, (int) cur_p->size); k = std::min(k, (int) cur_p->size);
// Sort scores in descending order // Sort scores in descending order
@ -449,7 +448,6 @@ struct llama_constraint * llama_constraint_init_softmax_impl() {
struct llama_constraint_context_top_k { struct llama_constraint_context_top_k {
const int32_t k; const int32_t k;
const size_t min_keep;
}; };
static struct llama_constraint_i llama_constraint_top_k_i = { static struct llama_constraint_i llama_constraint_top_k_i = {
@ -457,24 +455,23 @@ static struct llama_constraint_i llama_constraint_top_k_i = {
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
const auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx; const auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx;
llama_constraint_top_k_impl(cur_p, ctx->k, ctx->min_keep); llama_constraint_top_k_impl(cur_p, ctx->k);
}, },
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .copy = */ [](const struct llama_constraint * cnstr) { /* .copy = */ [](const struct llama_constraint * cnstr) {
const auto * ctx = (const llama_constraint_context_top_k *) cnstr->ctx; const auto * ctx = (const llama_constraint_context_top_k *) cnstr->ctx;
return llama_constraint_init_top_k_impl(ctx->k, ctx->min_keep); return llama_constraint_init_top_k_impl(ctx->k);
}, },
/* .free = */ [](struct llama_constraint * cnstr) { /* .free = */ [](struct llama_constraint * cnstr) {
delete (llama_constraint_context_top_k *) cnstr->ctx; delete (llama_constraint_context_top_k *) cnstr->ctx;
}, },
}; };
struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep) { struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k) {
return new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_top_k_i, /* .iface = */ &llama_constraint_top_k_i,
/* .ctx = */ new llama_constraint_context_top_k { /* .ctx = */ new llama_constraint_context_top_k {
/*.k =*/ k, /* .k = */ k,
/*.min_keep =*/ min_keep,
}, },
}; };
} }
@ -507,8 +504,8 @@ struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_k
return new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_top_p_i, /* .iface = */ &llama_constraint_top_p_i,
/* .ctx = */ new llama_constraint_context_top_p { /* .ctx = */ new llama_constraint_context_top_p {
/*.p =*/ p, /* .p = */ p,
/*.min_keep =*/ min_keep, /* .min_keep = */ min_keep,
}, },
}; };
} }
@ -541,8 +538,8 @@ struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_k
return new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_min_p_i, /* .iface = */ &llama_constraint_min_p_i,
/* .ctx = */ new llama_constraint_context_min_p { /* .ctx = */ new llama_constraint_context_min_p {
/*.p =*/ p, /* .p = */ p,
/*.min_keep =*/ min_keep, /* .min_keep = */ min_keep,
}, },
}; };
} }
@ -575,8 +572,8 @@ struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t m
return new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_tail_free_i, /* .iface = */ &llama_constraint_tail_free_i,
/* .ctx = */ new llama_constraint_context_tail_free { /* .ctx = */ new llama_constraint_context_tail_free {
/*.z =*/ z, /* .z = */ z,
/*.min_keep =*/ min_keep, /*. min_keep = */ min_keep,
}, },
}; };
} }
@ -609,8 +606,8 @@ struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min
return new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_typical_i, /* .iface = */ &llama_constraint_typical_i,
/* .ctx = */ new llama_constraint_context_typical { /* .ctx = */ new llama_constraint_context_typical {
/*.p =*/ p, /* .p = */ p,
/*.min_keep =*/ min_keep, /* .min_keep = */ min_keep,
}, },
}; };
} }
@ -642,7 +639,7 @@ struct llama_constraint * llama_constraint_init_temp_impl(float temp) {
return new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_temp_i, /* .iface = */ &llama_constraint_temp_i,
/* .ctx = */ new llama_constraint_context_temp { /* .ctx = */ new llama_constraint_context_temp {
/*.temp =*/ temp, /*.temp = */ temp,
}, },
}; };
} }
@ -683,9 +680,9 @@ struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float
return new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_temp_ext_i, /* .iface = */ &llama_constraint_temp_ext_i,
/* .ctx = */ new llama_constraint_context_temp_ext { /* .ctx = */ new llama_constraint_context_temp_ext {
/*.temp =*/ temp, /* .temp = */ temp,
/*.delta =*/ delta, /* .delta = */ delta,
/*.exponent =*/ exponent, /* .exponent = */ exponent,
}, },
}; };
} }
@ -745,7 +742,7 @@ static struct llama_constraint_i llama_constraint_mirostat_i = {
float epsilon_hat = s_hat - 1; float epsilon_hat = s_hat - 1;
float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->vocab->n_vocab, -epsilon_hat)), 1 / s_hat); float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->vocab->n_vocab, -epsilon_hat)), 1 / s_hat);
llama_constraint_top_k_impl(cur_p, int(k), 1); llama_constraint_top_k_impl(cur_p, std::max(int(k), 1));
// remember the order to be able to compute the distance later when accepting the token // remember the order to be able to compute the distance later when accepting the token
ctx->cur.resize(cur_p->size); ctx->cur.resize(cur_p->size);
@ -755,7 +752,7 @@ static struct llama_constraint_i llama_constraint_mirostat_i = {
}, },
/* .reset = */ [](struct llama_constraint * cnstr) { /* .reset = */ [](struct llama_constraint * cnstr) {
auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx; auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx;
ctx->mu = 0.0f; ctx->mu = 2.0f*ctx->tau;
}, },
/* .copy = */ [](const struct llama_constraint * cnstr) { /* .copy = */ [](const struct llama_constraint * cnstr) {
const auto * ctx = (const llama_constraint_context_mirostat *) cnstr->ctx; const auto * ctx = (const llama_constraint_context_mirostat *) cnstr->ctx;
@ -770,12 +767,12 @@ struct llama_constraint * llama_constraint_init_mirostat_impl(const struct llama
return new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_mirostat_i, /* .iface = */ &llama_constraint_mirostat_i,
/* .ctx = */ new llama_constraint_context_mirostat { /* .ctx = */ new llama_constraint_context_mirostat {
/*.vocab =*/ &vocab, /* .vocab = */ &vocab,
/*.tau =*/ tau, /* .tau = */ tau,
/*.eta =*/ eta, /* .eta = */ eta,
/*.m =*/ m, /* .m = */ m,
/*.mu =*/ 0.0f, /* .mu = */ 2.0f*tau,
/*.cur =*/ {}, /* .cur = */ {},
}, },
}; };
} }
@ -826,10 +823,16 @@ static struct llama_constraint_i llama_constraint_mirostat_v2_i = {
// Normalize the probabilities of the remaining words // Normalize the probabilities of the remaining words
llama_constraint_softmax_impl(cur_p); llama_constraint_softmax_impl(cur_p);
// remember the order to be able to compute the distance later when accepting the token
ctx->cur.resize(cur_p->size);
for (size_t i = 0; i < cur_p->size; ++i) {
ctx->cur[i] = cur_p->data[i];
}
}, },
/* .reset = */ [](struct llama_constraint * cnstr) { /* .reset = */ [](struct llama_constraint * cnstr) {
auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx; auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx;
ctx->mu = 0.0f; ctx->mu = 2.0f*ctx->tau;
}, },
/* .copy = */ [](const struct llama_constraint * cnstr) { /* .copy = */ [](const struct llama_constraint * cnstr) {
const auto * ctx = (const llama_constraint_context_mirostat_v2 *) cnstr->ctx; const auto * ctx = (const llama_constraint_context_mirostat_v2 *) cnstr->ctx;
@ -844,10 +847,10 @@ struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, floa
return new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_mirostat_v2_i, /* .iface = */ &llama_constraint_mirostat_v2_i,
/* .ctx = */ new llama_constraint_context_mirostat_v2 { /* .ctx = */ new llama_constraint_context_mirostat_v2 {
/*.tau =*/ tau, /* .tau = */ tau,
/*.eta =*/ eta, /* .eta = */ eta,
/*.mu =*/ 0.0f, /* .mu = */ 2.0f*tau,
/*.cur =*/ {}, /* .cur = */ {},
}, },
}; };
} }
@ -919,17 +922,17 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_
if (grammar_str != nullptr && grammar_str[0] != '\0') { if (grammar_str != nullptr && grammar_str[0] != '\0') {
*ctx = { *ctx = {
/*.vocab = */ &vocab, /* .vocab = */ &vocab,
/*.grammar_str = */ grammar_str, /* .grammar_str = */ grammar_str,
/*.grammar_root = */ grammar_root, /* .grammar_root = */ grammar_root,
/*.grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root), /* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
}; };
} else { } else {
*ctx = { *ctx = {
/*.vocab = */ &vocab, /* .vocab = */ &vocab,
/*.grammar_str = */ {}, /* .grammar_str = */ {},
/*.grammar_root = */ {}, /* .grammar_root = */ {},
/*.grammar = */ nullptr, /* .grammar = */ nullptr,
}; };
} }
@ -1023,14 +1026,14 @@ struct llama_constraint * llama_constraint_init_penalties_impl(const struct llam
return new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_penalties_i, /* .iface = */ &llama_constraint_penalties_i,
/* .ctx = */ new llama_constraint_context_penalties { /* .ctx = */ new llama_constraint_context_penalties {
/*.vocab =*/ &vocab, /* .vocab = */ &vocab,
/*.penalty_last_n =*/ penalty_last_n, /* .penalty_last_n = */ penalty_last_n,
/*.penalty_repeat =*/ penalty_repeat, /* .penalty_repeat = */ penalty_repeat,
/*.penalty_freq =*/ penalty_freq, /* .penalty_freq = */ penalty_freq,
/*.penalty_present =*/ penalty_present, /* .penalty_present = */ penalty_present,
/*.penalize_nl =*/ penalize_nl, /* .penalize_nl = */ penalize_nl,
/*.ignore_eos =*/ ignore_eos, /* .ignore_eos = */ ignore_eos,
/*.prev =*/ ring_buffer<llama_token>(penalty_last_n), /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
}, },
}; };
} }
@ -1072,8 +1075,8 @@ struct llama_constraint * llama_constraint_init_logit_bias_impl(
return new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_logit_bias_i, /* .iface = */ &llama_constraint_logit_bias_i,
/* .ctx = */ new llama_constraint_context_logit_bias { /* .ctx = */ new llama_constraint_context_logit_bias {
/*.vocab =*/ &vocab, /* .vocab = */ &vocab,
/*.logit_bias=*/ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias), /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
}, },
}; };
} }

View file

@ -21,7 +21,7 @@ void llama_constraint_penalties_impl(
// constraints // constraints
struct llama_constraint * llama_constraint_init_softmax_impl (); struct llama_constraint * llama_constraint_init_softmax_impl ();
struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k, size_t min_keep); struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k);
struct llama_constraint * llama_constraint_init_top_p_impl (float p, size_t min_keep); struct llama_constraint * llama_constraint_init_top_p_impl (float p, size_t min_keep);
struct llama_constraint * llama_constraint_init_min_p_impl (float p, size_t min_keep); struct llama_constraint * llama_constraint_init_min_p_impl (float p, size_t min_keep);
struct llama_constraint * llama_constraint_init_tail_free_impl (float z, size_t min_keep); struct llama_constraint * llama_constraint_init_tail_free_impl (float z, size_t min_keep);

View file

@ -20611,8 +20611,8 @@ struct llama_constraint * llama_constraint_init_softmax(void) {
return llama_constraint_init_softmax_impl(); return llama_constraint_init_softmax_impl();
} }
struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep) { struct llama_constraint * llama_constraint_init_top_k(int32_t k) {
return llama_constraint_init_top_k_impl(k, min_keep); return llama_constraint_init_top_k_impl(k);
} }
struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep) { struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep) {

View file

@ -19,7 +19,7 @@ static void dump(const llama_token_data_array * cur_p) {
#define DUMP(__cur_p) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__cur_p)); printf("-\n"); } while(0) #define DUMP(__cur_p) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__cur_p)); printf("-\n"); } while(0)
#define TEST(__cnstr, __cur_p) do { \ #define APPLY(__cnstr, __cur_p) do { \
auto * cnstr = (__cnstr); \ auto * cnstr = (__cnstr); \
llama_constraint_apply(cnstr, (__cur_p)); \ llama_constraint_apply(cnstr, (__cur_p)); \
llama_constraint_free(cnstr); \ llama_constraint_free(cnstr); \
@ -36,9 +36,9 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
} }
llama_token_data_array cur_p = { cur.data(), cur.size(), false }; llama_token_data_array cur_p = { cur.data(), cur.size(), false };
TEST(llama_constraint_init_softmax(), &cur_p); APPLY(llama_constraint_init_softmax(), &cur_p);
DUMP(&cur_p); DUMP(&cur_p);
TEST(llama_constraint_init_top_k(k, 1), &cur_p); APPLY(llama_constraint_init_top_k(k), &cur_p);
DUMP(&cur_p); DUMP(&cur_p);
GGML_ASSERT(cur_p.size == expected_probs.size()); GGML_ASSERT(cur_p.size == expected_probs.size());
@ -58,9 +58,9 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
} }
llama_token_data_array cur_p = { cur.data(), cur.size(), false }; llama_token_data_array cur_p = { cur.data(), cur.size(), false };
TEST(llama_constraint_init_softmax(), &cur_p); APPLY(llama_constraint_init_softmax(), &cur_p);
DUMP(&cur_p); DUMP(&cur_p);
TEST(llama_constraint_init_top_p(p, 1), &cur_p); APPLY(llama_constraint_init_top_p(p, 1), &cur_p);
DUMP(&cur_p); DUMP(&cur_p);
GGML_ASSERT(cur_p.size == expected_probs.size()); GGML_ASSERT(cur_p.size == expected_probs.size());
@ -81,7 +81,7 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
llama_token_data_array cur_p = { cur.data(), cur.size(), false }; llama_token_data_array cur_p = { cur.data(), cur.size(), false };
DUMP(&cur_p); DUMP(&cur_p);
TEST(llama_constraint_init_tail_free(z, 1), &cur_p); APPLY(llama_constraint_init_tail_free(z, 1), &cur_p);
DUMP(&cur_p); DUMP(&cur_p);
GGML_ASSERT(cur_p.size == expected_probs.size()); GGML_ASSERT(cur_p.size == expected_probs.size());
@ -102,9 +102,9 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
llama_token_data_array cur_p = { cur.data(), cur.size(), false }; llama_token_data_array cur_p = { cur.data(), cur.size(), false };
DUMP(&cur_p); DUMP(&cur_p);
TEST(llama_constraint_init_min_p(p, 1), &cur_p); APPLY(llama_constraint_init_min_p(p, 1), &cur_p);
DUMP(&cur_p); DUMP(&cur_p);
TEST(llama_constraint_init_softmax(), &cur_p); APPLY(llama_constraint_init_softmax(), &cur_p);
GGML_ASSERT(cur_p.size == expected_probs.size()); GGML_ASSERT(cur_p.size == expected_probs.size());
for (size_t i = 0; i < cur_p.size; i++) { for (size_t i = 0; i < cur_p.size; i++) {
@ -124,7 +124,7 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
llama_token_data_array cur_p = { cur.data(), cur.size(), false }; llama_token_data_array cur_p = { cur.data(), cur.size(), false };
DUMP(&cur_p); DUMP(&cur_p);
TEST(llama_constraint_init_typical(p, 1), &cur_p); APPLY(llama_constraint_init_typical(p, 1), &cur_p);
DUMP(&cur_p); DUMP(&cur_p);
GGML_ASSERT(cur_p.size == expected_probs.size()); GGML_ASSERT(cur_p.size == expected_probs.size());
@ -154,10 +154,10 @@ static void test_penalties(
} }
llama_token_data_array cur_p = { cur.data(), cur.size(), false }; llama_token_data_array cur_p = { cur.data(), cur.size(), false };
TEST(llama_constraint_init_softmax(), &cur_p); APPLY(llama_constraint_init_softmax(), &cur_p);
DUMP(&cur_p); DUMP(&cur_p);
llama_constraint_penalties_impl(&cur_p, token_count, repeat_penalty, alpha_frequency, alpha_presence); // TODO: avoid llama_constraint_penalties_impl(&cur_p, token_count, repeat_penalty, alpha_frequency, alpha_presence); // TODO: avoid
TEST(llama_constraint_init_softmax(), &cur_p); APPLY(llama_constraint_init_softmax(), &cur_p);
DUMP(&cur_p); DUMP(&cur_p);
GGML_ASSERT(cur_p.size == expected_probs.size()); GGML_ASSERT(cur_p.size == expected_probs.size());
@ -182,16 +182,16 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
for (auto s : samplers_sequence) { for (auto s : samplers_sequence) {
switch (s){ switch (s){
case 'k': TEST(llama_constraint_init_top_k(top_k, 1), &cur_p); break; case 'k': APPLY(llama_constraint_init_top_k(top_k), &cur_p); break;
case 'f': GGML_ABORT("tail_free test not implemented"); case 'f': GGML_ABORT("tail_free test not implemented");
case 'y': GGML_ABORT("typical test not implemented"); case 'y': GGML_ABORT("typical test not implemented");
case 'p': TEST(llama_constraint_init_top_p(top_p, 1), &cur_p); break; case 'p': APPLY(llama_constraint_init_top_p(top_p, 1), &cur_p); break;
case 'm': TEST(llama_constraint_init_min_p(min_p, 1), &cur_p); break; case 'm': APPLY(llama_constraint_init_min_p(min_p, 1), &cur_p); break;
case 't': GGML_ABORT("temperature test not implemented"); case 't': GGML_ABORT("temperature test not implemented");
default : GGML_ABORT("Unknown sampler"); default : GGML_ABORT("Unknown sampler");
} }
TEST(llama_constraint_init_softmax(), &cur_p); // make sure tokens are sorted for tests APPLY(llama_constraint_init_softmax(), &cur_p); // make sure tokens are sorted for tests
const int size = cur_p.size; const int size = cur_p.size;