From 6174762877639b6d9fbf3ee69afc441d762b8591 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 12 Aug 2024 19:24:12 +0300 Subject: [PATCH] cont : store params in llama_sampling implementation ggml-ci --- common/sampling.cpp | 20 +++++--------- include/llama.h | 20 +++++--------- src/llama-sampling.cpp | 63 ++++++++++++++++++++++++++++++++++-------- src/llama-sampling.h | 21 ++++++++++++-- src/llama.cpp | 24 ++++++++++------ 5 files changed, 98 insertions(+), 50 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 3082c0621..e05cb754c 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -32,24 +32,18 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_pa lp.penalize_nl = params.penalize_nl; lp.ignore_eos = params.ignore_eos; - lp.grammar = params.grammar.c_str(); - lp.grammar_root = "root"; - - lp.cfg_prompt = params.cfg_negative_prompt.c_str(); - lp.cfg_scale = params.cfg_scale; - - lp.n_logit_bias = params.logit_bias.size(); - lp.logit_bias = params.logit_bias.data(); - result->smpl = llama_sampling_init(model, lp); + + llama_sampling_set_rng_seed (result->smpl, params.seed); + llama_sampling_set_grammar (result->smpl, params.grammar.c_str(), "root"); + llama_sampling_set_cfg (result->smpl, params.cfg_negative_prompt.c_str(), params.cfg_scale); + llama_sampling_set_logit_bias(result->smpl, params.logit_bias.size(), params.logit_bias.data()); } result->prev.resize(params.n_prev); result->n_valid = 0; - llama_sampling_set_rng_seed(result->smpl, params.seed); - return result; } @@ -60,7 +54,7 @@ void llama_sampling_free(struct llama_sampling_context * ctx) { } void llama_sampling_reset(llama_sampling_context * ctx) { - llama_sampling_reset(ctx->smpl, ctx->params.grammar.c_str(), "root"); + llama_sampling_reset(ctx->smpl); std::fill(ctx->prev.begin(), ctx->prev.end(), 0); ctx->cur.clear(); @@ -378,7 +372,7 @@ static llama_token_data_array llama_sampling_prepare_impl( if (ctx_cfg) { float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); - llama_sampling_apply_guidance(smpl, logits, logits_guidance, params.cfg_scale); + llama_sampling_cfg(smpl, logits, logits_guidance, params.cfg_scale); } cur.resize(n_vocab); diff --git a/include/llama.h b/include/llama.h index 075a30b0e..3ffccf398 100644 --- a/include/llama.h +++ b/include/llama.h @@ -384,16 +384,6 @@ extern "C" { float mirostat_tau; // target entropy float mirostat_eta; // learning rate - // https://github.com/ggerganov/llama.cpp/pull/1773 - const char * grammar; - const char * grammar_root; - - const char * cfg_prompt; // string to help guidance in negative direction - float cfg_scale; // how strong is guidance - - int32_t n_logit_bias; - const llama_logit_bias * logit_bias; - // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. bool penalize_nl; // consider newlines as a repeatable token bool ignore_eos; // ignore the end-of-sequence token @@ -1020,10 +1010,14 @@ extern "C" { LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl); - LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root); + //LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root); + LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl); // Sets the current rng seed. - LLAMA_API void llama_sampling_set_rng_seed(struct llama_sampling * smpl, uint32_t seed); + LLAMA_API void llama_sampling_set_rng_seed (struct llama_sampling * smpl, uint32_t seed); + LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root); + LLAMA_API void llama_sampling_set_cfg (struct llama_sampling * smpl, const char * cfg_prompt, float cfg_scale); + LLAMA_API void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias); /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. LLAMA_API void llama_sampling_softmax( @@ -1098,7 +1092,7 @@ extern "C" { /// @param logits Logits extracted from the original generation context. /// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. /// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. - LLAMA_API void llama_sampling_apply_guidance( + LLAMA_API void llama_sampling_cfg( struct llama_sampling * smpl, float * logits, float * logits_guidance, diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 1a248a8ee..5ba41e949 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -36,11 +36,9 @@ llama_sampling::~llama_sampling() { struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params) { auto * result = new llama_sampling(vocab); - // TODO: store params + result->params = params; - if (params.grammar != nullptr && params.grammar[0] != '\0') { - result->grammar = llama_grammar_init_impl(result->vocab, params.grammar, params.grammar_root); - } + llama_sampling_set_rng_seed_impl(*result, params.seed); return result; } @@ -52,6 +50,16 @@ void llama_sampling_free_impl(struct llama_sampling * sampling) { struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl) { auto * result = new llama_sampling(smpl.vocab); + result->params = smpl.params; + + result->grammar_str = smpl.grammar_str; + result->grammar_root = smpl.grammar_root; + + result->cfg_prompt = smpl.cfg_prompt; + result->cfg_scale = smpl.cfg_scale; + + result->logit_bias = smpl.logit_bias; + if (smpl.grammar) { result->grammar = llama_grammar_copy_impl(*smpl.grammar); } @@ -59,19 +67,14 @@ struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smp return result; } -void llama_sampling_reset_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root) { - // TODO: this is dumb, need to fix - const struct llama_vocab * vocab = nullptr; - +void llama_sampling_reset_impl(struct llama_sampling & smpl) { if (smpl.grammar) { - vocab = &smpl.grammar->vocab; - llama_grammar_free_impl(smpl.grammar); smpl.grammar = nullptr; } - if (grammar_str != nullptr && grammar_str[0] != '\0') { - smpl.grammar = llama_grammar_init_impl(*vocab, grammar_str, grammar_root); + if (!smpl.grammar_str.empty()) { + smpl.grammar = llama_grammar_init_impl(smpl.vocab, smpl.grammar_str.data(), smpl.grammar_root.data()); } } @@ -83,6 +86,42 @@ void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t see smpl.rng.seed(seed); } +void llama_sampling_set_grammar_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root) { + if (smpl.grammar) { + llama_grammar_free_impl(smpl.grammar); + smpl.grammar = nullptr; + } + + if (grammar_str != nullptr && grammar_str[0] != '\0') { + smpl.grammar_str = grammar_str; + smpl.grammar_root = grammar_root; + + smpl.grammar = llama_grammar_init_impl(smpl.vocab, grammar_str, grammar_root); + } else { + smpl.grammar_str.clear(); + smpl.grammar_root.clear(); + } +} + +void llama_sampling_set_cfg_impl(struct llama_sampling & smpl, const char * cfg_prompt, float cfg_scale) { + if (cfg_prompt != nullptr && cfg_prompt[0] != '\0') { + smpl.cfg_prompt = cfg_prompt; + } else { + smpl.cfg_prompt.clear(); + } + + smpl.cfg_scale = cfg_scale; +} + +void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { + smpl.logit_bias.clear(); + smpl.logit_bias.reserve(n_logit_bias); + + for (int32_t i = 0; i < n_logit_bias; ++i) { + smpl.logit_bias.push_back(logit_bias[i]); + } +} + void llama_sampling_softmax_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) { GGML_ASSERT(candidates->size > 0); diff --git a/src/llama-sampling.h b/src/llama-sampling.h index a5af5e257..5d142b2ce 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -9,10 +9,22 @@ struct llama_sampling { llama_sampling(const struct llama_vocab & vocab); ~llama_sampling(); - const llama_vocab & vocab; + llama_sampling_params params; + + std::string grammar_str; + std::string grammar_root; + + std::string cfg_prompt; + float cfg_scale = 1.0f; + + std::vector logit_bias; // logit biases to apply + + // state std::mt19937 rng; + const struct llama_vocab & vocab; + struct llama_grammar * grammar = nullptr; mutable int64_t t_total_us = 0; @@ -30,10 +42,13 @@ void llama_sampling_free_impl(struct llama_sampling * sampling); struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl); -void llama_sampling_reset_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root); +void llama_sampling_reset_impl(struct llama_sampling & smpl); // TODO: move the API below as member functions of llama_sampling -void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed); +void llama_sampling_set_rng_seed_impl (struct llama_sampling & smpl, uint32_t seed); +void llama_sampling_set_grammar_impl (struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root); +void llama_sampling_set_cfg_impl (struct llama_sampling & smpl, const char * cfg_prompt, float cfg_scale); +void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias); void llama_sampling_softmax_impl (struct llama_sampling & smpl, llama_token_data_array * candidates); void llama_sampling_top_k_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep); diff --git a/src/llama.cpp b/src/llama.cpp index 25f6ec01e..46a5340df 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -16511,12 +16511,6 @@ struct llama_sampling_params llama_sampling_default_params() { /*.mirostat =*/ 0, /*.mirostat_tau =*/ 5.00f, /*.mirostat_eta =*/ 0.10f, - /*.grammar =*/ nullptr, - /*.grammar_root =*/ nullptr, - /*.cfg_prompt =*/ nullptr, - /*.cfg_scale =*/ 1.00f, - /*.n_logit_bias =*/ 0, - /*.logit_bias =*/ nullptr, /*.penalize_nl =*/ false, /*.ignore_eos =*/ false, }; @@ -19109,14 +19103,26 @@ struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl) { return llama_sampling_cp_impl(*smpl); } -void llama_sampling_reset(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) { - llama_sampling_reset_impl(*smpl, grammar_str, grammar_root); +void llama_sampling_reset(struct llama_sampling * smpl) { + llama_sampling_reset_impl(*smpl); } void llama_sampling_set_rng_seed(struct llama_sampling * smpl, uint32_t seed) { llama_sampling_set_rng_seed_impl(*smpl, seed); } +void llama_sampling_set_grammar(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) { + llama_sampling_set_grammar_impl(*smpl, grammar_str, grammar_root); +} + +void llama_sampling_set_cfg(struct llama_sampling * smpl, const char * cfg_prompt, float cfg_scale) { + llama_sampling_set_cfg_impl(*smpl, cfg_prompt, cfg_scale); +} + +void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { + llama_sampling_set_logit_bias_impl(*smpl, n_logit_bias, logit_bias); +} + void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_total_us); @@ -19186,7 +19192,7 @@ void llama_sampling_repetition_penalties( llama_sampling_repetition_penalties_impl(*smpl, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present); } -void llama_sampling_apply_guidance( +void llama_sampling_cfg( struct llama_sampling * smpl, float * logits, float * logits_guidance,