From d352b01af94149f83cba1d5515a1a891a76f5165 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 12 Aug 2024 17:14:56 +0300 Subject: [PATCH] cont : llama_sampling_init() use llama_sampling_params ggml-ci --- common/sampling.cpp | 38 ++++++++++++++++++- common/sampling.h | 2 +- examples/batched.swift/Sources/main.swift | 2 +- examples/batched/batched.cpp | 2 +- examples/gritlm/gritlm.cpp | 2 +- .../llama.cpp.swift/LibLlama.swift | 2 +- examples/passkey/passkey.cpp | 2 +- examples/save-load-state/save-load-state.cpp | 6 +-- examples/simple/simple.cpp | 2 +- include/llama.h | 16 +++++--- src/llama-sampling.cpp | 26 +++++++------ src/llama-sampling.h | 8 ++-- src/llama-vocab.cpp | 12 ++++++ src/llama-vocab.h | 2 + src/llama.cpp | 10 +++-- tests/test-backend-ops.cpp | 1 - tests/test-sampling.cpp | 15 ++++---- 17 files changed, 103 insertions(+), 45 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index a4f250d99..3082c0621 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -6,7 +6,43 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_pa struct llama_sampling_context * result = new llama_sampling_context(); result->params = params; - result->smpl = llama_sampling_init(model, params.grammar.c_str(), "root"); + + { + auto lp = llama_sampling_default_params(); + + lp.seed = params.seed; + lp.n_prev = params.n_prev; + lp.n_probs = params.n_probs; + lp.min_keep = params.min_keep; + lp.top_k = params.top_k; + lp.top_p = params.top_p; + lp.min_p = params.min_p; + lp.tfs_z = params.tfs_z; + lp.typical_p = params.typical_p; + lp.temp = params.temp; + lp.dynatemp_range = params.dynatemp_range; + lp.dynatemp_exponent = params.dynatemp_exponent; + lp.penalty_last_n = params.penalty_last_n; + lp.penalty_repeat = params.penalty_repeat; + lp.penalty_freq = params.penalty_freq; + lp.penalty_present = params.penalty_present; + lp.mirostat = params.mirostat; + lp.mirostat_tau = params.mirostat_tau; + lp.mirostat_eta = params.mirostat_eta; + lp.penalize_nl = params.penalize_nl; + lp.ignore_eos = params.ignore_eos; + + lp.grammar = params.grammar.c_str(); + lp.grammar_root = "root"; + + lp.cfg_prompt = params.cfg_negative_prompt.c_str(); + lp.cfg_scale = params.cfg_scale; + + lp.n_logit_bias = params.logit_bias.size(); + lp.logit_bias = params.logit_bias.data(); + + result->smpl = llama_sampling_init(model, lp); + } result->prev.resize(params.n_prev); diff --git a/common/sampling.h b/common/sampling.h index 79b8f5020..74020d52b 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -17,6 +17,7 @@ enum class llama_sampler_type : char { // sampling parameters typedef struct gpt_sampling_params { + uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context int32_t n_prev = 64; // number of previous tokens to remember int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens @@ -37,7 +38,6 @@ typedef struct gpt_sampling_params { float mirostat_eta = 0.10f; // learning rate bool penalize_nl = false; // consider newlines as a repeatable token bool ignore_eos = false; - uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context std::vector samplers_sequence = { llama_sampler_type::TOP_K, diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 72b4b43f1..c87ca03ff 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -50,7 +50,7 @@ defer { llama_free(context) } -let smpl = llama_sampling_init(model, nil, nil) +let smpl = llama_sampling_init(model, llama_sampling_default_params()) guard smpl != nil else { print("Failed to initialize sampling") exit(1) diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 02d98a3a3..7e76bc9a8 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -64,7 +64,7 @@ int main(int argc, char ** argv) { ctx_params.n_batch = std::max(n_predict, n_parallel); llama_context * ctx = llama_new_context_with_model(model, ctx_params); - llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr); + llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params()); if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index e89819d2b..8a201fbb1 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -172,7 +172,7 @@ int main(int argc, char * argv[]) { // create generation context llama_context * ctx = llama_new_context_with_model(model, cparams); - llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr); + llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params()); // ### Embedding/Representation ### // samples taken from: https://github.com/ContextualAI/gritlm#basic diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 5b63f5ac4..17b9aa1cf 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -43,7 +43,7 @@ actor LlamaContext { self.tokens_list = [] self.batch = llama_batch_init(512, 0, 1) self.temporary_invalid_cchars = [] - self.sampling = llama_sampling_init(context, nil, nil); + self.sampling = llama_sampling_init(context, llama_sampling_default_params()) } deinit { diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index e9e4de260..03c0ea997 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -83,7 +83,7 @@ int main(int argc, char ** argv) { return 1; } - llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr); + llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params()); // tokenize the prompt std::vector tokens_list; diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index a62c0f294..5ccec0eee 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -37,7 +37,7 @@ int main(int argc, char ** argv) { return 1; } - llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr); + llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params()); // tokenize prompt auto tokens = llama_tokenize(ctx, params.prompt, true); @@ -97,7 +97,7 @@ int main(int argc, char ** argv) { // make new context auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); - llama_sampling * smpl2 = llama_sampling_init(model, nullptr, nullptr); + llama_sampling * smpl2 = llama_sampling_init(model, llama_sampling_default_params()); printf("\nsecond run: %s", params.prompt.c_str()); @@ -162,7 +162,7 @@ int main(int argc, char ** argv) { // make new context auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); - llama_sampling * smpl3 = llama_sampling_init(model, nullptr, nullptr); + llama_sampling * smpl3 = llama_sampling_init(model, llama_sampling_default_params()); printf("\nsingle seq run: %s", params.prompt.c_str()); diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index cf5e74002..8ec078579 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -55,7 +55,7 @@ int main(int argc, char ** argv) { return 1; } - llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr); + llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params()); // tokenize the prompt diff --git a/include/llama.h b/include/llama.h index a9db3eebf..075a30b0e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -328,7 +328,8 @@ extern "C" { enum ggml_type type_k; // data type for K cache [EXPERIMENTAL] enum ggml_type type_v; // data type for V cache [EXPERIMENTAL] - // Keep the booleans together to avoid misalignment during copy-by-value. + // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. + // TODO: move at the end of the struct bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU @@ -382,14 +383,20 @@ extern "C" { int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float mirostat_tau; // target entropy float mirostat_eta; // learning rate - bool penalize_nl; // consider newlines as a repeatable token - bool ignore_eos; // ignore the end-of-sequence token + // https://github.com/ggerganov/llama.cpp/pull/1773 const char * grammar; const char * grammar_root; + const char * cfg_prompt; // string to help guidance in negative direction + float cfg_scale; // how strong is guidance + int32_t n_logit_bias; const llama_logit_bias * logit_bias; + + // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. + bool penalize_nl; // consider newlines as a repeatable token + bool ignore_eos; // ignore the end-of-sequence token } llama_sampling_params; // performance timing information @@ -1006,9 +1013,8 @@ extern "C" { // Sampling functions // - // TODO: args become llama_sampling_params // TODO: llama_model should become llama_vocab - LLAMA_API struct llama_sampling * llama_sampling_init(const struct llama_model * model, const char * grammar_str, const char * grammar_root); + LLAMA_API struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params); LLAMA_API void llama_sampling_free(struct llama_sampling * smpl); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 240631fac..1a248a8ee 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -24,13 +24,7 @@ static void llama_log_softmax(float * array, size_t size) { } } -llama_sampling::llama_sampling(uint32_t n_vocab) : n_vocab(n_vocab) { -} - -llama_sampling::llama_sampling(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) : n_vocab(vocab.n_vocab) { - if (grammar_str != nullptr && grammar_str[0] != '\0') { - grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root); - } +llama_sampling::llama_sampling(const struct llama_vocab & vocab) : vocab(vocab) { } llama_sampling::~llama_sampling() { @@ -39,8 +33,16 @@ llama_sampling::~llama_sampling() { } } -struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) { - return new llama_sampling(vocab, grammar_str, grammar_root); +struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params) { + auto * result = new llama_sampling(vocab); + + // TODO: store params + + if (params.grammar != nullptr && params.grammar[0] != '\0') { + result->grammar = llama_grammar_init_impl(result->vocab, params.grammar, params.grammar_root); + } + + return result; } void llama_sampling_free_impl(struct llama_sampling * sampling) { @@ -48,7 +50,7 @@ void llama_sampling_free_impl(struct llama_sampling * sampling) { } struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl) { - auto * result = new llama_sampling(smpl.n_vocab); + auto * result = new llama_sampling(smpl.vocab); if (smpl.grammar) { result->grammar = llama_grammar_copy_impl(*smpl.grammar); @@ -493,7 +495,7 @@ void llama_sampling_apply_guidance_impl( float * logits, float * logits_guidance, float scale) { - const auto n_vocab = smpl.n_vocab; + const auto n_vocab = smpl.vocab.n_vocab; llama_log_softmax(logits, n_vocab); llama_log_softmax(logits_guidance, n_vocab); @@ -507,7 +509,7 @@ void llama_sampling_apply_guidance_impl( } llama_token llama_sampling_sample_mirostat_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { - const int32_t n_vocab = float(smpl.n_vocab); + const float n_vocab = float(smpl.vocab.n_vocab); llama_sampling_softmax_impl(smpl, candidates); diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 8ee593c72..a5af5e257 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -1,17 +1,15 @@ #pragma once -#include "llama-impl.h" #include "llama-grammar.h" struct llama_vocab; struct llama_grammar; struct llama_sampling { - llama_sampling(uint32_t n_vocab); - llama_sampling(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root); + llama_sampling(const struct llama_vocab & vocab); ~llama_sampling(); - const uint32_t n_vocab; + const llama_vocab & vocab; std::mt19937 rng; @@ -26,7 +24,7 @@ struct llama_sampling { // internal API // -struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root); +struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params); void llama_sampling_free_impl(struct llama_sampling * sampling); diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 749f85718..69b6a9d3b 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -93,6 +93,18 @@ int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string return it->second; } +llama_vocab llama_vocab::for_tests(uint32_t n_vocab) { + llama_vocab vocab; + vocab.n_vocab = n_vocab; + vocab.token_to_id.reserve(n_vocab); + vocab.id_to_token.reserve(n_vocab); + for (uint32_t i = 0; i < n_vocab; i++) { + vocab.token_to_id[format("token_%u", i)] = i; + vocab.id_to_token.push_back({ format("token_%u", i), 0.0f, LLAMA_TOKEN_ATTR_NORMAL }); + } + return vocab; +} + static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) { return vocab.type; } diff --git a/src/llama-vocab.h b/src/llama-vocab.h index f48bbe1bf..20da3bb82 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -62,6 +62,8 @@ struct llama_vocab { std::vector precompiled_charsmap; int find_bpe_rank(const std::string & token_left, const std::string & token_right) const; + + static llama_vocab for_tests(uint32_t n_vocab); }; // diff --git a/src/llama.cpp b/src/llama.cpp index 773008c71..25f6ec01e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -16511,12 +16511,14 @@ struct llama_sampling_params llama_sampling_default_params() { /*.mirostat =*/ 0, /*.mirostat_tau =*/ 5.00f, /*.mirostat_eta =*/ 0.10f, - /*.penalize_nl =*/ false, - /*.ignore_eos =*/ false, /*.grammar =*/ nullptr, /*.grammar_root =*/ nullptr, + /*.cfg_prompt =*/ nullptr, + /*.cfg_scale =*/ 1.00f, /*.n_logit_bias =*/ 0, /*.logit_bias =*/ nullptr, + /*.penalize_nl =*/ false, + /*.ignore_eos =*/ false, }; return result; @@ -19091,8 +19093,8 @@ int32_t llama_chat_apply_template( // sampling // -struct llama_sampling * llama_sampling_init(const struct llama_model * model, const char * grammar_str, const char * grammar_root) { - return llama_sampling_init_impl(model->vocab, grammar_str, grammar_root); +struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params) { + return llama_sampling_init_impl(model->vocab, params); } void llama_sampling_free(struct llama_sampling * smpl) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2f4117a62..1c58bc01d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2470,7 +2470,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } GGML_ABORT("fatal error"); - return false; } static void usage(char ** argv) { diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 45dcb21ad..371b7e511 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -1,5 +1,6 @@ #include "ggml.h" #include "llama.h" +#include "llama-vocab.h" #include "llama-sampling.h" #ifdef NDEBUG @@ -21,7 +22,7 @@ static void dump(const llama_token_data_array * candidates) { static void test_top_k(const std::vector & probs, const std::vector & expected_probs, int k) { const size_t n_vocab = probs.size(); - llama_sampling smpl(n_vocab); + llama_sampling smpl(llama_vocab::for_tests(n_vocab)); std::vector candidates; candidates.reserve(n_vocab); @@ -44,7 +45,7 @@ static void test_top_k(const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); - llama_sampling smpl(n_vocab); + llama_sampling smpl(llama_vocab::for_tests(n_vocab)); std::vector candidates; candidates.reserve(n_vocab); @@ -67,7 +68,7 @@ static void test_top_p(const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float z) { const size_t n_vocab = probs.size(); - llama_sampling smpl(n_vocab); + llama_sampling smpl(llama_vocab::for_tests(n_vocab)); std::vector candidates; candidates.reserve(n_vocab); @@ -89,7 +90,7 @@ static void test_tfs(const std::vector & probs, const std::vector static void test_min_p(const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); - llama_sampling smpl(n_vocab); + llama_sampling smpl(llama_vocab::for_tests(n_vocab)); std::vector candidates; candidates.reserve(n_vocab); @@ -112,7 +113,7 @@ static void test_min_p(const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); - llama_sampling smpl(n_vocab); + llama_sampling smpl(llama_vocab::for_tests(n_vocab)); std::vector candidates; candidates.reserve(n_vocab); @@ -139,7 +140,7 @@ static void test_repetition_penalties( GGML_ASSERT(probs.size() == expected_probs.size()); const size_t n_vocab = probs.size(); - llama_sampling smpl(n_vocab); + llama_sampling smpl(llama_vocab::for_tests(n_vocab)); std::vector candidates; candidates.reserve(n_vocab); @@ -163,7 +164,7 @@ static void test_repetition_penalties( static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p ) { - llama_sampling smpl(n_vocab); + llama_sampling smpl(llama_vocab::for_tests(n_vocab)); std::vector candidates; candidates.reserve(n_vocab);