diff --git a/common/sampling.cpp b/common/sampling.cpp index f5edd87a6..b528d4929 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -47,9 +47,6 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st lparams.seed = params.seed; lparams.n_prev = params.n_prev; - lparams.mirostat = params.mirostat; - lparams.mirostat_tau = params.mirostat_tau; - lparams.mirostat_eta = params.mirostat_eta; auto * result = new gpt_sampler { /* .params = */ params, @@ -69,29 +66,39 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st /* .smpl = */ llama_sampler_init(model, lparams) }; - for (const auto & cnstr : params.constraints) { - switch (cnstr) { - case GPT_CONSTRAINT_TYPE_TOP_K: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep)); - break; - case GPT_CONSTRAINT_TYPE_TOP_P: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep)); - break; - case GPT_CONSTRAINT_TYPE_MIN_P: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep)); - break; - case GPT_CONSTRAINT_TYPE_TFS_Z: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep)); - break; - case GPT_CONSTRAINT_TYPE_TYPICAL_P: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep)); - break; - case GPT_CONSTRAINT_TYPE_TEMPERATURE: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); - break; - default: - GGML_ASSERT(false && "unknown constraint type"); + if (params.mirostat == 0) { + for (const auto & cnstr : params.constraints) { + switch (cnstr) { + case GPT_CONSTRAINT_TYPE_TOP_K: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TOP_P: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_MIN_P: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TFS_Z: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TYPICAL_P: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TEMPERATURE: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + break; + default: + GGML_ASSERT(false && "unknown constraint type"); + } } + } else if (params.mirostat == 1) { + llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp)); + llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta)); + } else if (params.mirostat == 2) { + llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp)); + llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta)); + } else { + GGML_ASSERT(false && "unknown mirostat version"); } return result; @@ -153,7 +160,6 @@ static llama_token gpt_sampler_sample( struct llama_sampler * smpl, struct llama_token_data_array * cur_p, float temp, - int mirostat, int n_probs) { llama_token res = 0; @@ -167,24 +173,20 @@ static llama_token gpt_sampler_sample( // apply all sampling constraints and then sample llama_sampler_apply(smpl, cur_p); - if (mirostat != 0) { - res = llama_sampler_sample_mirostat(smpl, cur_p); - } else { - res = llama_sampler_sample_dist(smpl, cur_p); + res = llama_sampler_sample_dist(smpl, cur_p); - //{ - // const int n_top = 10; - // LOG("top %d candidates:\n", n_top); + //{ + // const int n_top = 10; + // LOG("top %d candidates:\n", n_top); - // for (int i = 0; i < n_top; i++) { - // const llama_token id = cur_p.data[i].id; - // (void)id; // To avoid a warning that id is unused when logging is disabled. - // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p); - // } - //} + // for (int i = 0; i < n_top; i++) { + // const llama_token id = cur_p.data[i].id; + // (void)id; // To avoid a warning that id is unused when logging is disabled. + // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p); + // } + //} - //LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str()); - } + //LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str()); } return res; @@ -208,7 +210,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context llama_constraint_apply(pnlt, cur_p); // first, sample the token without any grammar constraints - const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.mirostat, params.n_probs); + const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.n_probs); // check if it the sampled token fits the grammar { @@ -231,7 +233,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context llama_constraint_apply(pnlt, cur_p); llama_constraint_apply(grmr, cur_p); - return gpt_sampler_sample(smpl, cur_p, params.temp, params.mirostat, params.n_probs); + return gpt_sampler_sample(smpl, cur_p, params.temp, params.n_probs); } void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p) { diff --git a/include/llama.h b/include/llama.h index bf67f24c7..8c2a2aff9 100644 --- a/include/llama.h +++ b/include/llama.h @@ -369,16 +369,18 @@ extern "C" { float bias; } llama_logit_bias; + enum llama_sampler_type { + LLAMA_SAMPLER_TYPE_GREEDY = 0, + LLAMA_SAMPLER_TYPE_DIST = 1, + }; + typedef struct llama_sampler_params { uint32_t seed; // the seed used to initialize the rng of the sampler int32_t n_prev; // size of ring buffer to keep previous accepted tokens (needed for llama_sampler_prev_ API) - int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau; // target entropy - float mirostat_eta; // learning rate - - // TODO: add type of sampler: greedy, dist, mirostat, etc. + // TODO: will be used by the llama_decode_with_sampler() API in the future + enum llama_sampler_type type; } llama_sampler_params; // performance timing information @@ -1005,17 +1007,18 @@ extern "C" { // // - Samplers // The llama_sampler samples a token based on the candidate token probabilities. Before the actual sampling, the - // sampler can apply a sequence of constraints to the candidate tokens. + // sampler can apply a sequence of constraints in order to modify the probabilities of the candidates. // // The llama_sampler object contains the entire sampling information: // // - RNG state (seed and generator) // - Custom set of constraints (see llama_sampler_add_constraint) - // - Sampling method (greedy, dist, mirostat) + // - Sampling method (greedy, dist) // - Previous tokens // // In the future, it will be utilized offload the sampling to the backends (e.g. GPU). // + // TODO: in the future, the entire API should be changed to accept llama_vocab, instead of llama_model // constraints @@ -1041,14 +1044,23 @@ extern "C" { llama_constraint_context_t ctx; }; - LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void); - LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t); - LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent); + LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void); + LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t); + LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent); + + LLAMA_API struct llama_constraint * llama_constraint_init_mirostat( + const struct llama_model * model, + float tau, + float eta); + + LLAMA_API struct llama_constraint * llama_constraint_init_mirostat_v2( + float tau, + float eta); LLAMA_API struct llama_constraint * llama_constraint_init_grammar( const struct llama_model * model, @@ -1095,9 +1107,8 @@ extern "C" { LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p); - LLAMA_API llama_token llama_sampler_sample_dist (struct llama_sampler * smpl, llama_token_data_array * cur_p); - LLAMA_API llama_token llama_sampler_sample_greedy (struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs); - LLAMA_API llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_token_data_array * cur_p); + LLAMA_API llama_token llama_sampler_sample_dist (struct llama_sampler * smpl, llama_token_data_array * cur_p); + LLAMA_API llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs); /// @details Get the number of accepted tokens so far (max of n_prev) LLAMA_API int llama_sampler_n_prev(const struct llama_sampler * smpl); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 733957fdd..385f7bec1 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -450,8 +450,8 @@ struct llama_constraint * llama_constraint_init_softmax_impl() { // top-k struct llama_constraint_context_top_k { - int32_t k; - size_t min_keep; + const int32_t k; + const size_t min_keep; }; static struct llama_constraint_i llama_constraint_top_k_i = { @@ -486,8 +486,8 @@ struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min // top-p struct llama_constraint_context_top_p { - float p; - size_t min_keep; + const float p; + const size_t min_keep; }; static struct llama_constraint_i llama_constraint_top_p_i = { @@ -522,8 +522,8 @@ struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_k // min-p struct llama_constraint_context_min_p { - float p; - size_t min_keep; + const float p; + const size_t min_keep; }; static struct llama_constraint_i llama_constraint_min_p_i = { @@ -558,8 +558,8 @@ struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_k // tail-free struct llama_constraint_context_tail_free { - float z; - size_t min_keep; + const float z; + const size_t min_keep; }; static struct llama_constraint_i llama_constraint_tail_free_i = { @@ -594,8 +594,8 @@ struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t m // typical struct llama_constraint_context_typical { - float p; - size_t min_keep; + const float p; + const size_t min_keep; }; static struct llama_constraint_i llama_constraint_typical_i = { @@ -630,7 +630,7 @@ struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min // temp struct llama_constraint_context_temp { - float temp; + const float temp; }; static struct llama_constraint_i llama_constraint_temp_i = { @@ -664,9 +664,9 @@ struct llama_constraint * llama_constraint_init_temp_impl(float temp) { // temp-ext struct llama_constraint_context_temp_ext { - float temp; - float delta; - float exponent; + const float temp; + const float delta; + const float exponent; }; static struct llama_constraint_i llama_constraint_temp_ext_i = { @@ -706,6 +706,176 @@ struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float return result; } +// mirostat + +struct llama_constraint_context_mirostat { + const struct llama_vocab * vocab; + + const float tau; + const float eta; + + const int32_t m; + + float mu; + + std::vector cur; +}; + +static struct llama_constraint_i llama_constraint_mirostat_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "mirostat"; }, + /* .accept = */ [](struct llama_constraint * cnstr, llama_token token) { + auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx; + + int32_t idx = -1; + for (size_t i = 0; i < ctx->cur.size(); ++i) { + if (ctx->cur[i].id == token) { + idx = i; + break; + } + } + + float observed_surprise = -log2f(ctx->cur[idx].p); + float e = observed_surprise - ctx->tau; + + // Update mu using the learning rate and error + ctx->mu = ctx->mu - ctx->eta * e; + }, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { + auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx; + + llama_constraint_softmax_impl(cur_p); + + // Estimate s_hat using the most probable m tokens + float s_hat = 0.0; + float sum_ti_bi = 0.0; + float sum_ti_sq = 0.0; + for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) { + float t_i = logf(float(i + 2) / float(i + 1)); + float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p); + sum_ti_bi += t_i * b_i; + sum_ti_sq += t_i * t_i; + } + s_hat = sum_ti_bi / sum_ti_sq; + + // Compute k from the estimated s_hat and target surprise value + float epsilon_hat = s_hat - 1; + float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->vocab->n_vocab, -epsilon_hat)), 1 / s_hat); + + llama_constraint_top_k_impl(cur_p, int(k), 1); + + // remember the order to be able to compute the distance later when accepting the token + ctx->cur.resize(cur_p->size); + for (size_t i = 0; i < cur_p->size; ++i) { + ctx->cur[i] = cur_p->data[i]; + } + }, + /* .reset = */ [](struct llama_constraint * cnstr) { + auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx; + ctx->mu = 0.0f; + }, + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx = (const llama_constraint_context_mirostat *) cnstr->ctx; + return llama_constraint_init_mirostat_impl(*ctx->vocab, ctx->tau, ctx->eta, ctx->m); + }, + /* .free = */ [](struct llama_constraint * cnstr) { + delete (llama_constraint_context_mirostat *) cnstr->ctx; + }, +}; + +struct llama_constraint * llama_constraint_init_mirostat_impl( + const struct llama_vocab & vocab, + float tau, + float eta, + int32_t m) { + struct llama_constraint * result = new llama_constraint { + /* .iface = */ &llama_constraint_mirostat_i, + /* .ctx = */ new llama_constraint_context_mirostat { + /*.vocab =*/ &vocab, + /*.tau =*/ tau, + /*.eta =*/ eta, + /*.m =*/ m, + /*.mu =*/ 0.0f, + /*.cur =*/ {}, + }, + }; + + return result; +} + +// mirostat v2 + +struct llama_constraint_context_mirostat_v2 { + const float tau; + const float eta; + + float mu; + + std::vector cur; +}; + +static struct llama_constraint_i llama_constraint_mirostat_v2_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "mirostat-v2"; }, + /* .accept = */ [](struct llama_constraint * cnstr, llama_token token) { + auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx; + + int32_t idx = -1; + for (size_t i = 0; i < ctx->cur.size(); ++i) { + if (ctx->cur[i].id == token) { + idx = i; + break; + } + } + + float observed_surprise = -log2f(ctx->cur[idx].p); + float e = observed_surprise - ctx->tau; + + // Update mu using the learning rate and error + ctx->mu = ctx->mu - ctx->eta * e; + }, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { + auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx; + + llama_constraint_softmax_impl(cur_p); + + // Truncate the words with surprise values greater than mu + cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) { + return -log2f(candidate.p) > ctx->mu; + })); + + if (cur_p->size == 0) { + cur_p->size = 1; + } + + // Normalize the probabilities of the remaining words + llama_constraint_softmax_impl(cur_p); + }, + /* .reset = */ [](struct llama_constraint * cnstr) { + auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx; + ctx->mu = 0.0f; + }, + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx = (const llama_constraint_context_mirostat_v2 *) cnstr->ctx; + return llama_constraint_init_mirostat_v2_impl(ctx->tau, ctx->eta); + }, + /* .free = */ [](struct llama_constraint * cnstr) { + delete (llama_constraint_context_mirostat_v2 *) cnstr->ctx; + }, +}; + +struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, float eta) { + struct llama_constraint * result = new llama_constraint { + /* .iface = */ &llama_constraint_mirostat_v2_i, + /* .ctx = */ new llama_constraint_context_mirostat_v2 { + /*.tau =*/ tau, + /*.eta =*/ eta, + /*.mu =*/ 0.0f, + /*.cur =*/ {}, + }, + }; + + return result; +} + // grammar struct llama_constraint_context_grammar { @@ -796,13 +966,13 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_ struct llama_constraint_context_penalties { const struct llama_vocab * vocab; - int32_t penalty_last_n; - float penalty_repeat; - float penalty_freq; - float penalty_present; + const int32_t penalty_last_n; + const float penalty_repeat; + const float penalty_freq; + const float penalty_present; - bool penalize_nl; - bool ignore_eos; + const bool penalize_nl; + const bool ignore_eos; ring_buffer prev; }; @@ -980,7 +1150,6 @@ struct llama_sampler * llama_sampler_init_impl(const struct llama_vocab & vocab, /* .rng = */ std::mt19937(params.seed), - /* .mirostat_mu = */ 0.0f, /* .prev = */ { (size_t) params.n_prev }, /* .constraints = */ {}, /* .cur = */ {}, @@ -1011,7 +1180,6 @@ struct llama_sampler * llama_sampler_cp_impl(const struct llama_sampler & smpl) /* .rng = */ smpl.rng, - /* .mirostat_mu = */ smpl.mirostat_mu, /* .prev = */ smpl.prev, /* .constraints = */ {}, /* .cur = */ {}, @@ -1077,74 +1245,6 @@ int llama_sampler_n_prev_impl(const struct llama_sampler & smpl) { return smpl.prev.size(); } -llama_token llama_sampler_sample_mirostat_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) { - llama_constraint_softmax_impl(cur_p); - - // Estimate s_hat using the most probable m tokens - float s_hat = 0.0; - float sum_ti_bi = 0.0; - float sum_ti_sq = 0.0; - for (size_t i = 0; i < size_t(m - 1) && i < cur_p->size - 1; ++i) { - float t_i = logf(float(i + 2) / float(i + 1)); - float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p); - sum_ti_bi += t_i * b_i; - sum_ti_sq += t_i * t_i; - } - s_hat = sum_ti_bi / sum_ti_sq; - - // Compute k from the estimated s_hat and target surprise value - float epsilon_hat = s_hat - 1; - float k = powf((epsilon_hat * powf(2, mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); - - // Sample the next word X using top-k sampling - llama_constraint_top_k_impl(cur_p, int(k), 1); - llama_token X = llama_sampler_sample_dist_impl(cur_p, rng); - - // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) { - return candidate.id == X; - })); - float observed_surprise = -log2f(cur_p->data[X_idx].p); - float e = observed_surprise - tau; - - // Update mu using the learning rate and error - mu = mu - eta * e; - - return X; -} - -llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, float tau, float eta, float & mu) { - llama_constraint_softmax_impl(cur_p); - - // Truncate the words with surprise values greater than mu - cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) { - return -log2f(candidate.p) > mu; - })); - - if (cur_p->size == 0) { - cur_p->size = 1; - } - - // Normalize the probabilities of the remaining words - llama_constraint_softmax_impl(cur_p); - - // Sample the next word X from the remaining words - llama_token X = llama_sampler_sample_dist_impl(cur_p, rng); - - // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) { - return candidate.id == X; - })); - - float observed_surprise = -log2f(cur_p->data[X_idx].p); - float e = observed_surprise - tau; - - // Update mu using the learning rate and error - mu = mu - eta * e; - - return X; -} - llama_token llama_sampler_sample_greedy_impl(llama_token_data_array * cur_p, bool probs) { if (probs) { // if probs are needed, we apply softmax to get the probabilities diff --git a/src/llama-sampling.h b/src/llama-sampling.h index e4f910886..aad9f311a 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -20,16 +20,38 @@ void llama_constraint_penalties_impl( // constraints -struct llama_constraint * llama_constraint_init_softmax_impl (); -struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k, size_t min_keep); -struct llama_constraint * llama_constraint_init_top_p_impl (float p, size_t min_keep); -struct llama_constraint * llama_constraint_init_min_p_impl (float p, size_t min_keep); -struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t min_keep); -struct llama_constraint * llama_constraint_init_typical_impl (float p, size_t min_keep); -struct llama_constraint * llama_constraint_init_temp_impl (float t); -struct llama_constraint * llama_constraint_init_temp_ext_impl (float t, float delta, float exponent); +struct llama_constraint * llama_constraint_init_softmax_impl (); +struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k, size_t min_keep); +struct llama_constraint * llama_constraint_init_top_p_impl (float p, size_t min_keep); +struct llama_constraint * llama_constraint_init_min_p_impl (float p, size_t min_keep); +struct llama_constraint * llama_constraint_init_tail_free_impl (float z, size_t min_keep); +struct llama_constraint * llama_constraint_init_typical_impl (float p, size_t min_keep); +struct llama_constraint * llama_constraint_init_temp_impl (float t); +struct llama_constraint * llama_constraint_init_temp_ext_impl (float t, float delta, float exponent); -struct llama_constraint * llama_constraint_init_grammar_impl ( +/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. +/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. +/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. +/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. +/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. +/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + +struct llama_constraint * llama_constraint_init_mirostat_impl( + const struct llama_vocab & vocab, + float tau, + float eta, + int32_t m); + +/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. +/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. +/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. +/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. +/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. +struct llama_constraint * llama_constraint_init_mirostat_v2_impl( + float tau, + float eta); + +struct llama_constraint * llama_constraint_init_grammar_impl( const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root); @@ -67,8 +89,6 @@ struct llama_sampler { std::mt19937 rng; - float mirostat_mu; - ring_buffer prev; std::vector constraints; @@ -97,20 +117,5 @@ void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_d llama_token llama_sampler_prev_impl (const struct llama_sampler & smpl, int ith); int llama_sampler_n_prev_impl(const struct llama_sampler & smpl); -/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. -/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. -/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. -/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. -/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. -/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. -llama_token llama_sampler_sample_mirostat_impl (struct llama_token_data_array * cur_p, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu); - -/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. -/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. -/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. -/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. -/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. -llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, float tau, float eta, float & mu); - llama_token llama_sampler_sample_greedy_impl(struct llama_token_data_array * cur_p, bool probs); llama_token llama_sampler_sample_dist_impl (struct llama_token_data_array * cur_p, std::mt19937 & rng); diff --git a/src/llama.cpp b/src/llama.cpp index 712d9cfb5..8f6503152 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17939,9 +17939,7 @@ struct llama_sampler_params llama_sampler_default_params() { struct llama_sampler_params result = { /*.seed =*/ LLAMA_DEFAULT_SEED, /*.n_prev =*/ 256, - /*.mirostat =*/ 0, - /*.mirostat_tau =*/ 5.00f, - /*.mirostat_eta =*/ 0.10f, + /*.type =*/ LLAMA_SAMPLER_TYPE_GREEDY, }; return result; @@ -20641,6 +20639,14 @@ struct llama_constraint * llama_constraint_init_temp_ext(float temp, float delta return llama_constraint_init_temp_ext_impl(temp, delta, exponent); } +struct llama_constraint * llama_constraint_init_mirostat(const struct llama_model * model, float tau, float eta) { + return llama_constraint_init_mirostat_impl(model->vocab, tau, eta, 100); +} + +struct llama_constraint * llama_constraint_init_mirostat_v2(float tau, float eta) { + return llama_constraint_init_mirostat_v2_impl(tau, eta); +} + struct llama_constraint * llama_constraint_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) { return llama_constraint_init_grammar_impl(model->vocab, grammar_str, grammar_root); } @@ -20741,40 +20747,6 @@ void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * c llama_sampler_apply_impl(*smpl, cur_p); } -llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - time_meas tm(smpl->t_sample_us); - - if (cur_p == nullptr) { - cur_p = &smpl->cur_p; - } - - const auto type = smpl->params.mirostat; - - llama_token res; - - if (type == 1) { - res = llama_sampler_sample_mirostat_impl(cur_p, - smpl->rng, - smpl->params.mirostat_tau, - smpl->params.mirostat_eta, - 100, - smpl->vocab->n_vocab, - smpl->mirostat_mu); - } else if (type == 2) { - res = llama_sampler_sample_mirostat_v2_impl(cur_p, - smpl->rng, - smpl->params.mirostat_tau, - smpl->params.mirostat_eta, - smpl->mirostat_mu); - } else { - GGML_ABORT("invalid mirostat type: %d", type); - } - - smpl->n_sample++; - - return res; -} - llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs) { time_meas tm(smpl->t_sample_us);