sampling : convert mirostat samplers to constraints
ggml-ci
This commit is contained in:
parent
1a0de0b781
commit
0e1378c844
5 changed files with 304 additions and 214 deletions
|
@ -47,9 +47,6 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
||||||
|
|
||||||
lparams.seed = params.seed;
|
lparams.seed = params.seed;
|
||||||
lparams.n_prev = params.n_prev;
|
lparams.n_prev = params.n_prev;
|
||||||
lparams.mirostat = params.mirostat;
|
|
||||||
lparams.mirostat_tau = params.mirostat_tau;
|
|
||||||
lparams.mirostat_eta = params.mirostat_eta;
|
|
||||||
|
|
||||||
auto * result = new gpt_sampler {
|
auto * result = new gpt_sampler {
|
||||||
/* .params = */ params,
|
/* .params = */ params,
|
||||||
|
@ -69,29 +66,39 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
||||||
/* .smpl = */ llama_sampler_init(model, lparams)
|
/* .smpl = */ llama_sampler_init(model, lparams)
|
||||||
};
|
};
|
||||||
|
|
||||||
for (const auto & cnstr : params.constraints) {
|
if (params.mirostat == 0) {
|
||||||
switch (cnstr) {
|
for (const auto & cnstr : params.constraints) {
|
||||||
case GPT_CONSTRAINT_TYPE_TOP_K:
|
switch (cnstr) {
|
||||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep));
|
case GPT_CONSTRAINT_TYPE_TOP_K:
|
||||||
break;
|
llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep));
|
||||||
case GPT_CONSTRAINT_TYPE_TOP_P:
|
break;
|
||||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep));
|
case GPT_CONSTRAINT_TYPE_TOP_P:
|
||||||
break;
|
llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep));
|
||||||
case GPT_CONSTRAINT_TYPE_MIN_P:
|
break;
|
||||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep));
|
case GPT_CONSTRAINT_TYPE_MIN_P:
|
||||||
break;
|
llama_sampler_add_constraint(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep));
|
||||||
case GPT_CONSTRAINT_TYPE_TFS_Z:
|
break;
|
||||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep));
|
case GPT_CONSTRAINT_TYPE_TFS_Z:
|
||||||
break;
|
llama_sampler_add_constraint(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep));
|
||||||
case GPT_CONSTRAINT_TYPE_TYPICAL_P:
|
break;
|
||||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep));
|
case GPT_CONSTRAINT_TYPE_TYPICAL_P:
|
||||||
break;
|
llama_sampler_add_constraint(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep));
|
||||||
case GPT_CONSTRAINT_TYPE_TEMPERATURE:
|
break;
|
||||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
case GPT_CONSTRAINT_TYPE_TEMPERATURE:
|
||||||
break;
|
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||||
default:
|
break;
|
||||||
GGML_ASSERT(false && "unknown constraint type");
|
default:
|
||||||
|
GGML_ASSERT(false && "unknown constraint type");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
} else if (params.mirostat == 1) {
|
||||||
|
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp));
|
||||||
|
llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta));
|
||||||
|
} else if (params.mirostat == 2) {
|
||||||
|
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp));
|
||||||
|
llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta));
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(false && "unknown mirostat version");
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -153,7 +160,6 @@ static llama_token gpt_sampler_sample(
|
||||||
struct llama_sampler * smpl,
|
struct llama_sampler * smpl,
|
||||||
struct llama_token_data_array * cur_p,
|
struct llama_token_data_array * cur_p,
|
||||||
float temp,
|
float temp,
|
||||||
int mirostat,
|
|
||||||
int n_probs) {
|
int n_probs) {
|
||||||
llama_token res = 0;
|
llama_token res = 0;
|
||||||
|
|
||||||
|
@ -167,24 +173,20 @@ static llama_token gpt_sampler_sample(
|
||||||
// apply all sampling constraints and then sample
|
// apply all sampling constraints and then sample
|
||||||
llama_sampler_apply(smpl, cur_p);
|
llama_sampler_apply(smpl, cur_p);
|
||||||
|
|
||||||
if (mirostat != 0) {
|
res = llama_sampler_sample_dist(smpl, cur_p);
|
||||||
res = llama_sampler_sample_mirostat(smpl, cur_p);
|
|
||||||
} else {
|
|
||||||
res = llama_sampler_sample_dist(smpl, cur_p);
|
|
||||||
|
|
||||||
//{
|
//{
|
||||||
// const int n_top = 10;
|
// const int n_top = 10;
|
||||||
// LOG("top %d candidates:\n", n_top);
|
// LOG("top %d candidates:\n", n_top);
|
||||||
|
|
||||||
// for (int i = 0; i < n_top; i++) {
|
// for (int i = 0; i < n_top; i++) {
|
||||||
// const llama_token id = cur_p.data[i].id;
|
// const llama_token id = cur_p.data[i].id;
|
||||||
// (void)id; // To avoid a warning that id is unused when logging is disabled.
|
// (void)id; // To avoid a warning that id is unused when logging is disabled.
|
||||||
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p);
|
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p);
|
||||||
// }
|
// }
|
||||||
//}
|
//}
|
||||||
|
|
||||||
//LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str());
|
//LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str());
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
|
@ -208,7 +210,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
|
||||||
llama_constraint_apply(pnlt, cur_p);
|
llama_constraint_apply(pnlt, cur_p);
|
||||||
|
|
||||||
// first, sample the token without any grammar constraints
|
// first, sample the token without any grammar constraints
|
||||||
const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.mirostat, params.n_probs);
|
const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.n_probs);
|
||||||
|
|
||||||
// check if it the sampled token fits the grammar
|
// check if it the sampled token fits the grammar
|
||||||
{
|
{
|
||||||
|
@ -231,7 +233,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
|
||||||
llama_constraint_apply(pnlt, cur_p);
|
llama_constraint_apply(pnlt, cur_p);
|
||||||
llama_constraint_apply(grmr, cur_p);
|
llama_constraint_apply(grmr, cur_p);
|
||||||
|
|
||||||
return gpt_sampler_sample(smpl, cur_p, params.temp, params.mirostat, params.n_probs);
|
return gpt_sampler_sample(smpl, cur_p, params.temp, params.n_probs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p) {
|
void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p) {
|
||||||
|
|
|
@ -369,16 +369,18 @@ extern "C" {
|
||||||
float bias;
|
float bias;
|
||||||
} llama_logit_bias;
|
} llama_logit_bias;
|
||||||
|
|
||||||
|
enum llama_sampler_type {
|
||||||
|
LLAMA_SAMPLER_TYPE_GREEDY = 0,
|
||||||
|
LLAMA_SAMPLER_TYPE_DIST = 1,
|
||||||
|
};
|
||||||
|
|
||||||
typedef struct llama_sampler_params {
|
typedef struct llama_sampler_params {
|
||||||
uint32_t seed; // the seed used to initialize the rng of the sampler
|
uint32_t seed; // the seed used to initialize the rng of the sampler
|
||||||
|
|
||||||
int32_t n_prev; // size of ring buffer to keep previous accepted tokens (needed for llama_sampler_prev_ API)
|
int32_t n_prev; // size of ring buffer to keep previous accepted tokens (needed for llama_sampler_prev_ API)
|
||||||
|
|
||||||
int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
// TODO: will be used by the llama_decode_with_sampler() API in the future
|
||||||
float mirostat_tau; // target entropy
|
enum llama_sampler_type type;
|
||||||
float mirostat_eta; // learning rate
|
|
||||||
|
|
||||||
// TODO: add type of sampler: greedy, dist, mirostat, etc.
|
|
||||||
} llama_sampler_params;
|
} llama_sampler_params;
|
||||||
|
|
||||||
// performance timing information
|
// performance timing information
|
||||||
|
@ -1005,17 +1007,18 @@ extern "C" {
|
||||||
//
|
//
|
||||||
// - Samplers
|
// - Samplers
|
||||||
// The llama_sampler samples a token based on the candidate token probabilities. Before the actual sampling, the
|
// The llama_sampler samples a token based on the candidate token probabilities. Before the actual sampling, the
|
||||||
// sampler can apply a sequence of constraints to the candidate tokens.
|
// sampler can apply a sequence of constraints in order to modify the probabilities of the candidates.
|
||||||
//
|
//
|
||||||
// The llama_sampler object contains the entire sampling information:
|
// The llama_sampler object contains the entire sampling information:
|
||||||
//
|
//
|
||||||
// - RNG state (seed and generator)
|
// - RNG state (seed and generator)
|
||||||
// - Custom set of constraints (see llama_sampler_add_constraint)
|
// - Custom set of constraints (see llama_sampler_add_constraint)
|
||||||
// - Sampling method (greedy, dist, mirostat)
|
// - Sampling method (greedy, dist)
|
||||||
// - Previous tokens
|
// - Previous tokens
|
||||||
//
|
//
|
||||||
// In the future, it will be utilized offload the sampling to the backends (e.g. GPU).
|
// In the future, it will be utilized offload the sampling to the backends (e.g. GPU).
|
||||||
//
|
//
|
||||||
|
// TODO: in the future, the entire API should be changed to accept llama_vocab, instead of llama_model
|
||||||
|
|
||||||
// constraints
|
// constraints
|
||||||
|
|
||||||
|
@ -1041,14 +1044,23 @@ extern "C" {
|
||||||
llama_constraint_context_t ctx;
|
llama_constraint_context_t ctx;
|
||||||
};
|
};
|
||||||
|
|
||||||
LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void);
|
LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void);
|
||||||
LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep);
|
LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep);
|
||||||
LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep);
|
LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep);
|
||||||
LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep);
|
LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep);
|
||||||
LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep);
|
LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep);
|
||||||
LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep);
|
LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep);
|
||||||
LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t);
|
LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t);
|
||||||
LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent);
|
LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent);
|
||||||
|
|
||||||
|
LLAMA_API struct llama_constraint * llama_constraint_init_mirostat(
|
||||||
|
const struct llama_model * model,
|
||||||
|
float tau,
|
||||||
|
float eta);
|
||||||
|
|
||||||
|
LLAMA_API struct llama_constraint * llama_constraint_init_mirostat_v2(
|
||||||
|
float tau,
|
||||||
|
float eta);
|
||||||
|
|
||||||
LLAMA_API struct llama_constraint * llama_constraint_init_grammar(
|
LLAMA_API struct llama_constraint * llama_constraint_init_grammar(
|
||||||
const struct llama_model * model,
|
const struct llama_model * model,
|
||||||
|
@ -1095,9 +1107,8 @@ extern "C" {
|
||||||
LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token);
|
LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token);
|
||||||
LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
||||||
|
|
||||||
LLAMA_API llama_token llama_sampler_sample_dist (struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
LLAMA_API llama_token llama_sampler_sample_dist (struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
||||||
LLAMA_API llama_token llama_sampler_sample_greedy (struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs);
|
LLAMA_API llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs);
|
||||||
LLAMA_API llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
|
||||||
|
|
||||||
/// @details Get the number of accepted tokens so far (max of n_prev)
|
/// @details Get the number of accepted tokens so far (max of n_prev)
|
||||||
LLAMA_API int llama_sampler_n_prev(const struct llama_sampler * smpl);
|
LLAMA_API int llama_sampler_n_prev(const struct llama_sampler * smpl);
|
||||||
|
|
|
@ -450,8 +450,8 @@ struct llama_constraint * llama_constraint_init_softmax_impl() {
|
||||||
// top-k
|
// top-k
|
||||||
|
|
||||||
struct llama_constraint_context_top_k {
|
struct llama_constraint_context_top_k {
|
||||||
int32_t k;
|
const int32_t k;
|
||||||
size_t min_keep;
|
const size_t min_keep;
|
||||||
};
|
};
|
||||||
|
|
||||||
static struct llama_constraint_i llama_constraint_top_k_i = {
|
static struct llama_constraint_i llama_constraint_top_k_i = {
|
||||||
|
@ -486,8 +486,8 @@ struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min
|
||||||
// top-p
|
// top-p
|
||||||
|
|
||||||
struct llama_constraint_context_top_p {
|
struct llama_constraint_context_top_p {
|
||||||
float p;
|
const float p;
|
||||||
size_t min_keep;
|
const size_t min_keep;
|
||||||
};
|
};
|
||||||
|
|
||||||
static struct llama_constraint_i llama_constraint_top_p_i = {
|
static struct llama_constraint_i llama_constraint_top_p_i = {
|
||||||
|
@ -522,8 +522,8 @@ struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_k
|
||||||
// min-p
|
// min-p
|
||||||
|
|
||||||
struct llama_constraint_context_min_p {
|
struct llama_constraint_context_min_p {
|
||||||
float p;
|
const float p;
|
||||||
size_t min_keep;
|
const size_t min_keep;
|
||||||
};
|
};
|
||||||
|
|
||||||
static struct llama_constraint_i llama_constraint_min_p_i = {
|
static struct llama_constraint_i llama_constraint_min_p_i = {
|
||||||
|
@ -558,8 +558,8 @@ struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_k
|
||||||
// tail-free
|
// tail-free
|
||||||
|
|
||||||
struct llama_constraint_context_tail_free {
|
struct llama_constraint_context_tail_free {
|
||||||
float z;
|
const float z;
|
||||||
size_t min_keep;
|
const size_t min_keep;
|
||||||
};
|
};
|
||||||
|
|
||||||
static struct llama_constraint_i llama_constraint_tail_free_i = {
|
static struct llama_constraint_i llama_constraint_tail_free_i = {
|
||||||
|
@ -594,8 +594,8 @@ struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t m
|
||||||
// typical
|
// typical
|
||||||
|
|
||||||
struct llama_constraint_context_typical {
|
struct llama_constraint_context_typical {
|
||||||
float p;
|
const float p;
|
||||||
size_t min_keep;
|
const size_t min_keep;
|
||||||
};
|
};
|
||||||
|
|
||||||
static struct llama_constraint_i llama_constraint_typical_i = {
|
static struct llama_constraint_i llama_constraint_typical_i = {
|
||||||
|
@ -630,7 +630,7 @@ struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min
|
||||||
// temp
|
// temp
|
||||||
|
|
||||||
struct llama_constraint_context_temp {
|
struct llama_constraint_context_temp {
|
||||||
float temp;
|
const float temp;
|
||||||
};
|
};
|
||||||
|
|
||||||
static struct llama_constraint_i llama_constraint_temp_i = {
|
static struct llama_constraint_i llama_constraint_temp_i = {
|
||||||
|
@ -664,9 +664,9 @@ struct llama_constraint * llama_constraint_init_temp_impl(float temp) {
|
||||||
// temp-ext
|
// temp-ext
|
||||||
|
|
||||||
struct llama_constraint_context_temp_ext {
|
struct llama_constraint_context_temp_ext {
|
||||||
float temp;
|
const float temp;
|
||||||
float delta;
|
const float delta;
|
||||||
float exponent;
|
const float exponent;
|
||||||
};
|
};
|
||||||
|
|
||||||
static struct llama_constraint_i llama_constraint_temp_ext_i = {
|
static struct llama_constraint_i llama_constraint_temp_ext_i = {
|
||||||
|
@ -706,6 +706,176 @@ struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mirostat
|
||||||
|
|
||||||
|
struct llama_constraint_context_mirostat {
|
||||||
|
const struct llama_vocab * vocab;
|
||||||
|
|
||||||
|
const float tau;
|
||||||
|
const float eta;
|
||||||
|
|
||||||
|
const int32_t m;
|
||||||
|
|
||||||
|
float mu;
|
||||||
|
|
||||||
|
std::vector<llama_token_data> cur;
|
||||||
|
};
|
||||||
|
|
||||||
|
static struct llama_constraint_i llama_constraint_mirostat_i = {
|
||||||
|
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "mirostat"; },
|
||||||
|
/* .accept = */ [](struct llama_constraint * cnstr, llama_token token) {
|
||||||
|
auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx;
|
||||||
|
|
||||||
|
int32_t idx = -1;
|
||||||
|
for (size_t i = 0; i < ctx->cur.size(); ++i) {
|
||||||
|
if (ctx->cur[i].id == token) {
|
||||||
|
idx = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float observed_surprise = -log2f(ctx->cur[idx].p);
|
||||||
|
float e = observed_surprise - ctx->tau;
|
||||||
|
|
||||||
|
// Update mu using the learning rate and error
|
||||||
|
ctx->mu = ctx->mu - ctx->eta * e;
|
||||||
|
},
|
||||||
|
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
|
||||||
|
auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx;
|
||||||
|
|
||||||
|
llama_constraint_softmax_impl(cur_p);
|
||||||
|
|
||||||
|
// Estimate s_hat using the most probable m tokens
|
||||||
|
float s_hat = 0.0;
|
||||||
|
float sum_ti_bi = 0.0;
|
||||||
|
float sum_ti_sq = 0.0;
|
||||||
|
for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
|
||||||
|
float t_i = logf(float(i + 2) / float(i + 1));
|
||||||
|
float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
|
||||||
|
sum_ti_bi += t_i * b_i;
|
||||||
|
sum_ti_sq += t_i * t_i;
|
||||||
|
}
|
||||||
|
s_hat = sum_ti_bi / sum_ti_sq;
|
||||||
|
|
||||||
|
// Compute k from the estimated s_hat and target surprise value
|
||||||
|
float epsilon_hat = s_hat - 1;
|
||||||
|
float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->vocab->n_vocab, -epsilon_hat)), 1 / s_hat);
|
||||||
|
|
||||||
|
llama_constraint_top_k_impl(cur_p, int(k), 1);
|
||||||
|
|
||||||
|
// remember the order to be able to compute the distance later when accepting the token
|
||||||
|
ctx->cur.resize(cur_p->size);
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
ctx->cur[i] = cur_p->data[i];
|
||||||
|
}
|
||||||
|
},
|
||||||
|
/* .reset = */ [](struct llama_constraint * cnstr) {
|
||||||
|
auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx;
|
||||||
|
ctx->mu = 0.0f;
|
||||||
|
},
|
||||||
|
/* .copy = */ [](const struct llama_constraint * cnstr) {
|
||||||
|
const auto * ctx = (const llama_constraint_context_mirostat *) cnstr->ctx;
|
||||||
|
return llama_constraint_init_mirostat_impl(*ctx->vocab, ctx->tau, ctx->eta, ctx->m);
|
||||||
|
},
|
||||||
|
/* .free = */ [](struct llama_constraint * cnstr) {
|
||||||
|
delete (llama_constraint_context_mirostat *) cnstr->ctx;
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_constraint * llama_constraint_init_mirostat_impl(
|
||||||
|
const struct llama_vocab & vocab,
|
||||||
|
float tau,
|
||||||
|
float eta,
|
||||||
|
int32_t m) {
|
||||||
|
struct llama_constraint * result = new llama_constraint {
|
||||||
|
/* .iface = */ &llama_constraint_mirostat_i,
|
||||||
|
/* .ctx = */ new llama_constraint_context_mirostat {
|
||||||
|
/*.vocab =*/ &vocab,
|
||||||
|
/*.tau =*/ tau,
|
||||||
|
/*.eta =*/ eta,
|
||||||
|
/*.m =*/ m,
|
||||||
|
/*.mu =*/ 0.0f,
|
||||||
|
/*.cur =*/ {},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// mirostat v2
|
||||||
|
|
||||||
|
struct llama_constraint_context_mirostat_v2 {
|
||||||
|
const float tau;
|
||||||
|
const float eta;
|
||||||
|
|
||||||
|
float mu;
|
||||||
|
|
||||||
|
std::vector<llama_token_data> cur;
|
||||||
|
};
|
||||||
|
|
||||||
|
static struct llama_constraint_i llama_constraint_mirostat_v2_i = {
|
||||||
|
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "mirostat-v2"; },
|
||||||
|
/* .accept = */ [](struct llama_constraint * cnstr, llama_token token) {
|
||||||
|
auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx;
|
||||||
|
|
||||||
|
int32_t idx = -1;
|
||||||
|
for (size_t i = 0; i < ctx->cur.size(); ++i) {
|
||||||
|
if (ctx->cur[i].id == token) {
|
||||||
|
idx = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float observed_surprise = -log2f(ctx->cur[idx].p);
|
||||||
|
float e = observed_surprise - ctx->tau;
|
||||||
|
|
||||||
|
// Update mu using the learning rate and error
|
||||||
|
ctx->mu = ctx->mu - ctx->eta * e;
|
||||||
|
},
|
||||||
|
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
|
||||||
|
auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx;
|
||||||
|
|
||||||
|
llama_constraint_softmax_impl(cur_p);
|
||||||
|
|
||||||
|
// Truncate the words with surprise values greater than mu
|
||||||
|
cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
|
||||||
|
return -log2f(candidate.p) > ctx->mu;
|
||||||
|
}));
|
||||||
|
|
||||||
|
if (cur_p->size == 0) {
|
||||||
|
cur_p->size = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize the probabilities of the remaining words
|
||||||
|
llama_constraint_softmax_impl(cur_p);
|
||||||
|
},
|
||||||
|
/* .reset = */ [](struct llama_constraint * cnstr) {
|
||||||
|
auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx;
|
||||||
|
ctx->mu = 0.0f;
|
||||||
|
},
|
||||||
|
/* .copy = */ [](const struct llama_constraint * cnstr) {
|
||||||
|
const auto * ctx = (const llama_constraint_context_mirostat_v2 *) cnstr->ctx;
|
||||||
|
return llama_constraint_init_mirostat_v2_impl(ctx->tau, ctx->eta);
|
||||||
|
},
|
||||||
|
/* .free = */ [](struct llama_constraint * cnstr) {
|
||||||
|
delete (llama_constraint_context_mirostat_v2 *) cnstr->ctx;
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, float eta) {
|
||||||
|
struct llama_constraint * result = new llama_constraint {
|
||||||
|
/* .iface = */ &llama_constraint_mirostat_v2_i,
|
||||||
|
/* .ctx = */ new llama_constraint_context_mirostat_v2 {
|
||||||
|
/*.tau =*/ tau,
|
||||||
|
/*.eta =*/ eta,
|
||||||
|
/*.mu =*/ 0.0f,
|
||||||
|
/*.cur =*/ {},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// grammar
|
// grammar
|
||||||
|
|
||||||
struct llama_constraint_context_grammar {
|
struct llama_constraint_context_grammar {
|
||||||
|
@ -796,13 +966,13 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_
|
||||||
struct llama_constraint_context_penalties {
|
struct llama_constraint_context_penalties {
|
||||||
const struct llama_vocab * vocab;
|
const struct llama_vocab * vocab;
|
||||||
|
|
||||||
int32_t penalty_last_n;
|
const int32_t penalty_last_n;
|
||||||
float penalty_repeat;
|
const float penalty_repeat;
|
||||||
float penalty_freq;
|
const float penalty_freq;
|
||||||
float penalty_present;
|
const float penalty_present;
|
||||||
|
|
||||||
bool penalize_nl;
|
const bool penalize_nl;
|
||||||
bool ignore_eos;
|
const bool ignore_eos;
|
||||||
|
|
||||||
ring_buffer<llama_token> prev;
|
ring_buffer<llama_token> prev;
|
||||||
};
|
};
|
||||||
|
@ -980,7 +1150,6 @@ struct llama_sampler * llama_sampler_init_impl(const struct llama_vocab & vocab,
|
||||||
|
|
||||||
/* .rng = */ std::mt19937(params.seed),
|
/* .rng = */ std::mt19937(params.seed),
|
||||||
|
|
||||||
/* .mirostat_mu = */ 0.0f,
|
|
||||||
/* .prev = */ { (size_t) params.n_prev },
|
/* .prev = */ { (size_t) params.n_prev },
|
||||||
/* .constraints = */ {},
|
/* .constraints = */ {},
|
||||||
/* .cur = */ {},
|
/* .cur = */ {},
|
||||||
|
@ -1011,7 +1180,6 @@ struct llama_sampler * llama_sampler_cp_impl(const struct llama_sampler & smpl)
|
||||||
|
|
||||||
/* .rng = */ smpl.rng,
|
/* .rng = */ smpl.rng,
|
||||||
|
|
||||||
/* .mirostat_mu = */ smpl.mirostat_mu,
|
|
||||||
/* .prev = */ smpl.prev,
|
/* .prev = */ smpl.prev,
|
||||||
/* .constraints = */ {},
|
/* .constraints = */ {},
|
||||||
/* .cur = */ {},
|
/* .cur = */ {},
|
||||||
|
@ -1077,74 +1245,6 @@ int llama_sampler_n_prev_impl(const struct llama_sampler & smpl) {
|
||||||
return smpl.prev.size();
|
return smpl.prev.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token llama_sampler_sample_mirostat_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) {
|
|
||||||
llama_constraint_softmax_impl(cur_p);
|
|
||||||
|
|
||||||
// Estimate s_hat using the most probable m tokens
|
|
||||||
float s_hat = 0.0;
|
|
||||||
float sum_ti_bi = 0.0;
|
|
||||||
float sum_ti_sq = 0.0;
|
|
||||||
for (size_t i = 0; i < size_t(m - 1) && i < cur_p->size - 1; ++i) {
|
|
||||||
float t_i = logf(float(i + 2) / float(i + 1));
|
|
||||||
float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
|
|
||||||
sum_ti_bi += t_i * b_i;
|
|
||||||
sum_ti_sq += t_i * t_i;
|
|
||||||
}
|
|
||||||
s_hat = sum_ti_bi / sum_ti_sq;
|
|
||||||
|
|
||||||
// Compute k from the estimated s_hat and target surprise value
|
|
||||||
float epsilon_hat = s_hat - 1;
|
|
||||||
float k = powf((epsilon_hat * powf(2, mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
|
|
||||||
|
|
||||||
// Sample the next word X using top-k sampling
|
|
||||||
llama_constraint_top_k_impl(cur_p, int(k), 1);
|
|
||||||
llama_token X = llama_sampler_sample_dist_impl(cur_p, rng);
|
|
||||||
|
|
||||||
// Compute error as the difference between observed surprise and target surprise value
|
|
||||||
size_t X_idx = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
|
|
||||||
return candidate.id == X;
|
|
||||||
}));
|
|
||||||
float observed_surprise = -log2f(cur_p->data[X_idx].p);
|
|
||||||
float e = observed_surprise - tau;
|
|
||||||
|
|
||||||
// Update mu using the learning rate and error
|
|
||||||
mu = mu - eta * e;
|
|
||||||
|
|
||||||
return X;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, float tau, float eta, float & mu) {
|
|
||||||
llama_constraint_softmax_impl(cur_p);
|
|
||||||
|
|
||||||
// Truncate the words with surprise values greater than mu
|
|
||||||
cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
|
|
||||||
return -log2f(candidate.p) > mu;
|
|
||||||
}));
|
|
||||||
|
|
||||||
if (cur_p->size == 0) {
|
|
||||||
cur_p->size = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Normalize the probabilities of the remaining words
|
|
||||||
llama_constraint_softmax_impl(cur_p);
|
|
||||||
|
|
||||||
// Sample the next word X from the remaining words
|
|
||||||
llama_token X = llama_sampler_sample_dist_impl(cur_p, rng);
|
|
||||||
|
|
||||||
// Compute error as the difference between observed surprise and target surprise value
|
|
||||||
size_t X_idx = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
|
|
||||||
return candidate.id == X;
|
|
||||||
}));
|
|
||||||
|
|
||||||
float observed_surprise = -log2f(cur_p->data[X_idx].p);
|
|
||||||
float e = observed_surprise - tau;
|
|
||||||
|
|
||||||
// Update mu using the learning rate and error
|
|
||||||
mu = mu - eta * e;
|
|
||||||
|
|
||||||
return X;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token llama_sampler_sample_greedy_impl(llama_token_data_array * cur_p, bool probs) {
|
llama_token llama_sampler_sample_greedy_impl(llama_token_data_array * cur_p, bool probs) {
|
||||||
if (probs) {
|
if (probs) {
|
||||||
// if probs are needed, we apply softmax to get the probabilities
|
// if probs are needed, we apply softmax to get the probabilities
|
||||||
|
|
|
@ -20,16 +20,38 @@ void llama_constraint_penalties_impl(
|
||||||
|
|
||||||
// constraints
|
// constraints
|
||||||
|
|
||||||
struct llama_constraint * llama_constraint_init_softmax_impl ();
|
struct llama_constraint * llama_constraint_init_softmax_impl ();
|
||||||
struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k, size_t min_keep);
|
struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k, size_t min_keep);
|
||||||
struct llama_constraint * llama_constraint_init_top_p_impl (float p, size_t min_keep);
|
struct llama_constraint * llama_constraint_init_top_p_impl (float p, size_t min_keep);
|
||||||
struct llama_constraint * llama_constraint_init_min_p_impl (float p, size_t min_keep);
|
struct llama_constraint * llama_constraint_init_min_p_impl (float p, size_t min_keep);
|
||||||
struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t min_keep);
|
struct llama_constraint * llama_constraint_init_tail_free_impl (float z, size_t min_keep);
|
||||||
struct llama_constraint * llama_constraint_init_typical_impl (float p, size_t min_keep);
|
struct llama_constraint * llama_constraint_init_typical_impl (float p, size_t min_keep);
|
||||||
struct llama_constraint * llama_constraint_init_temp_impl (float t);
|
struct llama_constraint * llama_constraint_init_temp_impl (float t);
|
||||||
struct llama_constraint * llama_constraint_init_temp_ext_impl (float t, float delta, float exponent);
|
struct llama_constraint * llama_constraint_init_temp_ext_impl (float t, float delta, float exponent);
|
||||||
|
|
||||||
struct llama_constraint * llama_constraint_init_grammar_impl (
|
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||||
|
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||||
|
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||||
|
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||||
|
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
|
||||||
|
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||||
|
|
||||||
|
struct llama_constraint * llama_constraint_init_mirostat_impl(
|
||||||
|
const struct llama_vocab & vocab,
|
||||||
|
float tau,
|
||||||
|
float eta,
|
||||||
|
int32_t m);
|
||||||
|
|
||||||
|
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||||
|
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||||
|
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||||
|
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||||
|
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||||
|
struct llama_constraint * llama_constraint_init_mirostat_v2_impl(
|
||||||
|
float tau,
|
||||||
|
float eta);
|
||||||
|
|
||||||
|
struct llama_constraint * llama_constraint_init_grammar_impl(
|
||||||
const struct llama_vocab & vocab,
|
const struct llama_vocab & vocab,
|
||||||
const char * grammar_str,
|
const char * grammar_str,
|
||||||
const char * grammar_root);
|
const char * grammar_root);
|
||||||
|
@ -67,8 +89,6 @@ struct llama_sampler {
|
||||||
|
|
||||||
std::mt19937 rng;
|
std::mt19937 rng;
|
||||||
|
|
||||||
float mirostat_mu;
|
|
||||||
|
|
||||||
ring_buffer<llama_token> prev;
|
ring_buffer<llama_token> prev;
|
||||||
|
|
||||||
std::vector<llama_constraint *> constraints;
|
std::vector<llama_constraint *> constraints;
|
||||||
|
@ -97,20 +117,5 @@ void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_d
|
||||||
llama_token llama_sampler_prev_impl (const struct llama_sampler & smpl, int ith);
|
llama_token llama_sampler_prev_impl (const struct llama_sampler & smpl, int ith);
|
||||||
int llama_sampler_n_prev_impl(const struct llama_sampler & smpl);
|
int llama_sampler_n_prev_impl(const struct llama_sampler & smpl);
|
||||||
|
|
||||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
|
||||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
|
||||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
|
||||||
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
|
||||||
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
|
|
||||||
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
|
||||||
llama_token llama_sampler_sample_mirostat_impl (struct llama_token_data_array * cur_p, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu);
|
|
||||||
|
|
||||||
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
|
||||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
|
||||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
|
||||||
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
|
||||||
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
|
||||||
llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, float tau, float eta, float & mu);
|
|
||||||
|
|
||||||
llama_token llama_sampler_sample_greedy_impl(struct llama_token_data_array * cur_p, bool probs);
|
llama_token llama_sampler_sample_greedy_impl(struct llama_token_data_array * cur_p, bool probs);
|
||||||
llama_token llama_sampler_sample_dist_impl (struct llama_token_data_array * cur_p, std::mt19937 & rng);
|
llama_token llama_sampler_sample_dist_impl (struct llama_token_data_array * cur_p, std::mt19937 & rng);
|
||||||
|
|
|
@ -17939,9 +17939,7 @@ struct llama_sampler_params llama_sampler_default_params() {
|
||||||
struct llama_sampler_params result = {
|
struct llama_sampler_params result = {
|
||||||
/*.seed =*/ LLAMA_DEFAULT_SEED,
|
/*.seed =*/ LLAMA_DEFAULT_SEED,
|
||||||
/*.n_prev =*/ 256,
|
/*.n_prev =*/ 256,
|
||||||
/*.mirostat =*/ 0,
|
/*.type =*/ LLAMA_SAMPLER_TYPE_GREEDY,
|
||||||
/*.mirostat_tau =*/ 5.00f,
|
|
||||||
/*.mirostat_eta =*/ 0.10f,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -20641,6 +20639,14 @@ struct llama_constraint * llama_constraint_init_temp_ext(float temp, float delta
|
||||||
return llama_constraint_init_temp_ext_impl(temp, delta, exponent);
|
return llama_constraint_init_temp_ext_impl(temp, delta, exponent);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct llama_constraint * llama_constraint_init_mirostat(const struct llama_model * model, float tau, float eta) {
|
||||||
|
return llama_constraint_init_mirostat_impl(model->vocab, tau, eta, 100);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct llama_constraint * llama_constraint_init_mirostat_v2(float tau, float eta) {
|
||||||
|
return llama_constraint_init_mirostat_v2_impl(tau, eta);
|
||||||
|
}
|
||||||
|
|
||||||
struct llama_constraint * llama_constraint_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {
|
struct llama_constraint * llama_constraint_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {
|
||||||
return llama_constraint_init_grammar_impl(model->vocab, grammar_str, grammar_root);
|
return llama_constraint_init_grammar_impl(model->vocab, grammar_str, grammar_root);
|
||||||
}
|
}
|
||||||
|
@ -20741,40 +20747,6 @@ void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * c
|
||||||
llama_sampler_apply_impl(*smpl, cur_p);
|
llama_sampler_apply_impl(*smpl, cur_p);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
||||||
time_meas tm(smpl->t_sample_us);
|
|
||||||
|
|
||||||
if (cur_p == nullptr) {
|
|
||||||
cur_p = &smpl->cur_p;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto type = smpl->params.mirostat;
|
|
||||||
|
|
||||||
llama_token res;
|
|
||||||
|
|
||||||
if (type == 1) {
|
|
||||||
res = llama_sampler_sample_mirostat_impl(cur_p,
|
|
||||||
smpl->rng,
|
|
||||||
smpl->params.mirostat_tau,
|
|
||||||
smpl->params.mirostat_eta,
|
|
||||||
100,
|
|
||||||
smpl->vocab->n_vocab,
|
|
||||||
smpl->mirostat_mu);
|
|
||||||
} else if (type == 2) {
|
|
||||||
res = llama_sampler_sample_mirostat_v2_impl(cur_p,
|
|
||||||
smpl->rng,
|
|
||||||
smpl->params.mirostat_tau,
|
|
||||||
smpl->params.mirostat_eta,
|
|
||||||
smpl->mirostat_mu);
|
|
||||||
} else {
|
|
||||||
GGML_ABORT("invalid mirostat type: %d", type);
|
|
||||||
}
|
|
||||||
|
|
||||||
smpl->n_sample++;
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs) {
|
llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs) {
|
||||||
time_meas tm(smpl->t_sample_us);
|
time_meas tm(smpl->t_sample_us);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue