diff --git a/CMakeLists.txt b/CMakeLists.txt index 9d7c9d1ed..5fdbeddfc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -76,7 +76,7 @@ option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) # Compile flags # -set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD_REQUIRED true) set(CMAKE_C_STANDARD 11) set(CMAKE_C_STANDARD_REQUIRED true) diff --git a/examples/common.cpp b/examples/common.cpp index a4938b484..6c712c713 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -158,13 +158,13 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.mirostat = std::stoi(argv[i]); - } else if (arg == "--mirostat_eta") { + } else if (arg == "--mirostat_lr") { if (++i >= argc) { invalid_param = true; break; } params.mirostat_eta = std::stof(argv[i]); - } else if (arg == "--mirostat_tau") { + } else if (arg == "--mirostat_ent") { if (++i >= argc) { invalid_param = true; break; @@ -242,7 +242,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { char sign; std::string value_str; try { - if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-' || sign == '=' || sign == ':')) { + if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { params.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); } else { throw std::exception(); @@ -309,18 +309,21 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " --top_p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p); fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z); fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p); - fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n); + fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n); fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty); fprintf(stderr, " --presence_penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty); fprintf(stderr, " --frequency_penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty); - fprintf(stderr, " --mirostat N use mirostat sampling (default: %d, 0 = disabled, 1 = mirostat, 2 = mirostat 2.0)\n", params.mirostat); - fprintf(stderr, " --mirostat_eta N mirostat learning rate (default: %.1f)\n", (double)params.mirostat_eta); - fprintf(stderr, " --mirostat_tau N mirostat target entropy (default: %.1f)\n", (double)params.mirostat_tau); - fprintf(stderr, " -l TOKEN+BIAS, --logit-bias TOKEN+BIAS"); + fprintf(stderr, " --mirostat N use Mirostat sampling.\n"); + fprintf(stderr, " Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"); + fprintf(stderr, " (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat); + fprintf(stderr, " --mirostat_lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta); + fprintf(stderr, " --mirostat_ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau); + fprintf(stderr, " -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n"); fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n"); - fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello'\n"); + fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); + fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx); - fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2+-inf)\n"); + fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); fprintf(stderr, " --no-penalize-nl do not penalize newline token\n"); fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n"); fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index da9740057..674920b8a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -276,7 +276,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); } } - fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_eta = %f, mirostat_tau = %f\n", + fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); fprintf(stderr, "\n\n"); @@ -420,8 +420,8 @@ int main(int argc, char ** argv) { std::vector candidates; candidates.reserve(n_vocab); - for (size_t i = 0; i < (size_t) n_vocab; i++) { - candidates.emplace_back(i, logits[i], 0.0f); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); } llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; @@ -445,11 +445,12 @@ int main(int argc, char ** argv) { } else { if (mirostat == 1) { static float mirostat_mu = 2.0f * mirostat_tau; - static int mirostat_k = 40; const int mirostat_m = 100; - id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, float(n_vocab), &mirostat_k, &mirostat_mu); + llama_sample_temperature(ctx, &candidates_p, temp); + id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); } else if (mirostat == 2) { static float mirostat_mu = 2.0f * mirostat_tau; + llama_sample_temperature(ctx, &candidates_p, temp); id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); } else { // Temperature sampling diff --git a/llama.cpp b/llama.cpp index 2ec6d894a..5645a22e2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1710,12 +1710,6 @@ void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_dat } else { candidates->data[i].logit /= penalty; } - - // But it does not penalize tokens that logits are near zero, which is a problem. - // Another solution is to convert the logits to probabilities, apply the penalty, and then convert back to logits. - // float probability = std::exp(candidates[i].logit); - // probability /= penalty; - // candidates[i].logit = std::log(probability); } candidates->sorted = false; @@ -1757,9 +1751,9 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l } -llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float N, int * k, float * mu) { +llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) { assert(ctx); - + auto N = float(llama_n_vocab(ctx)); int64_t t_start_sample_us; t_start_sample_us = ggml_time_us(); @@ -1779,12 +1773,10 @@ llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_ // Compute k from the estimated s_hat and target surprise value float epsilon_hat = s_hat - 1; - float new_k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat); - *k = int(std::min(new_k, float(candidates->size))); + float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat); // Sample the next word X using top-k sampling - // printf("llama_sample_mirostat *k = %d\n", *k); - llama_sample_top_k(nullptr, candidates, *k); + llama_sample_top_k(nullptr, candidates, int(k)); if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } diff --git a/llama.h b/llama.h index 5f61971ce..fccce707e 100644 --- a/llama.h +++ b/llama.h @@ -189,37 +189,47 @@ extern "C" { // Sampling functions - /// @brief Repetition penalty - /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/pdf/1909.05858.pdf with negative logit fix + /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float penalty); - /// @brief Frequency and presence repetition penalties - /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details + + /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence); + /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates); + + /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1); + + /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1); - /// @brief Tail Free Sampling https://www.trentonbricken.com/Tail-Free-Sampling/ + /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1); - /// @brief Locally Typical Sampling https://arxiv.org/pdf/2202.00666.pdf + /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1); LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp); - /// @brief Mirostat implementation. /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. - /// @param ctx The llama context. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. - /// @param N The size of the vocabulary. This is used in the calculation of the `k` value. - /// @param k A reference to the integer variable used to store the calculated top-k value. The top-k value determines how many of the most probable tokens are considered for sampling. - /// @param mu A reference to the floating-point variable that represents the maximum cross-entropy value. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float N, int * k, float * mu); + /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu); + + /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. + /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu); + + /// @details Selects the token with the highest probability. LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates); + + /// @details Randomly selects a token from the candidates based on their probabilities. LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates); // Performance information diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 0a23c80c5..3f3d5d174 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -1,4 +1,3 @@ -#include "ggml.h" #include "llama.h" #include #include @@ -23,12 +22,12 @@ void test_top_k(const std::vector & probs, size_t n_vocab = probs.size(); std::vector candidates; candidates.reserve(n_vocab); - for (int i = 0; i < n_vocab; i++) { - float logit = log(probs[i]); - candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { + float logit = log(probs[token_id]); + candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_sample_softmax(nullptr, &candidates_p); // DUMP(&candidates_p); llama_sample_top_k(nullptr, &candidates_p, k); @@ -48,12 +47,12 @@ void test_top_p(const std::vector & probs, size_t n_vocab = probs.size(); std::vector candidates; candidates.reserve(n_vocab); - for (int i = 0; i < n_vocab; i++) { - float logit = log(probs[i]); - candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { + float logit = log(probs[token_id]); + candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; // DUMP(&candidates_p); llama_sample_top_p(nullptr, &candidates_p, p); // DUMP(&candidates_p); @@ -71,12 +70,12 @@ void test_tfs(const std::vector & probs, size_t n_vocab = probs.size(); std::vector candidates; candidates.reserve(n_vocab); - for (int i = 0; i < n_vocab; i++) { - float logit = log(probs[i]); - candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { + float logit = log(probs[token_id]); + candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; // DUMP(&candidates_p); llama_sample_tail_free(nullptr, &candidates_p, z); // DUMP(&candidates_p); @@ -94,12 +93,12 @@ void test_typical(const std::vector & probs, size_t n_vocab = probs.size(); std::vector candidates; candidates.reserve(n_vocab); - for (int i = 0; i < n_vocab; i++) { - float logit = log(probs[i]); - candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { + float logit = log(probs[token_id]); + candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; // DUMP(&candidates_p); llama_sample_typical(nullptr, &candidates_p, p); // DUMP(&candidates_p); @@ -121,12 +120,12 @@ void test_repetition_penalty( size_t n_vocab = probs.size(); std::vector candidates; candidates.reserve(n_vocab); - for (int i = 0; i < n_vocab; i++) { - float logit = log(probs[i]); - candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { + float logit = log(probs[token_id]); + candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_sample_softmax(nullptr, &candidates_p); DUMP(&candidates_p); llama_sample_repetition_penalty(nullptr, &candidates_p, (llama_token *)last_tokens.data(), last_tokens.size(), penalty); @@ -150,12 +149,12 @@ void test_frequency_presence_penalty( size_t n_vocab = probs.size(); std::vector candidates; candidates.reserve(n_vocab); - for (int i = 0; i < n_vocab; i++) { - float logit = log(probs[i]); - candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { + float logit = log(probs[token_id]); + candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_sample_softmax(nullptr, &candidates_p); // DUMP(&candidates_p); llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p, (llama_token *)last_tokens.data(), last_tokens.size(), alpha_frequency, alpha_presence); @@ -168,38 +167,6 @@ void test_frequency_presence_penalty( } } - -void test_mirostat() { - std::vector probs = {0.1, 0.2, 0.3, 0.4}; - std::vector expected_probs = {0.1, 0.2, 0.3, 0.4}; - - size_t n_vocab = probs.size(); - std::vector candidates; - candidates.reserve(n_vocab); - for (int i = 0; i < n_vocab; i++) { - float logit = log(probs[i]); - candidates.emplace_back(llama_token_data{i, logit, 0.0f}); - } - - llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; - DUMP(&candidates_p); - - float tau = 5.0f; - float mu = 2.0f * tau; - int k = 0; - float eta = 0.1f; - int m = 100; - // float N = 32000; - float N = 4; - // llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu); - DUMP(&candidates_p); - - // assert(candidates_p.size == expected_probs.size()); - // for (size_t i = 0; i < candidates_p.size; i++) { - // assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6); - // } -} - int main(void) { test_top_k({0.1, 0.2, 0.3, 0.4}, {0.4}, 1); test_top_k({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3, 0.2}, 3); @@ -223,7 +190,5 @@ int main(void) { test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0.499966, 0.499966, 0.000023, 0.000023, 0.000023}, 5.0, 5.0); test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0.499977, 0.499977, 0.000023, 0.000023, 0.000000}, 5.0, 5.0); - // test_mirostat(); - printf("OK\n"); }