sampling : improve mirostat implementation
ggml-ci
This commit is contained in:
parent
bd88352834
commit
82a89df960
7 changed files with 95 additions and 88 deletions
|
@ -121,7 +121,7 @@ struct gpt_sampler {
|
||||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||||
}
|
}
|
||||||
|
|
||||||
cur_p = { cur.data(), cur.size(), LLAMA_TOKEN_NULL, false };
|
cur_p = { cur.data(), cur.size(), -1, false };
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -202,17 +202,17 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
||||||
GGML_ASSERT(false && "unknown sampler type");
|
GGML_ASSERT(false && "unknown sampler type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
|
||||||
|
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
||||||
} else if (params.mirostat == 1) {
|
} else if (params.mirostat == 1) {
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(model, params.mirostat_tau, params.mirostat_eta));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(model, params.seed, params.mirostat_tau, params.mirostat_eta));
|
||||||
} else if (params.mirostat == 2) {
|
} else if (params.mirostat == 2) {
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(false && "unknown mirostat version");
|
GGML_ASSERT(false && "unknown mirostat version");
|
||||||
}
|
}
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
|
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
|
||||||
} else {
|
} else {
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
|
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
|
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
|
||||||
|
@ -246,8 +246,8 @@ struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar) {
|
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||||
if (apply_grammar) {
|
if (accept_grammar) {
|
||||||
llama_sampler_accept(gsmpl->grmr, token);
|
llama_sampler_accept(gsmpl->grmr, token);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -293,9 +293,9 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
|
||||||
|
|
||||||
llama_sampler_apply(chain, &cur_p);
|
llama_sampler_apply(chain, &cur_p);
|
||||||
|
|
||||||
const llama_token id = cur_p.data[cur_p.selected].id;
|
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
||||||
|
|
||||||
GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - check your sampling configuration");
|
const llama_token id = cur_p.data[cur_p.selected].id;
|
||||||
|
|
||||||
if (grammar_first) {
|
if (grammar_first) {
|
||||||
return id;
|
return id;
|
||||||
|
@ -304,7 +304,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
|
||||||
// check if it the sampled token fits the grammar
|
// check if it the sampled token fits the grammar
|
||||||
{
|
{
|
||||||
llama_token_data single_token_data = { id, 1.0f, 0.0f };
|
llama_token_data single_token_data = { id, 1.0f, 0.0f };
|
||||||
llama_token_data_array single_token_data_array = { &single_token_data, 1, LLAMA_TOKEN_NULL, false };
|
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
|
||||||
|
|
||||||
llama_sampler_apply(grmr, &single_token_data_array);
|
llama_sampler_apply(grmr, &single_token_data_array);
|
||||||
|
|
||||||
|
@ -324,7 +324,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
|
||||||
|
|
||||||
llama_sampler_apply(chain, &cur_p);
|
llama_sampler_apply(chain, &cur_p);
|
||||||
|
|
||||||
GGML_ASSERT(cur_p.data[cur_p.selected].id != LLAMA_TOKEN_NULL && "null token in the sampling history - check your sampling configuration");
|
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
||||||
|
|
||||||
return cur_p.data[cur_p.selected].id;
|
return cur_p.data[cur_p.selected].id;
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,7 +70,7 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl);
|
||||||
|
|
||||||
struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl);
|
struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl);
|
||||||
|
|
||||||
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar);
|
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar);
|
||||||
void gpt_sampler_reset (struct gpt_sampler * gsmpl);
|
void gpt_sampler_reset (struct gpt_sampler * gsmpl);
|
||||||
|
|
||||||
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);
|
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);
|
||||||
|
|
|
@ -1068,6 +1068,7 @@ extern "C" {
|
||||||
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_mirostat(
|
LLAMA_API struct llama_sampler * llama_sampler_init_mirostat(
|
||||||
const struct llama_model * model,
|
const struct llama_model * model,
|
||||||
|
uint32_t seed,
|
||||||
float tau,
|
float tau,
|
||||||
float eta);
|
float eta);
|
||||||
|
|
||||||
|
@ -1077,6 +1078,7 @@ extern "C" {
|
||||||
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||||
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_mirostat_v2(
|
LLAMA_API struct llama_sampler * llama_sampler_init_mirostat_v2(
|
||||||
|
uint32_t seed,
|
||||||
float tau,
|
float tau,
|
||||||
float eta);
|
float eta);
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,17 @@
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
|
static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng, std::vector<float> & probs) {
|
||||||
|
probs.resize(cur_p->size);
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
probs[i] = cur_p->data[i].p;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::discrete_distribution<size_t> dist(probs.begin(), probs.end());
|
||||||
|
|
||||||
|
return dist(rng);
|
||||||
|
}
|
||||||
|
|
||||||
static void llama_log_softmax(float * array, size_t size) {
|
static void llama_log_softmax(float * array, size_t size) {
|
||||||
float max_l = *std::max_element(array, array + size);
|
float max_l = *std::max_element(array, array + size);
|
||||||
float sum = 0.f;
|
float sum = 0.f;
|
||||||
|
@ -456,6 +467,8 @@ struct llama_sampler_context_dist {
|
||||||
const uint32_t seed;
|
const uint32_t seed;
|
||||||
|
|
||||||
std::mt19937 rng;
|
std::mt19937 rng;
|
||||||
|
|
||||||
|
std::vector<float> probs; // work array
|
||||||
};
|
};
|
||||||
|
|
||||||
static struct llama_sampler_i llama_sampler_dist_i = {
|
static struct llama_sampler_i llama_sampler_dist_i = {
|
||||||
|
@ -463,15 +476,7 @@ static struct llama_sampler_i llama_sampler_dist_i = {
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
auto * ctx = (llama_sampler_context_dist *) smpl->ctx;
|
auto * ctx = (llama_sampler_context_dist *) smpl->ctx;
|
||||||
std::vector<float> probs;
|
cur_p->selected = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
|
||||||
probs.reserve(cur_p->size);
|
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
||||||
probs.push_back(cur_p->data[i].p);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::discrete_distribution<size_t> dist(probs.begin(), probs.end());
|
|
||||||
|
|
||||||
cur_p->selected = dist(ctx->rng);
|
|
||||||
},
|
},
|
||||||
/* .reset = */ nullptr,
|
/* .reset = */ nullptr,
|
||||||
/* .clone = */ [](const struct llama_sampler * smpl) {
|
/* .clone = */ [](const struct llama_sampler * smpl) {
|
||||||
|
@ -489,6 +494,7 @@ struct llama_sampler * llama_sampler_init_dist_impl(uint32_t seed) {
|
||||||
/* .ctx = */ new llama_sampler_context_dist {
|
/* .ctx = */ new llama_sampler_context_dist {
|
||||||
/* .seed = */ seed,
|
/* .seed = */ seed,
|
||||||
/* .rng = */ std::mt19937(seed),
|
/* .rng = */ std::mt19937(seed),
|
||||||
|
/* .probs = */ {},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -761,6 +767,8 @@ struct llama_sampler * llama_sampler_init_temp_ext_impl(float temp, float delta,
|
||||||
struct llama_sampler_context_mirostat {
|
struct llama_sampler_context_mirostat {
|
||||||
const struct llama_vocab * vocab;
|
const struct llama_vocab * vocab;
|
||||||
|
|
||||||
|
const uint32_t seed;
|
||||||
|
|
||||||
const float tau;
|
const float tau;
|
||||||
const float eta;
|
const float eta;
|
||||||
|
|
||||||
|
@ -768,28 +776,14 @@ struct llama_sampler_context_mirostat {
|
||||||
|
|
||||||
float mu;
|
float mu;
|
||||||
|
|
||||||
std::vector<llama_token_data> cur;
|
std::mt19937 rng;
|
||||||
|
|
||||||
|
std::vector<float> probs;
|
||||||
};
|
};
|
||||||
|
|
||||||
static struct llama_sampler_i llama_sampler_mirostat_i = {
|
static struct llama_sampler_i llama_sampler_mirostat_i = {
|
||||||
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat"; },
|
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat"; },
|
||||||
/* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
|
/* .accept = */ nullptr,
|
||||||
auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx;
|
|
||||||
|
|
||||||
int32_t idx = -1;
|
|
||||||
for (size_t i = 0; i < ctx->cur.size(); ++i) {
|
|
||||||
if (ctx->cur[i].id == token) {
|
|
||||||
idx = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
float observed_surprise = -log2f(ctx->cur[idx].p);
|
|
||||||
float e = observed_surprise - ctx->tau;
|
|
||||||
|
|
||||||
// Update mu using the learning rate and error
|
|
||||||
ctx->mu = ctx->mu - ctx->eta * e;
|
|
||||||
},
|
|
||||||
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx;
|
auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx;
|
||||||
|
|
||||||
|
@ -812,36 +806,44 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
|
||||||
float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->vocab->n_vocab, -epsilon_hat)), 1 / s_hat);
|
float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->vocab->n_vocab, -epsilon_hat)), 1 / s_hat);
|
||||||
|
|
||||||
llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
|
llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
|
||||||
|
llama_sampler_softmax_impl(cur_p);
|
||||||
|
|
||||||
// remember the order to be able to compute the distance later when accepting the token
|
const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
|
||||||
ctx->cur.resize(cur_p->size);
|
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
cur_p->selected = idx;
|
||||||
ctx->cur[i] = cur_p->data[i];
|
|
||||||
}
|
float observed_surprise = -log2f(cur_p->data[idx].p);
|
||||||
|
float e = observed_surprise - ctx->tau;
|
||||||
|
|
||||||
|
// Update mu using the learning rate and error
|
||||||
|
ctx->mu = ctx->mu - ctx->eta * e;
|
||||||
},
|
},
|
||||||
/* .reset = */ [](struct llama_sampler * smpl) {
|
/* .reset = */ [](struct llama_sampler * smpl) {
|
||||||
auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx;
|
auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx;
|
||||||
ctx->mu = 2.0f*ctx->tau;
|
ctx->mu = 2.0f*ctx->tau;
|
||||||
|
ctx->rng = std::mt19937(ctx->seed);
|
||||||
},
|
},
|
||||||
/* .clone = */ [](const struct llama_sampler * smpl) {
|
/* .clone = */ [](const struct llama_sampler * smpl) {
|
||||||
const auto * ctx = (const llama_sampler_context_mirostat *) smpl->ctx;
|
const auto * ctx = (const llama_sampler_context_mirostat *) smpl->ctx;
|
||||||
return llama_sampler_init_mirostat_impl(*ctx->vocab, ctx->tau, ctx->eta, ctx->m);
|
return llama_sampler_init_mirostat_impl(*ctx->vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
|
||||||
},
|
},
|
||||||
/* .free = */ [](struct llama_sampler * smpl) {
|
/* .free = */ [](struct llama_sampler * smpl) {
|
||||||
delete (llama_sampler_context_mirostat *) smpl->ctx;
|
delete (llama_sampler_context_mirostat *) smpl->ctx;
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_mirostat_impl(const struct llama_vocab & vocab, float tau, float eta, int32_t m) {
|
struct llama_sampler * llama_sampler_init_mirostat_impl(const struct llama_vocab & vocab, uint32_t seed, float tau, float eta, int32_t m) {
|
||||||
return new llama_sampler {
|
return new llama_sampler {
|
||||||
/* .iface = */ &llama_sampler_mirostat_i,
|
/* .iface = */ &llama_sampler_mirostat_i,
|
||||||
/* .ctx = */ new llama_sampler_context_mirostat {
|
/* .ctx = */ new llama_sampler_context_mirostat {
|
||||||
/* .vocab = */ &vocab,
|
/* .vocab = */ &vocab,
|
||||||
|
/* .seed = */ seed,
|
||||||
/* .tau = */ tau,
|
/* .tau = */ tau,
|
||||||
/* .eta = */ eta,
|
/* .eta = */ eta,
|
||||||
/* .m = */ m,
|
/* .m = */ m,
|
||||||
/* .mu = */ 2.0f*tau,
|
/* .mu = */ 2.0f*tau,
|
||||||
/* .cur = */ {},
|
/* .rng = */ std::mt19937(seed),
|
||||||
|
/* .probs = */ {},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -849,33 +851,21 @@ struct llama_sampler * llama_sampler_init_mirostat_impl(const struct llama_vocab
|
||||||
// mirostat v2
|
// mirostat v2
|
||||||
|
|
||||||
struct llama_sampler_context_mirostat_v2 {
|
struct llama_sampler_context_mirostat_v2 {
|
||||||
|
const uint32_t seed;
|
||||||
|
|
||||||
const float tau;
|
const float tau;
|
||||||
const float eta;
|
const float eta;
|
||||||
|
|
||||||
float mu;
|
float mu;
|
||||||
|
|
||||||
std::vector<llama_token_data> cur;
|
std::mt19937 rng;
|
||||||
|
|
||||||
|
std::vector<float> probs;
|
||||||
};
|
};
|
||||||
|
|
||||||
static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
|
static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
|
||||||
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat-v2"; },
|
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat-v2"; },
|
||||||
/* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
|
/* .accept = */ nullptr,
|
||||||
auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx;
|
|
||||||
|
|
||||||
int32_t idx = -1;
|
|
||||||
for (size_t i = 0; i < ctx->cur.size(); ++i) {
|
|
||||||
if (ctx->cur[i].id == token) {
|
|
||||||
idx = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
float observed_surprise = -log2f(ctx->cur[idx].p);
|
|
||||||
float e = observed_surprise - ctx->tau;
|
|
||||||
|
|
||||||
// Update mu using the learning rate and error
|
|
||||||
ctx->mu = ctx->mu - ctx->eta * e;
|
|
||||||
},
|
|
||||||
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx;
|
auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx;
|
||||||
|
|
||||||
|
@ -893,33 +883,40 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
|
||||||
// Normalize the probabilities of the remaining words
|
// Normalize the probabilities of the remaining words
|
||||||
llama_sampler_softmax_impl(cur_p);
|
llama_sampler_softmax_impl(cur_p);
|
||||||
|
|
||||||
// remember the order to be able to compute the distance later when accepting the token
|
const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
|
||||||
ctx->cur.resize(cur_p->size);
|
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
cur_p->selected = idx;
|
||||||
ctx->cur[i] = cur_p->data[i];
|
|
||||||
}
|
float observed_surprise = -log2f(cur_p->data[idx].p);
|
||||||
|
float e = observed_surprise - ctx->tau;
|
||||||
|
|
||||||
|
// Update mu using the learning rate and error
|
||||||
|
ctx->mu = ctx->mu - ctx->eta * e;
|
||||||
},
|
},
|
||||||
/* .reset = */ [](struct llama_sampler * smpl) {
|
/* .reset = */ [](struct llama_sampler * smpl) {
|
||||||
auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx;
|
auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx;
|
||||||
ctx->mu = 2.0f*ctx->tau;
|
ctx->mu = 2.0f*ctx->tau;
|
||||||
|
ctx->rng = std::mt19937(ctx->seed);
|
||||||
},
|
},
|
||||||
/* .clone = */ [](const struct llama_sampler * smpl) {
|
/* .clone = */ [](const struct llama_sampler * smpl) {
|
||||||
const auto * ctx = (const llama_sampler_context_mirostat_v2 *) smpl->ctx;
|
const auto * ctx = (const llama_sampler_context_mirostat_v2 *) smpl->ctx;
|
||||||
return llama_sampler_init_mirostat_v2_impl(ctx->tau, ctx->eta);
|
return llama_sampler_init_mirostat_v2_impl(ctx->seed, ctx->tau, ctx->eta);
|
||||||
},
|
},
|
||||||
/* .free = */ [](struct llama_sampler * smpl) {
|
/* .free = */ [](struct llama_sampler * smpl) {
|
||||||
delete (llama_sampler_context_mirostat_v2 *) smpl->ctx;
|
delete (llama_sampler_context_mirostat_v2 *) smpl->ctx;
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_mirostat_v2_impl(float tau, float eta) {
|
struct llama_sampler * llama_sampler_init_mirostat_v2_impl(uint32_t seed, float tau, float eta) {
|
||||||
return new llama_sampler {
|
return new llama_sampler {
|
||||||
/* .iface = */ &llama_sampler_mirostat_v2_i,
|
/* .iface = */ &llama_sampler_mirostat_v2_i,
|
||||||
/* .ctx = */ new llama_sampler_context_mirostat_v2 {
|
/* .ctx = */ new llama_sampler_context_mirostat_v2 {
|
||||||
/* .tau = */ tau,
|
/* .seed = */ seed,
|
||||||
/* .eta = */ eta,
|
/* .tau = */ tau,
|
||||||
/* .mu = */ 2.0f*tau,
|
/* .eta = */ eta,
|
||||||
/* .cur = */ {},
|
/* .mu = */ 2.0f*tau,
|
||||||
|
/* .rng = */ std::mt19937(seed),
|
||||||
|
/* .probs = */ {},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -1154,9 +1151,15 @@ struct llama_sampler * llama_sampler_init_logit_bias_impl(
|
||||||
|
|
||||||
static struct llama_sampler_i llama_sampler_chain_i = {
|
static struct llama_sampler_i llama_sampler_chain_i = {
|
||||||
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; },
|
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; },
|
||||||
/* .accept = */ [](struct llama_sampler * smpl, llama_token /*token*/) {
|
/* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
|
||||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||||
|
|
||||||
|
time_meas tm(chain->t_sample_us, chain->params.no_timing);
|
||||||
|
|
||||||
|
for (auto * smpl : chain->samplers) {
|
||||||
|
llama_sampler_accept_impl(*smpl, token);
|
||||||
|
}
|
||||||
|
|
||||||
chain->n_sample++;
|
chain->n_sample++;
|
||||||
},
|
},
|
||||||
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
|
|
|
@ -58,11 +58,13 @@ struct llama_sampler * llama_sampler_init_temp_ext_impl (float t, float delta
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_mirostat_impl(
|
struct llama_sampler * llama_sampler_init_mirostat_impl(
|
||||||
const struct llama_vocab & vocab,
|
const struct llama_vocab & vocab,
|
||||||
|
uint32_t seed,
|
||||||
float tau,
|
float tau,
|
||||||
float eta,
|
float eta,
|
||||||
int32_t m);
|
int32_t m);
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_mirostat_v2_impl(
|
struct llama_sampler * llama_sampler_init_mirostat_v2_impl(
|
||||||
|
uint32_t seed,
|
||||||
float tau,
|
float tau,
|
||||||
float eta);
|
float eta);
|
||||||
|
|
||||||
|
|
|
@ -20676,12 +20676,12 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
|
||||||
return llama_sampler_init_temp_ext_impl(temp, delta, exponent);
|
return llama_sampler_init_temp_ext_impl(temp, delta, exponent);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, float tau, float eta) {
|
struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, uint32_t seed, float tau, float eta) {
|
||||||
return llama_sampler_init_mirostat_impl(model->vocab, tau, eta, 100);
|
return llama_sampler_init_mirostat_impl(model->vocab, seed, tau, eta, 100);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_mirostat_v2(float tau, float eta) {
|
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
|
||||||
return llama_sampler_init_mirostat_v2_impl(tau, eta);
|
return llama_sampler_init_mirostat_v2_impl(seed, tau, eta);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {
|
struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {
|
||||||
|
|
|
@ -35,7 +35,7 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
|
||||||
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), LLAMA_TOKEN_NULL, false };
|
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
||||||
APPLY(llama_sampler_init_softmax(), &cur_p);
|
APPLY(llama_sampler_init_softmax(), &cur_p);
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
APPLY(llama_sampler_init_top_k(k), &cur_p);
|
APPLY(llama_sampler_init_top_k(k), &cur_p);
|
||||||
|
@ -57,7 +57,7 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
|
||||||
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), LLAMA_TOKEN_NULL, false };
|
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
||||||
APPLY(llama_sampler_init_softmax(), &cur_p);
|
APPLY(llama_sampler_init_softmax(), &cur_p);
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
APPLY(llama_sampler_init_top_p(p, 1), &cur_p);
|
APPLY(llama_sampler_init_top_p(p, 1), &cur_p);
|
||||||
|
@ -79,7 +79,7 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
|
||||||
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), LLAMA_TOKEN_NULL, false };
|
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
APPLY(llama_sampler_init_tail_free(z, 1), &cur_p);
|
APPLY(llama_sampler_init_tail_free(z, 1), &cur_p);
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
|
@ -100,7 +100,7 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
|
||||||
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), LLAMA_TOKEN_NULL, false };
|
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
APPLY(llama_sampler_init_min_p(p, 1), &cur_p);
|
APPLY(llama_sampler_init_min_p(p, 1), &cur_p);
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
|
@ -122,7 +122,7 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
|
||||||
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), LLAMA_TOKEN_NULL, false };
|
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
APPLY(llama_sampler_init_typical(p, 1), &cur_p);
|
APPLY(llama_sampler_init_typical(p, 1), &cur_p);
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
|
@ -153,7 +153,7 @@ static void test_penalties(
|
||||||
token_count[last_tokens[i]]++;
|
token_count[last_tokens[i]]++;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), LLAMA_TOKEN_NULL, false };
|
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
||||||
APPLY(llama_sampler_init_softmax(), &cur_p);
|
APPLY(llama_sampler_init_softmax(), &cur_p);
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
llama_sampler_penalties_impl(&cur_p, token_count, repeat_penalty, alpha_frequency, alpha_presence); // TODO: avoid
|
llama_sampler_penalties_impl(&cur_p, token_count, repeat_penalty, alpha_frequency, alpha_presence); // TODO: avoid
|
||||||
|
@ -175,7 +175,7 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
|
||||||
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), LLAMA_TOKEN_NULL, false };
|
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
||||||
|
|
||||||
llama_token min_token_id = 0;
|
llama_token min_token_id = 0;
|
||||||
const llama_token max_token_id = n_vocab-1;
|
const llama_token max_token_id = n_vocab-1;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue