cont : store params in llama_sampling implementation

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-08-12 19:24:12 +03:00
parent d352b01af9
commit 6174762877
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 98 additions and 50 deletions

View file

@ -32,24 +32,18 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_pa
lp.penalize_nl = params.penalize_nl;
lp.ignore_eos = params.ignore_eos;
lp.grammar = params.grammar.c_str();
lp.grammar_root = "root";
lp.cfg_prompt = params.cfg_negative_prompt.c_str();
lp.cfg_scale = params.cfg_scale;
lp.n_logit_bias = params.logit_bias.size();
lp.logit_bias = params.logit_bias.data();
result->smpl = llama_sampling_init(model, lp);
llama_sampling_set_rng_seed (result->smpl, params.seed);
llama_sampling_set_grammar (result->smpl, params.grammar.c_str(), "root");
llama_sampling_set_cfg (result->smpl, params.cfg_negative_prompt.c_str(), params.cfg_scale);
llama_sampling_set_logit_bias(result->smpl, params.logit_bias.size(), params.logit_bias.data());
}
result->prev.resize(params.n_prev);
result->n_valid = 0;
llama_sampling_set_rng_seed(result->smpl, params.seed);
return result;
}
@ -60,7 +54,7 @@ void llama_sampling_free(struct llama_sampling_context * ctx) {
}
void llama_sampling_reset(llama_sampling_context * ctx) {
llama_sampling_reset(ctx->smpl, ctx->params.grammar.c_str(), "root");
llama_sampling_reset(ctx->smpl);
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->cur.clear();
@ -378,7 +372,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
if (ctx_cfg) {
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
llama_sampling_apply_guidance(smpl, logits, logits_guidance, params.cfg_scale);
llama_sampling_cfg(smpl, logits, logits_guidance, params.cfg_scale);
}
cur.resize(n_vocab);

View file

@ -384,16 +384,6 @@ extern "C" {
float mirostat_tau; // target entropy
float mirostat_eta; // learning rate
// https://github.com/ggerganov/llama.cpp/pull/1773
const char * grammar;
const char * grammar_root;
const char * cfg_prompt; // string to help guidance in negative direction
float cfg_scale; // how strong is guidance
int32_t n_logit_bias;
const llama_logit_bias * logit_bias;
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
bool penalize_nl; // consider newlines as a repeatable token
bool ignore_eos; // ignore the end-of-sequence token
@ -1020,10 +1010,14 @@ extern "C" {
LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl);
LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root);
//LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root);
LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl);
// Sets the current rng seed.
LLAMA_API void llama_sampling_set_rng_seed(struct llama_sampling * smpl, uint32_t seed);
LLAMA_API void llama_sampling_set_rng_seed (struct llama_sampling * smpl, uint32_t seed);
LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root);
LLAMA_API void llama_sampling_set_cfg (struct llama_sampling * smpl, const char * cfg_prompt, float cfg_scale);
LLAMA_API void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias);
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
LLAMA_API void llama_sampling_softmax(
@ -1098,7 +1092,7 @@ extern "C" {
/// @param logits Logits extracted from the original generation context.
/// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
/// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
LLAMA_API void llama_sampling_apply_guidance(
LLAMA_API void llama_sampling_cfg(
struct llama_sampling * smpl,
float * logits,
float * logits_guidance,

View file

@ -36,11 +36,9 @@ llama_sampling::~llama_sampling() {
struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params) {
auto * result = new llama_sampling(vocab);
// TODO: store params
result->params = params;
if (params.grammar != nullptr && params.grammar[0] != '\0') {
result->grammar = llama_grammar_init_impl(result->vocab, params.grammar, params.grammar_root);
}
llama_sampling_set_rng_seed_impl(*result, params.seed);
return result;
}
@ -52,6 +50,16 @@ void llama_sampling_free_impl(struct llama_sampling * sampling) {
struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl) {
auto * result = new llama_sampling(smpl.vocab);
result->params = smpl.params;
result->grammar_str = smpl.grammar_str;
result->grammar_root = smpl.grammar_root;
result->cfg_prompt = smpl.cfg_prompt;
result->cfg_scale = smpl.cfg_scale;
result->logit_bias = smpl.logit_bias;
if (smpl.grammar) {
result->grammar = llama_grammar_copy_impl(*smpl.grammar);
}
@ -59,19 +67,14 @@ struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smp
return result;
}
void llama_sampling_reset_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root) {
// TODO: this is dumb, need to fix
const struct llama_vocab * vocab = nullptr;
void llama_sampling_reset_impl(struct llama_sampling & smpl) {
if (smpl.grammar) {
vocab = &smpl.grammar->vocab;
llama_grammar_free_impl(smpl.grammar);
smpl.grammar = nullptr;
}
if (grammar_str != nullptr && grammar_str[0] != '\0') {
smpl.grammar = llama_grammar_init_impl(*vocab, grammar_str, grammar_root);
if (!smpl.grammar_str.empty()) {
smpl.grammar = llama_grammar_init_impl(smpl.vocab, smpl.grammar_str.data(), smpl.grammar_root.data());
}
}
@ -83,6 +86,42 @@ void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t see
smpl.rng.seed(seed);
}
void llama_sampling_set_grammar_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root) {
if (smpl.grammar) {
llama_grammar_free_impl(smpl.grammar);
smpl.grammar = nullptr;
}
if (grammar_str != nullptr && grammar_str[0] != '\0') {
smpl.grammar_str = grammar_str;
smpl.grammar_root = grammar_root;
smpl.grammar = llama_grammar_init_impl(smpl.vocab, grammar_str, grammar_root);
} else {
smpl.grammar_str.clear();
smpl.grammar_root.clear();
}
}
void llama_sampling_set_cfg_impl(struct llama_sampling & smpl, const char * cfg_prompt, float cfg_scale) {
if (cfg_prompt != nullptr && cfg_prompt[0] != '\0') {
smpl.cfg_prompt = cfg_prompt;
} else {
smpl.cfg_prompt.clear();
}
smpl.cfg_scale = cfg_scale;
}
void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) {
smpl.logit_bias.clear();
smpl.logit_bias.reserve(n_logit_bias);
for (int32_t i = 0; i < n_logit_bias; ++i) {
smpl.logit_bias.push_back(logit_bias[i]);
}
}
void llama_sampling_softmax_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) {
GGML_ASSERT(candidates->size > 0);

View file

@ -9,10 +9,22 @@ struct llama_sampling {
llama_sampling(const struct llama_vocab & vocab);
~llama_sampling();
const llama_vocab & vocab;
llama_sampling_params params;
std::string grammar_str;
std::string grammar_root;
std::string cfg_prompt;
float cfg_scale = 1.0f;
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
// state
std::mt19937 rng;
const struct llama_vocab & vocab;
struct llama_grammar * grammar = nullptr;
mutable int64_t t_total_us = 0;
@ -30,10 +42,13 @@ void llama_sampling_free_impl(struct llama_sampling * sampling);
struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl);
void llama_sampling_reset_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root);
void llama_sampling_reset_impl(struct llama_sampling & smpl);
// TODO: move the API below as member functions of llama_sampling
void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed);
void llama_sampling_set_rng_seed_impl (struct llama_sampling & smpl, uint32_t seed);
void llama_sampling_set_grammar_impl (struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root);
void llama_sampling_set_cfg_impl (struct llama_sampling & smpl, const char * cfg_prompt, float cfg_scale);
void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias);
void llama_sampling_softmax_impl (struct llama_sampling & smpl, llama_token_data_array * candidates);
void llama_sampling_top_k_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);

View file

@ -16511,12 +16511,6 @@ struct llama_sampling_params llama_sampling_default_params() {
/*.mirostat =*/ 0,
/*.mirostat_tau =*/ 5.00f,
/*.mirostat_eta =*/ 0.10f,
/*.grammar =*/ nullptr,
/*.grammar_root =*/ nullptr,
/*.cfg_prompt =*/ nullptr,
/*.cfg_scale =*/ 1.00f,
/*.n_logit_bias =*/ 0,
/*.logit_bias =*/ nullptr,
/*.penalize_nl =*/ false,
/*.ignore_eos =*/ false,
};
@ -19109,14 +19103,26 @@ struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl) {
return llama_sampling_cp_impl(*smpl);
}
void llama_sampling_reset(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) {
llama_sampling_reset_impl(*smpl, grammar_str, grammar_root);
void llama_sampling_reset(struct llama_sampling * smpl) {
llama_sampling_reset_impl(*smpl);
}
void llama_sampling_set_rng_seed(struct llama_sampling * smpl, uint32_t seed) {
llama_sampling_set_rng_seed_impl(*smpl, seed);
}
void llama_sampling_set_grammar(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) {
llama_sampling_set_grammar_impl(*smpl, grammar_str, grammar_root);
}
void llama_sampling_set_cfg(struct llama_sampling * smpl, const char * cfg_prompt, float cfg_scale) {
llama_sampling_set_cfg_impl(*smpl, cfg_prompt, cfg_scale);
}
void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) {
llama_sampling_set_logit_bias_impl(*smpl, n_logit_bias, logit_bias);
}
void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_total_us);
@ -19186,7 +19192,7 @@ void llama_sampling_repetition_penalties(
llama_sampling_repetition_penalties_impl(*smpl, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
}
void llama_sampling_apply_guidance(
void llama_sampling_cfg(
struct llama_sampling * smpl,
float * logits,
float * logits_guidance,