cont : store params in llama_sampling implementation
ggml-ci
This commit is contained in:
parent
d352b01af9
commit
6174762877
5 changed files with 98 additions and 50 deletions
|
@ -32,24 +32,18 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_pa
|
|||
lp.penalize_nl = params.penalize_nl;
|
||||
lp.ignore_eos = params.ignore_eos;
|
||||
|
||||
lp.grammar = params.grammar.c_str();
|
||||
lp.grammar_root = "root";
|
||||
|
||||
lp.cfg_prompt = params.cfg_negative_prompt.c_str();
|
||||
lp.cfg_scale = params.cfg_scale;
|
||||
|
||||
lp.n_logit_bias = params.logit_bias.size();
|
||||
lp.logit_bias = params.logit_bias.data();
|
||||
|
||||
result->smpl = llama_sampling_init(model, lp);
|
||||
|
||||
llama_sampling_set_rng_seed (result->smpl, params.seed);
|
||||
llama_sampling_set_grammar (result->smpl, params.grammar.c_str(), "root");
|
||||
llama_sampling_set_cfg (result->smpl, params.cfg_negative_prompt.c_str(), params.cfg_scale);
|
||||
llama_sampling_set_logit_bias(result->smpl, params.logit_bias.size(), params.logit_bias.data());
|
||||
}
|
||||
|
||||
result->prev.resize(params.n_prev);
|
||||
|
||||
result->n_valid = 0;
|
||||
|
||||
llama_sampling_set_rng_seed(result->smpl, params.seed);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -60,7 +54,7 @@ void llama_sampling_free(struct llama_sampling_context * ctx) {
|
|||
}
|
||||
|
||||
void llama_sampling_reset(llama_sampling_context * ctx) {
|
||||
llama_sampling_reset(ctx->smpl, ctx->params.grammar.c_str(), "root");
|
||||
llama_sampling_reset(ctx->smpl);
|
||||
|
||||
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
||||
ctx->cur.clear();
|
||||
|
@ -378,7 +372,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||
|
||||
if (ctx_cfg) {
|
||||
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
|
||||
llama_sampling_apply_guidance(smpl, logits, logits_guidance, params.cfg_scale);
|
||||
llama_sampling_cfg(smpl, logits, logits_guidance, params.cfg_scale);
|
||||
}
|
||||
|
||||
cur.resize(n_vocab);
|
||||
|
|
|
@ -384,16 +384,6 @@ extern "C" {
|
|||
float mirostat_tau; // target entropy
|
||||
float mirostat_eta; // learning rate
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/pull/1773
|
||||
const char * grammar;
|
||||
const char * grammar_root;
|
||||
|
||||
const char * cfg_prompt; // string to help guidance in negative direction
|
||||
float cfg_scale; // how strong is guidance
|
||||
|
||||
int32_t n_logit_bias;
|
||||
const llama_logit_bias * logit_bias;
|
||||
|
||||
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
||||
bool penalize_nl; // consider newlines as a repeatable token
|
||||
bool ignore_eos; // ignore the end-of-sequence token
|
||||
|
@ -1020,10 +1010,14 @@ extern "C" {
|
|||
|
||||
LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl);
|
||||
|
||||
LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root);
|
||||
//LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root);
|
||||
LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl);
|
||||
|
||||
// Sets the current rng seed.
|
||||
LLAMA_API void llama_sampling_set_rng_seed(struct llama_sampling * smpl, uint32_t seed);
|
||||
LLAMA_API void llama_sampling_set_rng_seed (struct llama_sampling * smpl, uint32_t seed);
|
||||
LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root);
|
||||
LLAMA_API void llama_sampling_set_cfg (struct llama_sampling * smpl, const char * cfg_prompt, float cfg_scale);
|
||||
LLAMA_API void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias);
|
||||
|
||||
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||
LLAMA_API void llama_sampling_softmax(
|
||||
|
@ -1098,7 +1092,7 @@ extern "C" {
|
|||
/// @param logits Logits extracted from the original generation context.
|
||||
/// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
|
||||
/// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
|
||||
LLAMA_API void llama_sampling_apply_guidance(
|
||||
LLAMA_API void llama_sampling_cfg(
|
||||
struct llama_sampling * smpl,
|
||||
float * logits,
|
||||
float * logits_guidance,
|
||||
|
|
|
@ -36,11 +36,9 @@ llama_sampling::~llama_sampling() {
|
|||
struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params) {
|
||||
auto * result = new llama_sampling(vocab);
|
||||
|
||||
// TODO: store params
|
||||
result->params = params;
|
||||
|
||||
if (params.grammar != nullptr && params.grammar[0] != '\0') {
|
||||
result->grammar = llama_grammar_init_impl(result->vocab, params.grammar, params.grammar_root);
|
||||
}
|
||||
llama_sampling_set_rng_seed_impl(*result, params.seed);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -52,6 +50,16 @@ void llama_sampling_free_impl(struct llama_sampling * sampling) {
|
|||
struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl) {
|
||||
auto * result = new llama_sampling(smpl.vocab);
|
||||
|
||||
result->params = smpl.params;
|
||||
|
||||
result->grammar_str = smpl.grammar_str;
|
||||
result->grammar_root = smpl.grammar_root;
|
||||
|
||||
result->cfg_prompt = smpl.cfg_prompt;
|
||||
result->cfg_scale = smpl.cfg_scale;
|
||||
|
||||
result->logit_bias = smpl.logit_bias;
|
||||
|
||||
if (smpl.grammar) {
|
||||
result->grammar = llama_grammar_copy_impl(*smpl.grammar);
|
||||
}
|
||||
|
@ -59,19 +67,14 @@ struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smp
|
|||
return result;
|
||||
}
|
||||
|
||||
void llama_sampling_reset_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root) {
|
||||
// TODO: this is dumb, need to fix
|
||||
const struct llama_vocab * vocab = nullptr;
|
||||
|
||||
void llama_sampling_reset_impl(struct llama_sampling & smpl) {
|
||||
if (smpl.grammar) {
|
||||
vocab = &smpl.grammar->vocab;
|
||||
|
||||
llama_grammar_free_impl(smpl.grammar);
|
||||
smpl.grammar = nullptr;
|
||||
}
|
||||
|
||||
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
||||
smpl.grammar = llama_grammar_init_impl(*vocab, grammar_str, grammar_root);
|
||||
if (!smpl.grammar_str.empty()) {
|
||||
smpl.grammar = llama_grammar_init_impl(smpl.vocab, smpl.grammar_str.data(), smpl.grammar_root.data());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -83,6 +86,42 @@ void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t see
|
|||
smpl.rng.seed(seed);
|
||||
}
|
||||
|
||||
void llama_sampling_set_grammar_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root) {
|
||||
if (smpl.grammar) {
|
||||
llama_grammar_free_impl(smpl.grammar);
|
||||
smpl.grammar = nullptr;
|
||||
}
|
||||
|
||||
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
||||
smpl.grammar_str = grammar_str;
|
||||
smpl.grammar_root = grammar_root;
|
||||
|
||||
smpl.grammar = llama_grammar_init_impl(smpl.vocab, grammar_str, grammar_root);
|
||||
} else {
|
||||
smpl.grammar_str.clear();
|
||||
smpl.grammar_root.clear();
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sampling_set_cfg_impl(struct llama_sampling & smpl, const char * cfg_prompt, float cfg_scale) {
|
||||
if (cfg_prompt != nullptr && cfg_prompt[0] != '\0') {
|
||||
smpl.cfg_prompt = cfg_prompt;
|
||||
} else {
|
||||
smpl.cfg_prompt.clear();
|
||||
}
|
||||
|
||||
smpl.cfg_scale = cfg_scale;
|
||||
}
|
||||
|
||||
void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) {
|
||||
smpl.logit_bias.clear();
|
||||
smpl.logit_bias.reserve(n_logit_bias);
|
||||
|
||||
for (int32_t i = 0; i < n_logit_bias; ++i) {
|
||||
smpl.logit_bias.push_back(logit_bias[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sampling_softmax_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) {
|
||||
GGML_ASSERT(candidates->size > 0);
|
||||
|
||||
|
|
|
@ -9,10 +9,22 @@ struct llama_sampling {
|
|||
llama_sampling(const struct llama_vocab & vocab);
|
||||
~llama_sampling();
|
||||
|
||||
const llama_vocab & vocab;
|
||||
llama_sampling_params params;
|
||||
|
||||
std::string grammar_str;
|
||||
std::string grammar_root;
|
||||
|
||||
std::string cfg_prompt;
|
||||
float cfg_scale = 1.0f;
|
||||
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
|
||||
// state
|
||||
|
||||
std::mt19937 rng;
|
||||
|
||||
const struct llama_vocab & vocab;
|
||||
|
||||
struct llama_grammar * grammar = nullptr;
|
||||
|
||||
mutable int64_t t_total_us = 0;
|
||||
|
@ -30,10 +42,13 @@ void llama_sampling_free_impl(struct llama_sampling * sampling);
|
|||
|
||||
struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl);
|
||||
|
||||
void llama_sampling_reset_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root);
|
||||
void llama_sampling_reset_impl(struct llama_sampling & smpl);
|
||||
|
||||
// TODO: move the API below as member functions of llama_sampling
|
||||
void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed);
|
||||
void llama_sampling_set_rng_seed_impl (struct llama_sampling & smpl, uint32_t seed);
|
||||
void llama_sampling_set_grammar_impl (struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root);
|
||||
void llama_sampling_set_cfg_impl (struct llama_sampling & smpl, const char * cfg_prompt, float cfg_scale);
|
||||
void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias);
|
||||
|
||||
void llama_sampling_softmax_impl (struct llama_sampling & smpl, llama_token_data_array * candidates);
|
||||
void llama_sampling_top_k_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
|
||||
|
|
|
@ -16511,12 +16511,6 @@ struct llama_sampling_params llama_sampling_default_params() {
|
|||
/*.mirostat =*/ 0,
|
||||
/*.mirostat_tau =*/ 5.00f,
|
||||
/*.mirostat_eta =*/ 0.10f,
|
||||
/*.grammar =*/ nullptr,
|
||||
/*.grammar_root =*/ nullptr,
|
||||
/*.cfg_prompt =*/ nullptr,
|
||||
/*.cfg_scale =*/ 1.00f,
|
||||
/*.n_logit_bias =*/ 0,
|
||||
/*.logit_bias =*/ nullptr,
|
||||
/*.penalize_nl =*/ false,
|
||||
/*.ignore_eos =*/ false,
|
||||
};
|
||||
|
@ -19109,14 +19103,26 @@ struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl) {
|
|||
return llama_sampling_cp_impl(*smpl);
|
||||
}
|
||||
|
||||
void llama_sampling_reset(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) {
|
||||
llama_sampling_reset_impl(*smpl, grammar_str, grammar_root);
|
||||
void llama_sampling_reset(struct llama_sampling * smpl) {
|
||||
llama_sampling_reset_impl(*smpl);
|
||||
}
|
||||
|
||||
void llama_sampling_set_rng_seed(struct llama_sampling * smpl, uint32_t seed) {
|
||||
llama_sampling_set_rng_seed_impl(*smpl, seed);
|
||||
}
|
||||
|
||||
void llama_sampling_set_grammar(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) {
|
||||
llama_sampling_set_grammar_impl(*smpl, grammar_str, grammar_root);
|
||||
}
|
||||
|
||||
void llama_sampling_set_cfg(struct llama_sampling * smpl, const char * cfg_prompt, float cfg_scale) {
|
||||
llama_sampling_set_cfg_impl(*smpl, cfg_prompt, cfg_scale);
|
||||
}
|
||||
|
||||
void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) {
|
||||
llama_sampling_set_logit_bias_impl(*smpl, n_logit_bias, logit_bias);
|
||||
}
|
||||
|
||||
void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
time_meas tm(smpl->t_total_us);
|
||||
|
||||
|
@ -19186,7 +19192,7 @@ void llama_sampling_repetition_penalties(
|
|||
llama_sampling_repetition_penalties_impl(*smpl, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
|
||||
}
|
||||
|
||||
void llama_sampling_apply_guidance(
|
||||
void llama_sampling_cfg(
|
||||
struct llama_sampling * smpl,
|
||||
float * logits,
|
||||
float * logits_guidance,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue