sampling : remove top-k min_keep, fix mirostat init and state
This commit is contained in:
parent
b2b36e9e95
commit
69551ffd60
8 changed files with 76 additions and 73 deletions
|
@ -67,7 +67,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
|||
for (const auto & cnstr : params.constraints) {
|
||||
switch (cnstr) {
|
||||
case GPT_CONSTRAINT_TYPE_TOP_K:
|
||||
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep));
|
||||
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k));
|
||||
break;
|
||||
case GPT_CONSTRAINT_TYPE_TOP_P:
|
||||
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep));
|
||||
|
|
|
@ -61,7 +61,7 @@ defer {
|
|||
llama_sampler_free(smpl)
|
||||
}
|
||||
|
||||
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(40, 1));
|
||||
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(40));
|
||||
llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(0.9, 1));
|
||||
llama_sampler_constraint_add(smpl, llama_constraint_init_temp (0.4));
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_sampler * smpl = llama_sampler_init(model, sparams);
|
||||
|
||||
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep));
|
||||
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(params.sparams.top_k));
|
||||
llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep));
|
||||
llama_sampler_constraint_add(smpl, llama_constraint_init_temp (params.sparams.temp));
|
||||
|
||||
|
|
|
@ -1045,7 +1045,7 @@ extern "C" {
|
|||
};
|
||||
|
||||
LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void);
|
||||
LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep);
|
||||
LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k);
|
||||
LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep);
|
||||
LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep);
|
||||
LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep);
|
||||
|
|
|
@ -49,7 +49,7 @@ static void llama_constraint_softmax_impl(llama_token_data_array * cur_p) {
|
|||
}
|
||||
}
|
||||
|
||||
static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t k, size_t min_keep) {
|
||||
static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
|
||||
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
|
||||
// if (k >= (int32_t)cur_p->size) {
|
||||
// return;
|
||||
|
@ -59,7 +59,6 @@ static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t
|
|||
k = cur_p->size;
|
||||
}
|
||||
|
||||
k = std::max(k, (int) min_keep);
|
||||
k = std::min(k, (int) cur_p->size);
|
||||
|
||||
// Sort scores in descending order
|
||||
|
@ -449,7 +448,6 @@ struct llama_constraint * llama_constraint_init_softmax_impl() {
|
|||
|
||||
struct llama_constraint_context_top_k {
|
||||
const int32_t k;
|
||||
const size_t min_keep;
|
||||
};
|
||||
|
||||
static struct llama_constraint_i llama_constraint_top_k_i = {
|
||||
|
@ -457,24 +455,23 @@ static struct llama_constraint_i llama_constraint_top_k_i = {
|
|||
/* .accept = */ nullptr,
|
||||
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
|
||||
const auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx;
|
||||
llama_constraint_top_k_impl(cur_p, ctx->k, ctx->min_keep);
|
||||
llama_constraint_top_k_impl(cur_p, ctx->k);
|
||||
},
|
||||
/* .reset = */ nullptr,
|
||||
/* .copy = */ [](const struct llama_constraint * cnstr) {
|
||||
const auto * ctx = (const llama_constraint_context_top_k *) cnstr->ctx;
|
||||
return llama_constraint_init_top_k_impl(ctx->k, ctx->min_keep);
|
||||
return llama_constraint_init_top_k_impl(ctx->k);
|
||||
},
|
||||
/* .free = */ [](struct llama_constraint * cnstr) {
|
||||
delete (llama_constraint_context_top_k *) cnstr->ctx;
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep) {
|
||||
struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k) {
|
||||
return new llama_constraint {
|
||||
/* .iface = */ &llama_constraint_top_k_i,
|
||||
/* .ctx = */ new llama_constraint_context_top_k {
|
||||
/*.k =*/ k,
|
||||
/*.min_keep =*/ min_keep,
|
||||
/* .k = */ k,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
@ -507,8 +504,8 @@ struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_k
|
|||
return new llama_constraint {
|
||||
/* .iface = */ &llama_constraint_top_p_i,
|
||||
/* .ctx = */ new llama_constraint_context_top_p {
|
||||
/*.p =*/ p,
|
||||
/*.min_keep =*/ min_keep,
|
||||
/* .p = */ p,
|
||||
/* .min_keep = */ min_keep,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
@ -541,8 +538,8 @@ struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_k
|
|||
return new llama_constraint {
|
||||
/* .iface = */ &llama_constraint_min_p_i,
|
||||
/* .ctx = */ new llama_constraint_context_min_p {
|
||||
/*.p =*/ p,
|
||||
/*.min_keep =*/ min_keep,
|
||||
/* .p = */ p,
|
||||
/* .min_keep = */ min_keep,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
@ -575,8 +572,8 @@ struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t m
|
|||
return new llama_constraint {
|
||||
/* .iface = */ &llama_constraint_tail_free_i,
|
||||
/* .ctx = */ new llama_constraint_context_tail_free {
|
||||
/*.z =*/ z,
|
||||
/*.min_keep =*/ min_keep,
|
||||
/* .z = */ z,
|
||||
/*. min_keep = */ min_keep,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
@ -609,8 +606,8 @@ struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min
|
|||
return new llama_constraint {
|
||||
/* .iface = */ &llama_constraint_typical_i,
|
||||
/* .ctx = */ new llama_constraint_context_typical {
|
||||
/*.p =*/ p,
|
||||
/*.min_keep =*/ min_keep,
|
||||
/* .p = */ p,
|
||||
/* .min_keep = */ min_keep,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
@ -642,7 +639,7 @@ struct llama_constraint * llama_constraint_init_temp_impl(float temp) {
|
|||
return new llama_constraint {
|
||||
/* .iface = */ &llama_constraint_temp_i,
|
||||
/* .ctx = */ new llama_constraint_context_temp {
|
||||
/*.temp =*/ temp,
|
||||
/*.temp = */ temp,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
@ -683,9 +680,9 @@ struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float
|
|||
return new llama_constraint {
|
||||
/* .iface = */ &llama_constraint_temp_ext_i,
|
||||
/* .ctx = */ new llama_constraint_context_temp_ext {
|
||||
/*.temp =*/ temp,
|
||||
/*.delta =*/ delta,
|
||||
/*.exponent =*/ exponent,
|
||||
/* .temp = */ temp,
|
||||
/* .delta = */ delta,
|
||||
/* .exponent = */ exponent,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
@ -745,7 +742,7 @@ static struct llama_constraint_i llama_constraint_mirostat_i = {
|
|||
float epsilon_hat = s_hat - 1;
|
||||
float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->vocab->n_vocab, -epsilon_hat)), 1 / s_hat);
|
||||
|
||||
llama_constraint_top_k_impl(cur_p, int(k), 1);
|
||||
llama_constraint_top_k_impl(cur_p, std::max(int(k), 1));
|
||||
|
||||
// remember the order to be able to compute the distance later when accepting the token
|
||||
ctx->cur.resize(cur_p->size);
|
||||
|
@ -755,7 +752,7 @@ static struct llama_constraint_i llama_constraint_mirostat_i = {
|
|||
},
|
||||
/* .reset = */ [](struct llama_constraint * cnstr) {
|
||||
auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx;
|
||||
ctx->mu = 0.0f;
|
||||
ctx->mu = 2.0f*ctx->tau;
|
||||
},
|
||||
/* .copy = */ [](const struct llama_constraint * cnstr) {
|
||||
const auto * ctx = (const llama_constraint_context_mirostat *) cnstr->ctx;
|
||||
|
@ -770,12 +767,12 @@ struct llama_constraint * llama_constraint_init_mirostat_impl(const struct llama
|
|||
return new llama_constraint {
|
||||
/* .iface = */ &llama_constraint_mirostat_i,
|
||||
/* .ctx = */ new llama_constraint_context_mirostat {
|
||||
/*.vocab =*/ &vocab,
|
||||
/*.tau =*/ tau,
|
||||
/*.eta =*/ eta,
|
||||
/*.m =*/ m,
|
||||
/*.mu =*/ 0.0f,
|
||||
/*.cur =*/ {},
|
||||
/* .vocab = */ &vocab,
|
||||
/* .tau = */ tau,
|
||||
/* .eta = */ eta,
|
||||
/* .m = */ m,
|
||||
/* .mu = */ 2.0f*tau,
|
||||
/* .cur = */ {},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
@ -826,10 +823,16 @@ static struct llama_constraint_i llama_constraint_mirostat_v2_i = {
|
|||
|
||||
// Normalize the probabilities of the remaining words
|
||||
llama_constraint_softmax_impl(cur_p);
|
||||
|
||||
// remember the order to be able to compute the distance later when accepting the token
|
||||
ctx->cur.resize(cur_p->size);
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
ctx->cur[i] = cur_p->data[i];
|
||||
}
|
||||
},
|
||||
/* .reset = */ [](struct llama_constraint * cnstr) {
|
||||
auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx;
|
||||
ctx->mu = 0.0f;
|
||||
ctx->mu = 2.0f*ctx->tau;
|
||||
},
|
||||
/* .copy = */ [](const struct llama_constraint * cnstr) {
|
||||
const auto * ctx = (const llama_constraint_context_mirostat_v2 *) cnstr->ctx;
|
||||
|
@ -844,10 +847,10 @@ struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, floa
|
|||
return new llama_constraint {
|
||||
/* .iface = */ &llama_constraint_mirostat_v2_i,
|
||||
/* .ctx = */ new llama_constraint_context_mirostat_v2 {
|
||||
/*.tau =*/ tau,
|
||||
/*.eta =*/ eta,
|
||||
/*.mu =*/ 0.0f,
|
||||
/*.cur =*/ {},
|
||||
/* .tau = */ tau,
|
||||
/* .eta = */ eta,
|
||||
/* .mu = */ 2.0f*tau,
|
||||
/* .cur = */ {},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
@ -919,17 +922,17 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_
|
|||
|
||||
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
||||
*ctx = {
|
||||
/*.vocab = */ &vocab,
|
||||
/*.grammar_str = */ grammar_str,
|
||||
/*.grammar_root = */ grammar_root,
|
||||
/*.grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
|
||||
/* .vocab = */ &vocab,
|
||||
/* .grammar_str = */ grammar_str,
|
||||
/* .grammar_root = */ grammar_root,
|
||||
/* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
|
||||
};
|
||||
} else {
|
||||
*ctx = {
|
||||
/*.vocab = */ &vocab,
|
||||
/*.grammar_str = */ {},
|
||||
/*.grammar_root = */ {},
|
||||
/*.grammar = */ nullptr,
|
||||
/* .vocab = */ &vocab,
|
||||
/* .grammar_str = */ {},
|
||||
/* .grammar_root = */ {},
|
||||
/* .grammar = */ nullptr,
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -1023,14 +1026,14 @@ struct llama_constraint * llama_constraint_init_penalties_impl(const struct llam
|
|||
return new llama_constraint {
|
||||
/* .iface = */ &llama_constraint_penalties_i,
|
||||
/* .ctx = */ new llama_constraint_context_penalties {
|
||||
/*.vocab =*/ &vocab,
|
||||
/*.penalty_last_n =*/ penalty_last_n,
|
||||
/*.penalty_repeat =*/ penalty_repeat,
|
||||
/*.penalty_freq =*/ penalty_freq,
|
||||
/*.penalty_present =*/ penalty_present,
|
||||
/*.penalize_nl =*/ penalize_nl,
|
||||
/*.ignore_eos =*/ ignore_eos,
|
||||
/*.prev =*/ ring_buffer<llama_token>(penalty_last_n),
|
||||
/* .vocab = */ &vocab,
|
||||
/* .penalty_last_n = */ penalty_last_n,
|
||||
/* .penalty_repeat = */ penalty_repeat,
|
||||
/* .penalty_freq = */ penalty_freq,
|
||||
/* .penalty_present = */ penalty_present,
|
||||
/* .penalize_nl = */ penalize_nl,
|
||||
/* .ignore_eos = */ ignore_eos,
|
||||
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
@ -1072,8 +1075,8 @@ struct llama_constraint * llama_constraint_init_logit_bias_impl(
|
|||
return new llama_constraint {
|
||||
/* .iface = */ &llama_constraint_logit_bias_i,
|
||||
/* .ctx = */ new llama_constraint_context_logit_bias {
|
||||
/*.vocab =*/ &vocab,
|
||||
/*.logit_bias=*/ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
|
||||
/* .vocab = */ &vocab,
|
||||
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ void llama_constraint_penalties_impl(
|
|||
// constraints
|
||||
|
||||
struct llama_constraint * llama_constraint_init_softmax_impl ();
|
||||
struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k, size_t min_keep);
|
||||
struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k);
|
||||
struct llama_constraint * llama_constraint_init_top_p_impl (float p, size_t min_keep);
|
||||
struct llama_constraint * llama_constraint_init_min_p_impl (float p, size_t min_keep);
|
||||
struct llama_constraint * llama_constraint_init_tail_free_impl (float z, size_t min_keep);
|
||||
|
|
|
@ -20611,8 +20611,8 @@ struct llama_constraint * llama_constraint_init_softmax(void) {
|
|||
return llama_constraint_init_softmax_impl();
|
||||
}
|
||||
|
||||
struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep) {
|
||||
return llama_constraint_init_top_k_impl(k, min_keep);
|
||||
struct llama_constraint * llama_constraint_init_top_k(int32_t k) {
|
||||
return llama_constraint_init_top_k_impl(k);
|
||||
}
|
||||
|
||||
struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep) {
|
||||
|
|
|
@ -19,7 +19,7 @@ static void dump(const llama_token_data_array * cur_p) {
|
|||
|
||||
#define DUMP(__cur_p) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__cur_p)); printf("-\n"); } while(0)
|
||||
|
||||
#define TEST(__cnstr, __cur_p) do { \
|
||||
#define APPLY(__cnstr, __cur_p) do { \
|
||||
auto * cnstr = (__cnstr); \
|
||||
llama_constraint_apply(cnstr, (__cur_p)); \
|
||||
llama_constraint_free(cnstr); \
|
||||
|
@ -36,9 +36,9 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
|
|||
}
|
||||
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||
TEST(llama_constraint_init_softmax(), &cur_p);
|
||||
APPLY(llama_constraint_init_softmax(), &cur_p);
|
||||
DUMP(&cur_p);
|
||||
TEST(llama_constraint_init_top_k(k, 1), &cur_p);
|
||||
APPLY(llama_constraint_init_top_k(k), &cur_p);
|
||||
DUMP(&cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.size == expected_probs.size());
|
||||
|
@ -58,9 +58,9 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
|
|||
}
|
||||
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||
TEST(llama_constraint_init_softmax(), &cur_p);
|
||||
APPLY(llama_constraint_init_softmax(), &cur_p);
|
||||
DUMP(&cur_p);
|
||||
TEST(llama_constraint_init_top_p(p, 1), &cur_p);
|
||||
APPLY(llama_constraint_init_top_p(p, 1), &cur_p);
|
||||
DUMP(&cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.size == expected_probs.size());
|
||||
|
@ -81,7 +81,7 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
|
|||
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||
DUMP(&cur_p);
|
||||
TEST(llama_constraint_init_tail_free(z, 1), &cur_p);
|
||||
APPLY(llama_constraint_init_tail_free(z, 1), &cur_p);
|
||||
DUMP(&cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.size == expected_probs.size());
|
||||
|
@ -102,9 +102,9 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
|
|||
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||
DUMP(&cur_p);
|
||||
TEST(llama_constraint_init_min_p(p, 1), &cur_p);
|
||||
APPLY(llama_constraint_init_min_p(p, 1), &cur_p);
|
||||
DUMP(&cur_p);
|
||||
TEST(llama_constraint_init_softmax(), &cur_p);
|
||||
APPLY(llama_constraint_init_softmax(), &cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.size == expected_probs.size());
|
||||
for (size_t i = 0; i < cur_p.size; i++) {
|
||||
|
@ -124,7 +124,7 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
|
|||
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||
DUMP(&cur_p);
|
||||
TEST(llama_constraint_init_typical(p, 1), &cur_p);
|
||||
APPLY(llama_constraint_init_typical(p, 1), &cur_p);
|
||||
DUMP(&cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.size == expected_probs.size());
|
||||
|
@ -154,10 +154,10 @@ static void test_penalties(
|
|||
}
|
||||
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||
TEST(llama_constraint_init_softmax(), &cur_p);
|
||||
APPLY(llama_constraint_init_softmax(), &cur_p);
|
||||
DUMP(&cur_p);
|
||||
llama_constraint_penalties_impl(&cur_p, token_count, repeat_penalty, alpha_frequency, alpha_presence); // TODO: avoid
|
||||
TEST(llama_constraint_init_softmax(), &cur_p);
|
||||
APPLY(llama_constraint_init_softmax(), &cur_p);
|
||||
DUMP(&cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.size == expected_probs.size());
|
||||
|
@ -182,16 +182,16 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
|
|||
|
||||
for (auto s : samplers_sequence) {
|
||||
switch (s){
|
||||
case 'k': TEST(llama_constraint_init_top_k(top_k, 1), &cur_p); break;
|
||||
case 'k': APPLY(llama_constraint_init_top_k(top_k), &cur_p); break;
|
||||
case 'f': GGML_ABORT("tail_free test not implemented");
|
||||
case 'y': GGML_ABORT("typical test not implemented");
|
||||
case 'p': TEST(llama_constraint_init_top_p(top_p, 1), &cur_p); break;
|
||||
case 'm': TEST(llama_constraint_init_min_p(min_p, 1), &cur_p); break;
|
||||
case 'p': APPLY(llama_constraint_init_top_p(top_p, 1), &cur_p); break;
|
||||
case 'm': APPLY(llama_constraint_init_min_p(min_p, 1), &cur_p); break;
|
||||
case 't': GGML_ABORT("temperature test not implemented");
|
||||
default : GGML_ABORT("Unknown sampler");
|
||||
}
|
||||
|
||||
TEST(llama_constraint_init_softmax(), &cur_p); // make sure tokens are sorted for tests
|
||||
APPLY(llama_constraint_init_softmax(), &cur_p); // make sure tokens are sorted for tests
|
||||
|
||||
const int size = cur_p.size;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue