cont : add n_prev to llama_sampler_params
This commit is contained in:
parent
91cbb40b29
commit
437376e708
6 changed files with 35 additions and 12 deletions
|
@ -38,6 +38,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
|||
llama_sampler_params lparams = llama_sampler_default_params();
|
||||
|
||||
lparams.seed = params.seed;
|
||||
lparams.n_prev = params.n_prev;
|
||||
lparams.mirostat = params.mirostat;
|
||||
lparams.mirostat_tau = params.mirostat_tau;
|
||||
lparams.mirostat_eta = params.mirostat_eta;
|
||||
|
@ -177,8 +178,10 @@ llama_token gpt_sampler_sample(
|
|||
|
||||
llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx));
|
||||
|
||||
auto * cur_p = llama_sampler_get_candidates(smpl);
|
||||
|
||||
// first, sample the token without any grammar constraints
|
||||
const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.mirostat, params.n_probs);
|
||||
const llama_token id = gpt_sampler_sample(smpl, cur_p, params.temp, params.mirostat, params.n_probs);
|
||||
|
||||
// create an array with a single token data element for the sampled id
|
||||
llama_token_data single_token_data = { id, 1.0f, 0.0f };
|
||||
|
@ -194,7 +197,6 @@ llama_token gpt_sampler_sample(
|
|||
|
||||
// if the token is not valid, sample again, after applying the grammar constraints
|
||||
llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx));
|
||||
auto * cur_p = llama_sampler_get_candidates(smpl);
|
||||
|
||||
llama_constraint_apply(grmr, cur_p);
|
||||
|
||||
|
|
|
@ -375,6 +375,8 @@ extern "C" {
|
|||
typedef struct llama_sampler_params {
|
||||
uint32_t seed; // the seed used to initialize the rng of the sampler
|
||||
|
||||
int32_t n_prev; // size of ring buffer to keep previous accepted tokens (needed for llama_sampler_prev_ API)
|
||||
|
||||
int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float mirostat_tau; // target entropy
|
||||
float mirostat_eta; // learning rate
|
||||
|
|
|
@ -56,7 +56,6 @@ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal
|
|||
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
||||
template<typename T>
|
||||
struct ring_buffer {
|
||||
ring_buffer() {}
|
||||
ring_buffer(size_t cap) : capacity(cap), data(cap) {}
|
||||
|
||||
T & front() {
|
||||
|
|
|
@ -983,7 +983,7 @@ struct llama_constraint * llama_constraint_init_penalties_impl(const struct llam
|
|||
/*.penalty_present =*/ penalty_present,
|
||||
/*.penalize_nl =*/ penalize_nl,
|
||||
/*.ignore_eos =*/ ignore_eos,
|
||||
/*.prev =*/ {},
|
||||
/*.prev =*/ ring_buffer<llama_token>(penalty_last_n),
|
||||
},
|
||||
};
|
||||
|
||||
|
@ -1069,12 +1069,20 @@ void llama_constraint_reset_impl(struct llama_constraint & cnstr) {
|
|||
// samplers
|
||||
|
||||
struct llama_sampler * llama_sampler_init_impl(const struct llama_vocab & vocab, struct llama_sampler_params params) {
|
||||
auto * result = new llama_sampler;
|
||||
auto * result = new llama_sampler {
|
||||
/* .params = */ params,
|
||||
/* .vocab = */ &vocab,
|
||||
|
||||
result->params = params;
|
||||
result->vocab = &vocab;
|
||||
/* .rng = */ std::mt19937(params.seed),
|
||||
|
||||
result->rng.seed(params.seed);
|
||||
/* .mirostat_mu = */ 0.0f,
|
||||
/* .prev = */ { (size_t) params.n_prev },
|
||||
/* .constraints = */ {},
|
||||
/* .cur = */ {},
|
||||
/* .cur_p = */ {},
|
||||
/* .t_sample_us = */ 0,
|
||||
/* .n_sample = */ 0,
|
||||
};
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -1092,9 +1100,20 @@ void llama_sampler_free_impl(struct llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_cp_impl(const struct llama_sampler & smpl) {
|
||||
auto * result = new llama_sampler;
|
||||
auto * result = new llama_sampler {
|
||||
/* .params = */ smpl.params,
|
||||
/* .vocab = */ smpl.vocab,
|
||||
|
||||
*result = smpl;
|
||||
/* .rng = */ smpl.rng,
|
||||
|
||||
/* .mirostat_mu = */ smpl.mirostat_mu,
|
||||
/* .prev = */ smpl.prev,
|
||||
/* .constraints = */ {},
|
||||
/* .cur = */ {},
|
||||
/* .cur_p = */ {},
|
||||
/* .t_sample_us = */ 0,
|
||||
/* .n_sample = */ 0,
|
||||
};
|
||||
|
||||
// copy the constraints objects
|
||||
result->constraints.clear();
|
||||
|
|
|
@ -111,9 +111,9 @@ struct llama_sampler {
|
|||
|
||||
// timing
|
||||
|
||||
mutable int64_t t_sample_us = 0;
|
||||
mutable int64_t t_sample_us;
|
||||
|
||||
mutable int32_t n_sample = 0;
|
||||
mutable int32_t n_sample;
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_impl (const struct llama_vocab & vocab, struct llama_sampler_params params);
|
||||
|
|
|
@ -17938,6 +17938,7 @@ struct llama_context_params llama_context_default_params() {
|
|||
struct llama_sampler_params llama_sampler_default_params() {
|
||||
struct llama_sampler_params result = {
|
||||
/*.seed =*/ LLAMA_DEFAULT_SEED,
|
||||
/*.n_prev =*/ 256,
|
||||
/*.mirostat =*/ 0,
|
||||
/*.mirostat_tau =*/ 5.00f,
|
||||
/*.mirostat_eta =*/ 0.10f,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue