cont : add n_prev to llama_sampler_params

This commit is contained in:
Georgi Gerganov 2024-09-04 11:54:49 +03:00
parent 91cbb40b29
commit 437376e708
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
6 changed files with 35 additions and 12 deletions

View file

@ -38,6 +38,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
llama_sampler_params lparams = llama_sampler_default_params();
lparams.seed = params.seed;
lparams.n_prev = params.n_prev;
lparams.mirostat = params.mirostat;
lparams.mirostat_tau = params.mirostat_tau;
lparams.mirostat_eta = params.mirostat_eta;
@ -177,8 +178,10 @@ llama_token gpt_sampler_sample(
llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx));
auto * cur_p = llama_sampler_get_candidates(smpl);
// first, sample the token without any grammar constraints
const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.mirostat, params.n_probs);
const llama_token id = gpt_sampler_sample(smpl, cur_p, params.temp, params.mirostat, params.n_probs);
// create an array with a single token data element for the sampled id
llama_token_data single_token_data = { id, 1.0f, 0.0f };
@ -194,7 +197,6 @@ llama_token gpt_sampler_sample(
// if the token is not valid, sample again, after applying the grammar constraints
llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx));
auto * cur_p = llama_sampler_get_candidates(smpl);
llama_constraint_apply(grmr, cur_p);

View file

@ -375,6 +375,8 @@ extern "C" {
typedef struct llama_sampler_params {
uint32_t seed; // the seed used to initialize the rng of the sampler
int32_t n_prev; // size of ring buffer to keep previous accepted tokens (needed for llama_sampler_prev_ API)
int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau; // target entropy
float mirostat_eta; // learning rate

View file

@ -56,7 +56,6 @@ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal
// the ring buffer works similarly to std::deque, but with a fixed capacity
template<typename T>
struct ring_buffer {
ring_buffer() {}
ring_buffer(size_t cap) : capacity(cap), data(cap) {}
T & front() {

View file

@ -983,7 +983,7 @@ struct llama_constraint * llama_constraint_init_penalties_impl(const struct llam
/*.penalty_present =*/ penalty_present,
/*.penalize_nl =*/ penalize_nl,
/*.ignore_eos =*/ ignore_eos,
/*.prev =*/ {},
/*.prev =*/ ring_buffer<llama_token>(penalty_last_n),
},
};
@ -1069,12 +1069,20 @@ void llama_constraint_reset_impl(struct llama_constraint & cnstr) {
// samplers
struct llama_sampler * llama_sampler_init_impl(const struct llama_vocab & vocab, struct llama_sampler_params params) {
auto * result = new llama_sampler;
auto * result = new llama_sampler {
/* .params = */ params,
/* .vocab = */ &vocab,
result->params = params;
result->vocab = &vocab;
/* .rng = */ std::mt19937(params.seed),
result->rng.seed(params.seed);
/* .mirostat_mu = */ 0.0f,
/* .prev = */ { (size_t) params.n_prev },
/* .constraints = */ {},
/* .cur = */ {},
/* .cur_p = */ {},
/* .t_sample_us = */ 0,
/* .n_sample = */ 0,
};
return result;
}
@ -1092,9 +1100,20 @@ void llama_sampler_free_impl(struct llama_sampler * smpl) {
}
struct llama_sampler * llama_sampler_cp_impl(const struct llama_sampler & smpl) {
auto * result = new llama_sampler;
auto * result = new llama_sampler {
/* .params = */ smpl.params,
/* .vocab = */ smpl.vocab,
*result = smpl;
/* .rng = */ smpl.rng,
/* .mirostat_mu = */ smpl.mirostat_mu,
/* .prev = */ smpl.prev,
/* .constraints = */ {},
/* .cur = */ {},
/* .cur_p = */ {},
/* .t_sample_us = */ 0,
/* .n_sample = */ 0,
};
// copy the constraints objects
result->constraints.clear();

View file

@ -111,9 +111,9 @@ struct llama_sampler {
// timing
mutable int64_t t_sample_us = 0;
mutable int64_t t_sample_us;
mutable int32_t n_sample = 0;
mutable int32_t n_sample;
};
struct llama_sampler * llama_sampler_init_impl (const struct llama_vocab & vocab, struct llama_sampler_params params);

View file

@ -17938,6 +17938,7 @@ struct llama_context_params llama_context_default_params() {
struct llama_sampler_params llama_sampler_default_params() {
struct llama_sampler_params result = {
/*.seed =*/ LLAMA_DEFAULT_SEED,
/*.n_prev =*/ 256,
/*.mirostat =*/ 0,
/*.mirostat_tau =*/ 5.00f,
/*.mirostat_eta =*/ 0.10f,