sampling : remove top-k min_keep, fix mirostat init and state

This commit is contained in:
Georgi Gerganov 2024-09-05 10:18:04 +03:00
parent b2b36e9e95
commit 69551ffd60
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
8 changed files with 76 additions and 73 deletions

View file

@ -67,7 +67,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
for (const auto & cnstr : params.constraints) {
switch (cnstr) {
case GPT_CONSTRAINT_TYPE_TOP_K:
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep));
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k));
break;
case GPT_CONSTRAINT_TYPE_TOP_P:
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep));

View file

@ -61,7 +61,7 @@ defer {
llama_sampler_free(smpl)
}
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(40, 1));
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(40));
llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(0.9, 1));
llama_sampler_constraint_add(smpl, llama_constraint_init_temp (0.4));

View file

@ -70,7 +70,7 @@ int main(int argc, char ** argv) {
llama_sampler * smpl = llama_sampler_init(model, sparams);
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep));
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(params.sparams.top_k));
llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep));
llama_sampler_constraint_add(smpl, llama_constraint_init_temp (params.sparams.temp));

View file

@ -1045,7 +1045,7 @@ extern "C" {
};
LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void);
LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k);
LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep);

View file

@ -49,7 +49,7 @@ static void llama_constraint_softmax_impl(llama_token_data_array * cur_p) {
}
}
static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t k, size_t min_keep) {
static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
// if (k >= (int32_t)cur_p->size) {
// return;
@ -59,7 +59,6 @@ static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t
k = cur_p->size;
}
k = std::max(k, (int) min_keep);
k = std::min(k, (int) cur_p->size);
// Sort scores in descending order
@ -449,7 +448,6 @@ struct llama_constraint * llama_constraint_init_softmax_impl() {
struct llama_constraint_context_top_k {
const int32_t k;
const size_t min_keep;
};
static struct llama_constraint_i llama_constraint_top_k_i = {
@ -457,24 +455,23 @@ static struct llama_constraint_i llama_constraint_top_k_i = {
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
const auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx;
llama_constraint_top_k_impl(cur_p, ctx->k, ctx->min_keep);
llama_constraint_top_k_impl(cur_p, ctx->k);
},
/* .reset = */ nullptr,
/* .copy = */ [](const struct llama_constraint * cnstr) {
const auto * ctx = (const llama_constraint_context_top_k *) cnstr->ctx;
return llama_constraint_init_top_k_impl(ctx->k, ctx->min_keep);
return llama_constraint_init_top_k_impl(ctx->k);
},
/* .free = */ [](struct llama_constraint * cnstr) {
delete (llama_constraint_context_top_k *) cnstr->ctx;
},
};
struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep) {
struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k) {
return new llama_constraint {
/* .iface = */ &llama_constraint_top_k_i,
/* .ctx = */ new llama_constraint_context_top_k {
/* .k = */ k,
/*.min_keep =*/ min_keep,
},
};
}
@ -745,7 +742,7 @@ static struct llama_constraint_i llama_constraint_mirostat_i = {
float epsilon_hat = s_hat - 1;
float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->vocab->n_vocab, -epsilon_hat)), 1 / s_hat);
llama_constraint_top_k_impl(cur_p, int(k), 1);
llama_constraint_top_k_impl(cur_p, std::max(int(k), 1));
// remember the order to be able to compute the distance later when accepting the token
ctx->cur.resize(cur_p->size);
@ -755,7 +752,7 @@ static struct llama_constraint_i llama_constraint_mirostat_i = {
},
/* .reset = */ [](struct llama_constraint * cnstr) {
auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx;
ctx->mu = 0.0f;
ctx->mu = 2.0f*ctx->tau;
},
/* .copy = */ [](const struct llama_constraint * cnstr) {
const auto * ctx = (const llama_constraint_context_mirostat *) cnstr->ctx;
@ -774,7 +771,7 @@ struct llama_constraint * llama_constraint_init_mirostat_impl(const struct llama
/* .tau = */ tau,
/* .eta = */ eta,
/* .m = */ m,
/*.mu =*/ 0.0f,
/* .mu = */ 2.0f*tau,
/* .cur = */ {},
},
};
@ -826,10 +823,16 @@ static struct llama_constraint_i llama_constraint_mirostat_v2_i = {
// Normalize the probabilities of the remaining words
llama_constraint_softmax_impl(cur_p);
// remember the order to be able to compute the distance later when accepting the token
ctx->cur.resize(cur_p->size);
for (size_t i = 0; i < cur_p->size; ++i) {
ctx->cur[i] = cur_p->data[i];
}
},
/* .reset = */ [](struct llama_constraint * cnstr) {
auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx;
ctx->mu = 0.0f;
ctx->mu = 2.0f*ctx->tau;
},
/* .copy = */ [](const struct llama_constraint * cnstr) {
const auto * ctx = (const llama_constraint_context_mirostat_v2 *) cnstr->ctx;
@ -846,7 +849,7 @@ struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, floa
/* .ctx = */ new llama_constraint_context_mirostat_v2 {
/* .tau = */ tau,
/* .eta = */ eta,
/*.mu =*/ 0.0f,
/* .mu = */ 2.0f*tau,
/* .cur = */ {},
},
};

View file

@ -21,7 +21,7 @@ void llama_constraint_penalties_impl(
// constraints
struct llama_constraint * llama_constraint_init_softmax_impl ();
struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k, size_t min_keep);
struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k);
struct llama_constraint * llama_constraint_init_top_p_impl (float p, size_t min_keep);
struct llama_constraint * llama_constraint_init_min_p_impl (float p, size_t min_keep);
struct llama_constraint * llama_constraint_init_tail_free_impl (float z, size_t min_keep);

View file

@ -20611,8 +20611,8 @@ struct llama_constraint * llama_constraint_init_softmax(void) {
return llama_constraint_init_softmax_impl();
}
struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep) {
return llama_constraint_init_top_k_impl(k, min_keep);
struct llama_constraint * llama_constraint_init_top_k(int32_t k) {
return llama_constraint_init_top_k_impl(k);
}
struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep) {

View file

@ -19,7 +19,7 @@ static void dump(const llama_token_data_array * cur_p) {
#define DUMP(__cur_p) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__cur_p)); printf("-\n"); } while(0)
#define TEST(__cnstr, __cur_p) do { \
#define APPLY(__cnstr, __cur_p) do { \
auto * cnstr = (__cnstr); \
llama_constraint_apply(cnstr, (__cur_p)); \
llama_constraint_free(cnstr); \
@ -36,9 +36,9 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
}
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
TEST(llama_constraint_init_softmax(), &cur_p);
APPLY(llama_constraint_init_softmax(), &cur_p);
DUMP(&cur_p);
TEST(llama_constraint_init_top_k(k, 1), &cur_p);
APPLY(llama_constraint_init_top_k(k), &cur_p);
DUMP(&cur_p);
GGML_ASSERT(cur_p.size == expected_probs.size());
@ -58,9 +58,9 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
}
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
TEST(llama_constraint_init_softmax(), &cur_p);
APPLY(llama_constraint_init_softmax(), &cur_p);
DUMP(&cur_p);
TEST(llama_constraint_init_top_p(p, 1), &cur_p);
APPLY(llama_constraint_init_top_p(p, 1), &cur_p);
DUMP(&cur_p);
GGML_ASSERT(cur_p.size == expected_probs.size());
@ -81,7 +81,7 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
DUMP(&cur_p);
TEST(llama_constraint_init_tail_free(z, 1), &cur_p);
APPLY(llama_constraint_init_tail_free(z, 1), &cur_p);
DUMP(&cur_p);
GGML_ASSERT(cur_p.size == expected_probs.size());
@ -102,9 +102,9 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
DUMP(&cur_p);
TEST(llama_constraint_init_min_p(p, 1), &cur_p);
APPLY(llama_constraint_init_min_p(p, 1), &cur_p);
DUMP(&cur_p);
TEST(llama_constraint_init_softmax(), &cur_p);
APPLY(llama_constraint_init_softmax(), &cur_p);
GGML_ASSERT(cur_p.size == expected_probs.size());
for (size_t i = 0; i < cur_p.size; i++) {
@ -124,7 +124,7 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
DUMP(&cur_p);
TEST(llama_constraint_init_typical(p, 1), &cur_p);
APPLY(llama_constraint_init_typical(p, 1), &cur_p);
DUMP(&cur_p);
GGML_ASSERT(cur_p.size == expected_probs.size());
@ -154,10 +154,10 @@ static void test_penalties(
}
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
TEST(llama_constraint_init_softmax(), &cur_p);
APPLY(llama_constraint_init_softmax(), &cur_p);
DUMP(&cur_p);
llama_constraint_penalties_impl(&cur_p, token_count, repeat_penalty, alpha_frequency, alpha_presence); // TODO: avoid
TEST(llama_constraint_init_softmax(), &cur_p);
APPLY(llama_constraint_init_softmax(), &cur_p);
DUMP(&cur_p);
GGML_ASSERT(cur_p.size == expected_probs.size());
@ -182,16 +182,16 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
for (auto s : samplers_sequence) {
switch (s){
case 'k': TEST(llama_constraint_init_top_k(top_k, 1), &cur_p); break;
case 'k': APPLY(llama_constraint_init_top_k(top_k), &cur_p); break;
case 'f': GGML_ABORT("tail_free test not implemented");
case 'y': GGML_ABORT("typical test not implemented");
case 'p': TEST(llama_constraint_init_top_p(top_p, 1), &cur_p); break;
case 'm': TEST(llama_constraint_init_min_p(min_p, 1), &cur_p); break;
case 'p': APPLY(llama_constraint_init_top_p(top_p, 1), &cur_p); break;
case 'm': APPLY(llama_constraint_init_min_p(min_p, 1), &cur_p); break;
case 't': GGML_ABORT("temperature test not implemented");
default : GGML_ABORT("Unknown sampler");
}
TEST(llama_constraint_init_softmax(), &cur_p); // make sure tokens are sorted for tests
APPLY(llama_constraint_init_softmax(), &cur_p); // make sure tokens are sorted for tests
const int size = cur_p.size;