diff --git a/common/sampling.cpp b/common/sampling.cpp index a5e76dfd4..4dfbe9021 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -38,6 +38,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st llama_sampler_params lparams = llama_sampler_default_params(); lparams.seed = params.seed; + lparams.n_prev = params.n_prev; lparams.mirostat = params.mirostat; lparams.mirostat_tau = params.mirostat_tau; lparams.mirostat_eta = params.mirostat_eta; @@ -177,8 +178,10 @@ llama_token gpt_sampler_sample( llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx)); + auto * cur_p = llama_sampler_get_candidates(smpl); + // first, sample the token without any grammar constraints - const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.mirostat, params.n_probs); + const llama_token id = gpt_sampler_sample(smpl, cur_p, params.temp, params.mirostat, params.n_probs); // create an array with a single token data element for the sampled id llama_token_data single_token_data = { id, 1.0f, 0.0f }; @@ -194,7 +197,6 @@ llama_token gpt_sampler_sample( // if the token is not valid, sample again, after applying the grammar constraints llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx)); - auto * cur_p = llama_sampler_get_candidates(smpl); llama_constraint_apply(grmr, cur_p); diff --git a/include/llama.h b/include/llama.h index 920952d68..49011b7fe 100644 --- a/include/llama.h +++ b/include/llama.h @@ -375,6 +375,8 @@ extern "C" { typedef struct llama_sampler_params { uint32_t seed; // the seed used to initialize the rng of the sampler + int32_t n_prev; // size of ring buffer to keep previous accepted tokens (needed for llama_sampler_prev_ API) + int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float mirostat_tau; // target entropy float mirostat_eta; // learning rate diff --git a/src/llama-impl.h b/src/llama-impl.h index b67f511c0..6d388655d 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -56,7 +56,6 @@ const std::vector> & llama_internal // the ring buffer works similarly to std::deque, but with a fixed capacity template struct ring_buffer { - ring_buffer() {} ring_buffer(size_t cap) : capacity(cap), data(cap) {} T & front() { diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 7a1f8a805..abf9d5a8e 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -983,7 +983,7 @@ struct llama_constraint * llama_constraint_init_penalties_impl(const struct llam /*.penalty_present =*/ penalty_present, /*.penalize_nl =*/ penalize_nl, /*.ignore_eos =*/ ignore_eos, - /*.prev =*/ {}, + /*.prev =*/ ring_buffer(penalty_last_n), }, }; @@ -1069,12 +1069,20 @@ void llama_constraint_reset_impl(struct llama_constraint & cnstr) { // samplers struct llama_sampler * llama_sampler_init_impl(const struct llama_vocab & vocab, struct llama_sampler_params params) { - auto * result = new llama_sampler; + auto * result = new llama_sampler { + /* .params = */ params, + /* .vocab = */ &vocab, - result->params = params; - result->vocab = &vocab; + /* .rng = */ std::mt19937(params.seed), - result->rng.seed(params.seed); + /* .mirostat_mu = */ 0.0f, + /* .prev = */ { (size_t) params.n_prev }, + /* .constraints = */ {}, + /* .cur = */ {}, + /* .cur_p = */ {}, + /* .t_sample_us = */ 0, + /* .n_sample = */ 0, + }; return result; } @@ -1092,9 +1100,20 @@ void llama_sampler_free_impl(struct llama_sampler * smpl) { } struct llama_sampler * llama_sampler_cp_impl(const struct llama_sampler & smpl) { - auto * result = new llama_sampler; + auto * result = new llama_sampler { + /* .params = */ smpl.params, + /* .vocab = */ smpl.vocab, - *result = smpl; + /* .rng = */ smpl.rng, + + /* .mirostat_mu = */ smpl.mirostat_mu, + /* .prev = */ smpl.prev, + /* .constraints = */ {}, + /* .cur = */ {}, + /* .cur_p = */ {}, + /* .t_sample_us = */ 0, + /* .n_sample = */ 0, + }; // copy the constraints objects result->constraints.clear(); diff --git a/src/llama-sampling.h b/src/llama-sampling.h index dd9236392..501f11de8 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -111,9 +111,9 @@ struct llama_sampler { // timing - mutable int64_t t_sample_us = 0; + mutable int64_t t_sample_us; - mutable int32_t n_sample = 0; + mutable int32_t n_sample; }; struct llama_sampler * llama_sampler_init_impl (const struct llama_vocab & vocab, struct llama_sampler_params params); diff --git a/src/llama.cpp b/src/llama.cpp index a40fc4c30..903be4575 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17938,6 +17938,7 @@ struct llama_context_params llama_context_default_params() { struct llama_sampler_params llama_sampler_default_params() { struct llama_sampler_params result = { /*.seed =*/ LLAMA_DEFAULT_SEED, + /*.n_prev =*/ 256, /*.mirostat =*/ 0, /*.mirostat_tau =*/ 5.00f, /*.mirostat_eta =*/ 0.10f,