sampling : change _cp/copy to clone

This commit is contained in:
Georgi Gerganov 2024-09-05 10:25:33 +03:00
parent 69551ffd60
commit ebeb65194b
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
9 changed files with 50 additions and 51 deletions

View file

@ -114,13 +114,13 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl) {
}
}
struct gpt_sampler * gpt_sampler_cp(gpt_sampler * gsmpl) {
struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
return new gpt_sampler {
/* .params = */ gsmpl->params,
/* .bias = */ llama_constraint_cp(gsmpl->bias),
/* .pnlt = */ llama_constraint_cp(gsmpl->pnlt),
/* .grmr = */ llama_constraint_cp(gsmpl->grmr),
/* .smpl = */ llama_sampler_cp (gsmpl->smpl)
/* .bias = */ llama_constraint_clone(gsmpl->bias),
/* .pnlt = */ llama_constraint_clone(gsmpl->pnlt),
/* .grmr = */ llama_constraint_clone(gsmpl->grmr),
/* .smpl = */ llama_sampler_clone (gsmpl->smpl)
};
}

View file

@ -68,7 +68,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
void gpt_sampler_free(struct gpt_sampler * gsmpl);
struct gpt_sampler * gpt_sampler_cp(gpt_sampler * gsmpl);
struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl);
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar);
void gpt_sampler_reset (struct gpt_sampler * gsmpl);

View file

@ -451,7 +451,7 @@ int main(int argc, char ** argv) {
if (drafts[0].smpl) {
gpt_sampler_free(drafts[0].smpl);
}
drafts[0].smpl = gpt_sampler_cp(smpl);
drafts[0].smpl = gpt_sampler_clone(smpl);
int n_seq_cur = 1;
int n_past_cur = n_past_dft;
@ -523,7 +523,7 @@ int main(int argc, char ** argv) {
if (drafts[n_seq_cur].smpl) {
gpt_sampler_free(drafts[n_seq_cur].smpl);
}
drafts[n_seq_cur].smpl = gpt_sampler_cp(drafts[s].smpl);
drafts[n_seq_cur].smpl = gpt_sampler_clone(drafts[s].smpl);
sa.push_back(n_seq_cur);

View file

@ -1032,7 +1032,7 @@ extern "C" {
void (*accept)( struct llama_constraint * cnstr, llama_token token); // can be NULL
void (*apply) ( struct llama_constraint * cnstr, llama_token_data_array * cur_p); // required
void (*reset) ( struct llama_constraint * cnstr); // can be NULL
struct llama_constraint * (*copy) (const struct llama_constraint * cnstr); // can be NULL if ctx is NULL
struct llama_constraint * (*clone) (const struct llama_constraint * cnstr); // can be NULL if ctx is NULL
void (*free) ( struct llama_constraint * cnstr); // can be NULL if ctx is NULL
// TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
@ -1053,11 +1053,22 @@ extern "C" {
LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t);
LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
LLAMA_API struct llama_constraint * llama_constraint_init_mirostat(
const struct llama_model * model,
float tau,
float eta);
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
LLAMA_API struct llama_constraint * llama_constraint_init_mirostat_v2(
float tau,
float eta);
@ -1081,7 +1092,7 @@ extern "C" {
int32_t n_logit_bias,
const llama_logit_bias * logit_bias);
LLAMA_API struct llama_constraint * llama_constraint_cp(const struct llama_constraint * cnstr);
LLAMA_API struct llama_constraint * llama_constraint_clone(const struct llama_constraint * cnstr);
// important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_constraint_add)
LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr);
@ -1094,7 +1105,7 @@ extern "C" {
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_model * model, struct llama_sampler_params params);
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl);
LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl);
LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl);
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);

View file

@ -1050,7 +1050,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
delete grammar;
}
struct llama_grammar * llama_grammar_cp_impl(const struct llama_grammar & grammar) {
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, };
// redirect elements in stacks to point to new rules

View file

@ -131,7 +131,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
void llama_grammar_free_impl(struct llama_grammar * grammar);
struct llama_grammar * llama_grammar_cp_impl(const struct llama_grammar & grammar);
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar);
// TODO: move the API below as member functions of llama_grammar
void llama_grammar_apply_impl(

View file

@ -433,7 +433,7 @@ static struct llama_constraint_i llama_constraint_softmax_i = {
llama_constraint_softmax_impl(cur_p);
},
/* .reset = */ nullptr,
/* .copy = */ nullptr,
/* .clone = */ nullptr,
/* .free = */ nullptr,
};
@ -458,7 +458,7 @@ static struct llama_constraint_i llama_constraint_top_k_i = {
llama_constraint_top_k_impl(cur_p, ctx->k);
},
/* .reset = */ nullptr,
/* .copy = */ [](const struct llama_constraint * cnstr) {
/* .clone = */ [](const struct llama_constraint * cnstr) {
const auto * ctx = (const llama_constraint_context_top_k *) cnstr->ctx;
return llama_constraint_init_top_k_impl(ctx->k);
},
@ -491,7 +491,7 @@ static struct llama_constraint_i llama_constraint_top_p_i = {
llama_constraint_top_p_impl(cur_p, ctx->p, ctx->min_keep);
},
/* .reset = */ nullptr,
/* .copy = */ [](const struct llama_constraint * cnstr) {
/* .clone = */ [](const struct llama_constraint * cnstr) {
const auto * ctx = (const llama_constraint_context_top_p *) cnstr->ctx;
return llama_constraint_init_top_p_impl(ctx->p, ctx->min_keep);
},
@ -525,7 +525,7 @@ static struct llama_constraint_i llama_constraint_min_p_i = {
llama_constraint_min_p_impl(cur_p, ctx->p, ctx->min_keep);
},
/* .reset = */ nullptr,
/* .copy = */ [](const struct llama_constraint * cnstr) {
/* .clone = */ [](const struct llama_constraint * cnstr) {
const auto * ctx = (const llama_constraint_context_min_p *) cnstr->ctx;
return llama_constraint_init_min_p_impl(ctx->p, ctx->min_keep);
},
@ -559,7 +559,7 @@ static struct llama_constraint_i llama_constraint_tail_free_i = {
llama_constraint_tail_free_impl(cur_p, ctx->z, ctx->min_keep);
},
/* .reset = */ nullptr,
/* .copy = */ [](const struct llama_constraint * cnstr) {
/* .clone = */ [](const struct llama_constraint * cnstr) {
const auto * ctx = (const llama_constraint_context_tail_free *) cnstr->ctx;
return llama_constraint_init_tail_free_impl(ctx->z, ctx->min_keep);
},
@ -593,7 +593,7 @@ static struct llama_constraint_i llama_constraint_typical_i = {
llama_constraint_typical_impl(cur_p, ctx->p, ctx->min_keep);
},
/* .reset = */ nullptr,
/* .copy = */ [](const struct llama_constraint * cnstr) {
/* .clone = */ [](const struct llama_constraint * cnstr) {
const auto * ctx = (const llama_constraint_context_typical *) cnstr->ctx;
return llama_constraint_init_typical_impl(ctx->p, ctx->min_keep);
},
@ -626,7 +626,7 @@ static struct llama_constraint_i llama_constraint_temp_i = {
llama_constraint_temp_impl(cur_p, ctx->temp);
},
/* .reset = */ nullptr,
/* .copy = */ [](const struct llama_constraint * cnstr) {
/* .clone = */ [](const struct llama_constraint * cnstr) {
const auto * ctx = (const llama_constraint_context_temp *) cnstr->ctx;
return llama_constraint_init_temp_impl(ctx->temp);
},
@ -667,7 +667,7 @@ static struct llama_constraint_i llama_constraint_temp_ext_i = {
}
},
/* .reset = */ nullptr,
/* .copy = */ [](const struct llama_constraint * cnstr) {
/* .clone = */ [](const struct llama_constraint * cnstr) {
const auto * ctx = (const llama_constraint_context_temp_ext *) cnstr->ctx;
return llama_constraint_init_temp_ext_impl(ctx->temp, ctx->delta, ctx->exponent);
},
@ -754,7 +754,7 @@ static struct llama_constraint_i llama_constraint_mirostat_i = {
auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx;
ctx->mu = 2.0f*ctx->tau;
},
/* .copy = */ [](const struct llama_constraint * cnstr) {
/* .clone = */ [](const struct llama_constraint * cnstr) {
const auto * ctx = (const llama_constraint_context_mirostat *) cnstr->ctx;
return llama_constraint_init_mirostat_impl(*ctx->vocab, ctx->tau, ctx->eta, ctx->m);
},
@ -834,7 +834,7 @@ static struct llama_constraint_i llama_constraint_mirostat_v2_i = {
auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx;
ctx->mu = 2.0f*ctx->tau;
},
/* .copy = */ [](const struct llama_constraint * cnstr) {
/* .clone = */ [](const struct llama_constraint * cnstr) {
const auto * ctx = (const llama_constraint_context_mirostat_v2 *) cnstr->ctx;
return llama_constraint_init_mirostat_v2_impl(ctx->tau, ctx->eta);
},
@ -891,7 +891,7 @@ static struct llama_constraint_i llama_constraint_grammar_i = {
llama_grammar_free_impl(ctx->grammar);
ctx->grammar = grammar_new;
},
/* .copy = */ [](const struct llama_constraint * cnstr) {
/* .clone = */ [](const struct llama_constraint * cnstr) {
const auto * ctx_src = (const llama_constraint_context_grammar *) cnstr->ctx;
auto * result = llama_constraint_init_grammar_impl(*ctx_src->vocab, nullptr, nullptr);
@ -901,7 +901,7 @@ static struct llama_constraint_i llama_constraint_grammar_i = {
ctx_dst->grammar_str = ctx_src->grammar_str;
ctx_dst->grammar_root = ctx_src->grammar_root;
ctx_dst->grammar = llama_grammar_cp_impl(*ctx_src->grammar);
ctx_dst->grammar = llama_grammar_clone_impl(*ctx_src->grammar);
}
return result;
@ -998,7 +998,7 @@ static struct llama_constraint_i llama_constraint_penalties_i = {
auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx;
ctx->prev.clear();
},
/* .copy = */ [](const struct llama_constraint * cnstr) {
/* .clone = */ [](const struct llama_constraint * cnstr) {
const auto * ctx_src = (const llama_constraint_context_penalties *) cnstr->ctx;
auto * result = llama_constraint_init_penalties_impl(
*ctx_src->vocab,
@ -1059,7 +1059,7 @@ static struct llama_constraint_i llama_constraint_logit_bias_i = {
}
},
/* .reset = */ nullptr,
/* .copy = */ [](const struct llama_constraint * cnstr) {
/* .clone = */ [](const struct llama_constraint * cnstr) {
const auto * ctx_src = (const llama_constraint_context_logit_bias *) cnstr->ctx;
return llama_constraint_init_logit_bias_impl(*ctx_src->vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data());
},
@ -1083,8 +1083,8 @@ struct llama_constraint * llama_constraint_init_logit_bias_impl(
////////////////////////////////////////
struct llama_constraint * llama_constraint_cp_impl(const struct llama_constraint & cnstr) {
return cnstr.iface->copy ? cnstr.iface->copy(&cnstr) : nullptr;
struct llama_constraint * llama_constraint_clone_impl(const struct llama_constraint & cnstr) {
return cnstr.iface->clone ? cnstr.iface->clone(&cnstr) : nullptr;
}
void llama_constraint_free_impl(struct llama_constraint * cnstr) {
@ -1148,7 +1148,7 @@ void llama_sampler_free_impl(struct llama_sampler * smpl) {
delete smpl;
}
struct llama_sampler * llama_sampler_cp_impl(const struct llama_sampler & smpl) {
struct llama_sampler * llama_sampler_clone_impl(const struct llama_sampler & smpl) {
auto * result = new llama_sampler {
/* .params = */ smpl.params,
/* .vocab = */ smpl.vocab,
@ -1163,7 +1163,7 @@ struct llama_sampler * llama_sampler_cp_impl(const struct llama_sampler & smpl)
/* .n_sample = */ 0,
};
// copy the constraints objects
// clone the constraints objects
result->constraints.clear();
for (const auto & cnstr : smpl.constraints) {
if (cnstr->ctx == nullptr) {
@ -1172,8 +1172,8 @@ struct llama_sampler * llama_sampler_cp_impl(const struct llama_sampler & smpl)
/* .ctx = */ nullptr,
});
} else {
GGML_ASSERT(cnstr->iface->copy);
result->constraints.push_back(cnstr->iface->copy(cnstr));
GGML_ASSERT(cnstr->iface->clone);
result->constraints.push_back(cnstr->iface->clone(cnstr));
}
}

View file

@ -29,24 +29,12 @@ struct llama_constraint * llama_constraint_init_typical_impl (float p, size
struct llama_constraint * llama_constraint_init_temp_impl (float t);
struct llama_constraint * llama_constraint_init_temp_ext_impl (float t, float delta, float exponent);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
struct llama_constraint * llama_constraint_init_mirostat_impl(
const struct llama_vocab & vocab,
float tau,
float eta,
int32_t m);
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
struct llama_constraint * llama_constraint_init_mirostat_v2_impl(
float tau,
float eta);
@ -70,7 +58,7 @@ struct llama_constraint * llama_constraint_init_penalties_impl(
int32_t n_logit_bias,
const llama_logit_bias * logit_bias);
struct llama_constraint * llama_constraint_cp_impl(const struct llama_constraint & cnstr);
struct llama_constraint * llama_constraint_clone_impl(const struct llama_constraint & cnstr);
void llama_constraint_free_impl(struct llama_constraint * cnstr);
@ -106,7 +94,7 @@ struct llama_sampler {
struct llama_sampler * llama_sampler_init_impl (const struct llama_vocab & vocab, struct llama_sampler_params params);
void llama_sampler_free_impl ( struct llama_sampler * smpl);
struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl);
struct llama_sampler * llama_sampler_clone_impl (const struct llama_sampler & smpl);
void llama_sampler_reset_impl ( struct llama_sampler & smpl);
void llama_sampler_accept_impl( struct llama_sampler & smpl, llama_token token);
void llama_sampler_apply_impl ( struct llama_sampler & smpl, struct llama_token_data_array * cur_p);

View file

@ -20669,8 +20669,8 @@ LLAMA_API struct llama_constraint * llama_constraint_init_logit_bias(
return llama_constraint_init_logit_bias_impl(model->vocab, n_logit_bias, logit_bias);
}
struct llama_constraint * llama_constraint_cp(const struct llama_constraint * cnstr) {
return llama_constraint_cp_impl(*cnstr);
struct llama_constraint * llama_constraint_clone(const struct llama_constraint * cnstr) {
return llama_constraint_clone_impl(*cnstr);
}
void llama_constraint_free(struct llama_constraint * cnstr) {
@ -20705,8 +20705,8 @@ void llama_sampler_free(struct llama_sampler * smpl) {
llama_sampler_free_impl(smpl);
}
struct llama_sampler * llama_sampler_cp(const struct llama_sampler * smpl) {
return llama_sampler_cp_impl(*smpl);
struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
return llama_sampler_clone_impl(*smpl);
}
void llama_sampler_reset(struct llama_sampler * smpl) {