From f01c67fe55d4c48b7903394416303aafc20e3f3b Mon Sep 17 00:00:00 2001 From: Ivan Stepanov Date: Sat, 22 Apr 2023 21:23:10 +0300 Subject: [PATCH] mirostat --- examples/common.cpp | 37 +++++-- examples/common.h | 17 ++- examples/main/main.cpp | 53 +++++----- llama.cpp | 158 ++++++++++++++++++++------- llama.h | 20 ++-- tests/CMakeLists.txt | 2 - tests/test-sampling.cpp | 229 ++++++++++++++++++++++++++++++++++++++++ 7 files changed, 425 insertions(+), 91 deletions(-) create mode 100644 tests/test-sampling.cpp diff --git a/examples/common.cpp b/examples/common.cpp index a8f57360a..7e62be356 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -150,6 +150,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.alpha_presence = std::stof(argv[i]); + } else if (arg == "--mirostat") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.mirostat = std::stoi(argv[i]); + } else if (arg == "--mirostat_eta") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.mirostat_eta = std::stof(argv[i]); + } else if (arg == "--mirostat_tau") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.mirostat_tau = std::stof(argv[i]); } else if (arg == "-b" || arg == "--batch_size") { if (++i >= argc) { invalid_param = true; @@ -264,14 +282,17 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -f FNAME, --file FNAME\n"); fprintf(stderr, " prompt file to start generation.\n"); fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict); - fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k); - fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", (double)params.top_p); - fprintf(stderr, " --tfs N tail free sampling (default: %.1f)\n", (double)params.tfs_z); - fprintf(stderr, " --typical N locally typical sampling (default: %.1f)\n", (double)params.typical_p); - fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %d)\n", params.alpha_presence); - fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f)\n", (double)params.alpha_frequency); - fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n); - fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty); + fprintf(stderr, " --top_k N top-k sampling (default: %d, disabled: 0)\n", params.top_k); + fprintf(stderr, " --top_p N top-p sampling (default: %.1f, disabled: 1.0)\n", (double)params.top_p); + fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, disabled: 1.0)\n", (double)params.tfs_z); + fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, disabled: 1.0)\n", (double)params.typical_p); + fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, disabled: 0)\n", params.repeat_last_n); + fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f, disabled: 1.0)\n", (double)params.repeat_penalty); + fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %.1f, disabled: 0.0)\n", (double)params.alpha_presence); + fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f, disabled: 0.0)\n", (double)params.alpha_frequency); + fprintf(stderr, " --mirostat N use mirostat sampling (default: %d, disabled: 0, mirostat: 1, mirostat 2.0: 2)\n", params.mirostat); + fprintf(stderr, " --mirostat_eta N mirostat learning rate (default: %.1f)\n", (double)params.mirostat_eta); + fprintf(stderr, " --mirostat_tau N mirostat target entropy (default: %.1f)\n", (double)params.mirostat_tau); fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n"); fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n"); diff --git a/examples/common.h b/examples/common.h index 9d3697d79..de25e6435 100644 --- a/examples/common.h +++ b/examples/common.h @@ -17,17 +17,24 @@ struct gpt_params { int32_t seed = -1; // RNG seed int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t n_predict = 128; // new tokens to predict - int32_t repeat_last_n = 64; // last n tokens to penalize int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions) int32_t n_ctx = 512; // context size int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt // sampling parameters - int32_t top_k = 40; - float top_p = 0.95f; - float temp = 0.80f; - float repeat_penalty = 1.10f; + int32_t top_k = 0; // <= 0 to use vocab size + float top_p = 1.0f; // 1.0 = disabled + float tfs_z = 1.0f; // 1.0 = disabled + float typical_p = 1.0f; // 1.0 = disabled + float temp = 1.0f; // 1.0 = disabled + float repeat_penalty = 1.0f; // 1.0 = disabled + int32_t repeat_last_n = -1; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float alpha_frequency = 0.0f; // 0.0 = disabled + float alpha_presence = 0.0f; // 0.0 = disabled + int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.0f; // target entropy + float mirostat_eta = 0.1f; // learning rate std::string model = "models/lamma-7B/ggml-model.bin"; // model path std::string prompt = ""; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 9b795bd3a..a6de98fed 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -276,8 +276,8 @@ int main(int argc, char ** argv) { fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); } } - fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f\n", - params.repeat_last_n, params.repeat_penalty, params.alpha_presence, params.alpha_frequency, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp); + fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_eta = %f, mirostat_tau = %f\n", + params.repeat_last_n, params.repeat_penalty, params.alpha_presence, params.alpha_frequency, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); fprintf(stderr, "\n\n"); @@ -396,6 +396,9 @@ int main(int argc, char ** argv) { const float repeat_penalty = params.repeat_penalty; const float alpha_presence = params.alpha_presence; const float alpha_frequency = params.alpha_frequency; + const int mirostat = params.mirostat; + const float mirostat_tau = params.mirostat_tau; + const float mirostat_eta = params.mirostat_eta; // optionally save the session on first sample (for faster prompt loading next time) if (!path_session.empty() && need_to_save_session) { @@ -415,47 +418,45 @@ int main(int argc, char ** argv) { std::vector candidates; candidates.reserve(n_vocab); - for (size_t i = 0; i < n_vocab; i++) { + for (size_t i = 0; i < (size_t) n_vocab; i++) { candidates.emplace_back(i, logits[i], 0.0f); } - llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; // Apply penalties auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); - llama_sample_repetition_penalty(&candidates_p, + llama_sample_repetition_penalty(ctx, &candidates_p, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, last_n_repeat, repeat_penalty); - llama_sample_frequency_and_presence_penalties(&candidates_p, + llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, last_n_repeat, alpha_frequency, alpha_presence); -#if 1 if (temp <= 0) { // Greedy sampling id = llama_sample_token_greedy(ctx, &candidates_p); } else { - // Temperature sampling - llama_sample_top_k(&candidates_p, top_k); - llama_sample_tail_free(&candidates_p, tfs_z); - llama_sample_typical(&candidates_p, typical_p); - llama_sample_top_p(&candidates_p, top_p); - - llama_sample_temperature(&candidates_p, temp); - // printf("`%d`", candidates_p.size); - id = llama_sample_token(ctx, &candidates_p); + if (mirostat == 1) { + static float mirostat_mu = 2.0f * mirostat_tau; + static int mirostat_k = 40; + const int mirostat_m = 100; + id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, float(n_vocab), &mirostat_k, &mirostat_mu); + } else if (mirostat == 2) { + static float mirostat_mu = 2.0f * mirostat_tau; + id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); + } else { + // Temperature sampling + llama_sample_top_k(ctx, &candidates_p, top_k); + llama_sample_tail_free(ctx, &candidates_p, tfs_z); + llama_sample_typical(ctx, &candidates_p, typical_p); + llama_sample_top_p(ctx, &candidates_p, top_p); + llama_sample_temperature(ctx, &candidates_p, temp); + id = llama_sample_token(ctx, &candidates_p); + } } -#else - const float tau = 5.0f; - static float mu = 2.0f * tau; - static int k = 40; - const float eta = 0.1f; - const int m = 100; - const float N = n_vocab; - id = llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu); - // id = llama_sample_mirostat_v2(ctx, &candidates_p, tau, eta, &mu); -#endif + // printf("`%d`", candidates_p.size); last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); diff --git a/llama.cpp b/llama.cpp index 64debd715..4da4df1f2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1479,8 +1479,11 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co // sampling // -void llama_sample_softmax(llama_token_data_array * candidates) { +void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { assert(candidates->size > 0); + + const int64_t t_start_sample_us = ggml_time_us(); + std::span tokens(candidates->data, candidates->size); // Sort the logits in descending order @@ -1502,35 +1505,47 @@ void llama_sample_softmax(llama_token_data_array * candidates) { for (size_t i = 0; i < tokens.size(); ++i) { tokens[i].p /= cum_sum; } + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } } -void llama_sample_top_k(llama_token_data_array * candidates_p, int k) { - assert(k > 0); +void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates_p, int k, size_t min_keep) { + const int64_t t_start_sample_us = ggml_time_us(); + + k = std::max(k, (int) min_keep); + k = std::min(k, (int) candidates_p->size); + std::span candidates(candidates_p->data, candidates_p->size); // Sort scores in descending order if (!candidates_p->sorted) { - if (k >= candidates_p->size) { - std::sort(candidates.begin(), candidates.end(), [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); + auto comp = [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }; + if (k == (int) candidates_p->size) { + std::sort(candidates.begin(), candidates.end(), comp); } else { - std::partial_sort(candidates.begin(), candidates.begin() + k, candidates.end(), - [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); + std::partial_sort(candidates.begin(), candidates.begin() + k, candidates.end(), comp); } candidates_p->sorted = true; } - candidates_p->size = std::min(k, (int) candidates.size()); + candidates_p->size = k; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } } -void llama_sample_top_p(llama_token_data_array * candidates_p, float p, size_t min_keep) { +void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates_p, float p, size_t min_keep) { if (p >= 1.0f) { return; } - llama_sample_softmax(candidates_p); + const int64_t t_start_sample_us = ggml_time_us(); + + llama_sample_softmax(ctx, candidates_p); // Compute the cumulative probabilities float cum_sum = 0.0f; @@ -1548,15 +1563,21 @@ void llama_sample_top_p(llama_token_data_array * candidates_p, float p, size_t m // Resize the output vector to keep only the top-p tokens candidates_p->size = last_idx; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } } // https://www.trentonbricken.com/Tail-Free-Sampling/ -void llama_sample_tail_free(llama_token_data_array * candidates_p, float z, size_t min_keep) { +void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates_p, float z, size_t min_keep) { if (z >= 1.0f || candidates_p->size <= 2) { return; } - llama_sample_softmax(candidates_p); + const int64_t t_start_sample_us = ggml_time_us(); + + llama_sample_softmax(nullptr, candidates_p); // Compute the first and second derivatives std::vector first_derivatives(candidates_p->size - 1); @@ -1594,18 +1615,23 @@ void llama_sample_tail_free(llama_token_data_array * candidates_p, float z, size // Resize the output vector to keep only the tokens above the tail location candidates_p->size = last_idx; -} + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} // https://arxiv.org/pdf/2202.00666.pdf // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr -void llama_sample_typical(llama_token_data_array * candidates_p, float typical_p, size_t min_keep) { +void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates_p, float typical_p, size_t min_keep) { if (typical_p >= 1.0f) { return; } + const int64_t t_start_sample_us = ggml_time_us(); + // Compute the softmax of logits and calculate entropy - llama_sample_softmax(candidates_p); + llama_sample_softmax(nullptr, candidates_p); std::span candidates(candidates_p->data, candidates_p->size); @@ -1654,21 +1680,32 @@ void llama_sample_typical(llama_token_data_array * candidates_p, float typical_p // Replace the data in candidates_p with the new_candidates data std::copy(new_candidates.begin(), new_candidates.end(), candidates_p->data); candidates_p->size = new_candidates.size(); + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } } +void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { + const int64_t t_start_sample_us = ggml_time_us(); -void llama_sample_temperature(llama_token_data_array * candidates_p, float temp) { std::span candidates(candidates_p->data, candidates_p->size); for (auto & candidate : candidates) { candidate.logit /= temp; } + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } } -void llama_sample_repetition_penalty(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty) { +void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty) { if (last_tokens_size == 0 || penalty == 1.0f) { return; } + const int64_t t_start_sample_us = ggml_time_us(); + // CTRL paper: https://arxiv.org/pdf/1909.05858.pdf std::span candidates(candidates_p->data, candidates_p->size); std::span last_tokens(last_tokens_p, last_tokens_size); @@ -1695,13 +1732,19 @@ void llama_sample_repetition_penalty(llama_token_data_array * candidates_p, llam } candidates_p->sorted = false; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } } -void llama_sample_frequency_and_presence_penalties(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) { +void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) { if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) { return; } + const int64_t t_start_sample_us = ggml_time_us(); + std::span candidates(candidates_p->data, candidates_p->size); std::span last_tokens(last_tokens_p, last_tokens_size); @@ -1723,6 +1766,10 @@ void llama_sample_frequency_and_presence_penalties(llama_token_data_array * cand } candidates_p->sorted = false; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } } /// @brief Mirostat 1.0 implementation. @@ -1733,20 +1780,26 @@ void llama_sample_frequency_and_presence_penalties(llama_token_data_array * cand /// @param N The size of the vocabulary. This is used in the calculation of the `k` value. /// @param k A reference to the integer variable used to store the calculated top-k value. The top-k value determines how many of the most probable tokens are considered for sampling. /// @param mu A reference to the floating-point variable that represents the maximum cross-entropy value. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. -llama_token llama_sample_mirostat(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, int m, float N, int * k, float * mu) { +llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, int m, float N, int * k, float * mu) { + assert(ctx); + + int64_t t_start_sample_us; + t_start_sample_us = ggml_time_us(); + // https://arxiv.org/abs/2007.14966 + // Algorithm 1 std::span candidates(candidates_p->data, candidates_p->size); // printf("llama_sample_mirostat: candidates.size() = %d, m = %d, N = %f, tau = %f, eta = %f, *k = %d, *mu = %f\n", candidates.size(), m, N, tau, eta, *k, *mu); - llama_sample_softmax(candidates_p); + llama_sample_softmax(nullptr, candidates_p); // Estimate s_hat using the most probable m tokens float s_hat = 0.0; float sum_ti_bi = 0.0; float sum_ti_sq = 0.0; - for (int i = 0; i < m - 1 && i < candidates.size() - 1; ++i) { - float t_i = logf((i + 2) / float(i + 1)); + for (size_t i = 0; i < size_t(m - 1) && i < candidates.size() - 1; ++i) { + float t_i = logf(float(i + 2) / float(i + 1)); float b_i = logf(candidates[i].p / candidates[i + 1].p); sum_ti_bi += t_i * b_i; sum_ti_sq += t_i * t_i; @@ -1757,15 +1810,20 @@ llama_token llama_sample_mirostat(struct llama_context * ctx, llama_token_data_a float epsilon_hat = s_hat - 1; // printf("llama_sample_mirostat: s_hat = %f, epsilon_hat = %f, *mu = %f, N = %f\n", s_hat, epsilon_hat, *mu, N); float new_k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat); - *k = std::min(new_k, float(candidates.size())); + // printf("llama_sample_mirostat: new_k = %f\n", new_k); + *k = int(std::min(new_k, float(candidates.size()))); // Sample the next word X using top-k sampling // printf("llama_sample_mirostat *k = %d\n", *k); - llama_sample_top_k(candidates_p, *k); + llama_sample_top_k(nullptr, candidates_p, *k); + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } llama_token X = llama_sample_token(ctx, candidates_p); + t_start_sample_us = ggml_time_us(); // Compute error as the difference between observed surprise and target surprise value - int X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) { + size_t X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) { return candidate.id == X; })); float observed_surprise = -log2f(candidates[X_idx].p); @@ -1774,13 +1832,23 @@ llama_token llama_sample_mirostat(struct llama_context * ctx, llama_token_data_a // Update mu using the learning rate and error *mu = *mu - eta * e; + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + ctx->n_sample++; + } return X; } -llama_token llama_sample_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, float * mu) { +llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, float * mu) { + assert(ctx); + int64_t t_start_sample_us; + t_start_sample_us = ggml_time_us(); + + // https://arxiv.org/abs/2007.14966 + // Algorithm 2 std::span candidates(candidates_p->data, candidates_p->size); - llama_sample_softmax(candidates_p); + llama_sample_softmax(ctx, candidates_p); // Truncate the words with surprise values greater than mu candidates_p->size = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) { @@ -1788,13 +1856,17 @@ llama_token llama_sample_mirostat_v2(struct llama_context * ctx, llama_token_dat })); // Normalize the probabilities of the remaining words - llama_sample_softmax(candidates_p); + llama_sample_softmax(ctx, candidates_p); // Sample the next word X from the remaining words + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } llama_token X = llama_sample_token(ctx, candidates_p); + t_start_sample_us = ggml_time_us(); // Compute error as the difference between observed surprise and target surprise value - int X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) { + size_t X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) { return candidate.id == X; })); float observed_surprise = -log2f(candidates[X_idx].p); @@ -1803,11 +1875,15 @@ llama_token llama_sample_mirostat_v2(struct llama_context * ctx, llama_token_dat // Update mu using the learning rate and error *mu = *mu - eta * e; + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } return X; } - llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates_p) { + const int64_t t_start_sample_us = ggml_time_us(); + // Find max element std::span candidates(candidates_p->data, candidates_p->size); auto max_iter = std::max_element(candidates.begin(), candidates.end(), [](const llama_token_data & a, const llama_token_data & b) { @@ -1815,14 +1891,17 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da }); llama_token result = max_iter->id; - ctx->n_sample++; + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + ctx->n_sample++; + } return result; } - llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates_p) { - // const int64_t t_start_sample_us = ggml_time_us(); - llama_sample_softmax(candidates_p); + assert(ctx); + const int64_t t_start_sample_us = ggml_time_us(); + llama_sample_softmax(nullptr, candidates_p); std::span candidates(candidates_p->data, candidates_p->size); @@ -1838,9 +1917,8 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra llama_token result = candidates[idx].id; - // ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; ctx->n_sample++; - return result; } diff --git a/llama.h b/llama.h index 129574eed..4f72c273c 100644 --- a/llama.h +++ b/llama.h @@ -187,20 +187,20 @@ extern "C" { LLAMA_API llama_token llama_token_eos(); // Sampling functions - LLAMA_API void llama_sample_repetition_penalty(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty); - LLAMA_API void llama_sample_frequency_and_presence_penalties(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence); + LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty); + LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence); - LLAMA_API void llama_sample_softmax(llama_token_data_array * candidates); - LLAMA_API void llama_sample_top_k(llama_token_data_array * candidates, int k); - LLAMA_API void llama_sample_top_p(llama_token_data_array * candidates, float p, size_t min_keep = 1); - LLAMA_API void llama_sample_tail_free(llama_token_data_array * candidates, float z, size_t min_keep = 1); - LLAMA_API void llama_sample_typical(llama_token_data_array * candidates, float p, size_t min_keep = 1); - LLAMA_API void llama_sample_temperature(llama_token_data_array * candidates, float temp); + LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates); + LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1); + LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1); + LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1); + LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1); + LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp); + LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float N, int * k, float * mu); + LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu); LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates); LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates); - LLAMA_API llama_token llama_sample_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float N, int * k, float * mu); - LLAMA_API llama_token llama_sample_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu); // Performance information LLAMA_API void llama_print_timings(struct llama_context * ctx); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9bc5ea036..645648585 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -3,8 +3,6 @@ function(llama_add_test source) add_executable(${TEST_TARGET} ${source}) target_link_libraries(${TEST_TARGET} PRIVATE llama) add_test(NAME ${TEST_TARGET} COMMAND $ ${ARGN}) - target_compile_options(${TEST_TARGET} PRIVATE -fsanitize=address) - target_link_options(${TEST_TARGET} PRIVATE -fsanitize=address) endfunction() # llama_add_test(test-double-float.c) # SLOW diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp new file mode 100644 index 000000000..0a23c80c5 --- /dev/null +++ b/tests/test-sampling.cpp @@ -0,0 +1,229 @@ +#include "ggml.h" +#include "llama.h" +#include +#include +#include +#include +#include +#include +#include + +void dump(const llama_token_data_array * candidates) { + for (size_t i = 0; i < candidates->size; i++) { + printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit); + } +} + +#define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0) + + +void test_top_k(const std::vector & probs, + const std::vector & expected_probs, + int k) { + size_t n_vocab = probs.size(); + std::vector candidates; + candidates.reserve(n_vocab); + for (int i = 0; i < n_vocab; i++) { + float logit = log(probs[i]); + candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + llama_sample_softmax(nullptr, &candidates_p); + // DUMP(&candidates_p); + llama_sample_top_k(nullptr, &candidates_p, k); + // DUMP(&candidates_p); + + assert(candidates_p.size == expected_probs.size()); + for (size_t i = 0; i < candidates_p.size; i++) { + assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5); + } +} + + +void test_top_p(const std::vector & probs, + const std::vector & expected_probs, + float p) { + + size_t n_vocab = probs.size(); + std::vector candidates; + candidates.reserve(n_vocab); + for (int i = 0; i < n_vocab; i++) { + float logit = log(probs[i]); + candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + // DUMP(&candidates_p); + llama_sample_top_p(nullptr, &candidates_p, p); + // DUMP(&candidates_p); + + assert(candidates_p.size == expected_probs.size()); + for (size_t i = 0; i < candidates_p.size; i++) { + assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5); + } +} + + +void test_tfs(const std::vector & probs, + const std::vector & expected_probs, + float z) { + size_t n_vocab = probs.size(); + std::vector candidates; + candidates.reserve(n_vocab); + for (int i = 0; i < n_vocab; i++) { + float logit = log(probs[i]); + candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + // DUMP(&candidates_p); + llama_sample_tail_free(nullptr, &candidates_p, z); + // DUMP(&candidates_p); + + assert(candidates_p.size == expected_probs.size()); + for (size_t i = 0; i < candidates_p.size; i++) { + assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6); + } +} + + +void test_typical(const std::vector & probs, + const std::vector & expected_probs, + float p) { + size_t n_vocab = probs.size(); + std::vector candidates; + candidates.reserve(n_vocab); + for (int i = 0; i < n_vocab; i++) { + float logit = log(probs[i]); + candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + // DUMP(&candidates_p); + llama_sample_typical(nullptr, &candidates_p, p); + // DUMP(&candidates_p); + + assert(candidates_p.size == expected_probs.size()); + for (size_t i = 0; i < candidates_p.size; i++) { + assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6); + } +} + + +void test_repetition_penalty( + const std::vector & probs, + const std::vector & last_tokens, + const std::vector & expected_probs, + float penalty) { + assert(probs.size() == expected_probs.size()); + + size_t n_vocab = probs.size(); + std::vector candidates; + candidates.reserve(n_vocab); + for (int i = 0; i < n_vocab; i++) { + float logit = log(probs[i]); + candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + llama_sample_softmax(nullptr, &candidates_p); + DUMP(&candidates_p); + llama_sample_repetition_penalty(nullptr, &candidates_p, (llama_token *)last_tokens.data(), last_tokens.size(), penalty); + llama_sample_softmax(nullptr, &candidates_p); + DUMP(&candidates_p); + + assert(candidates_p.size == expected_probs.size()); + for (size_t i = 0; i < candidates_p.size; i++) { + assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6); + } +} + + +void test_frequency_presence_penalty( + const std::vector & probs, + const std::vector & last_tokens, + const std::vector & expected_probs, + float alpha_frequency, float alpha_presence) { + assert(probs.size() == expected_probs.size()); + + size_t n_vocab = probs.size(); + std::vector candidates; + candidates.reserve(n_vocab); + for (int i = 0; i < n_vocab; i++) { + float logit = log(probs[i]); + candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + llama_sample_softmax(nullptr, &candidates_p); + // DUMP(&candidates_p); + llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p, (llama_token *)last_tokens.data(), last_tokens.size(), alpha_frequency, alpha_presence); + llama_sample_softmax(nullptr, &candidates_p); + // DUMP(&candidates_p); + + assert(candidates_p.size == expected_probs.size()); + for (size_t i = 0; i < candidates_p.size; i++) { + assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6); + } +} + + +void test_mirostat() { + std::vector probs = {0.1, 0.2, 0.3, 0.4}; + std::vector expected_probs = {0.1, 0.2, 0.3, 0.4}; + + size_t n_vocab = probs.size(); + std::vector candidates; + candidates.reserve(n_vocab); + for (int i = 0; i < n_vocab; i++) { + float logit = log(probs[i]); + candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + DUMP(&candidates_p); + + float tau = 5.0f; + float mu = 2.0f * tau; + int k = 0; + float eta = 0.1f; + int m = 100; + // float N = 32000; + float N = 4; + // llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu); + DUMP(&candidates_p); + + // assert(candidates_p.size == expected_probs.size()); + // for (size_t i = 0; i < candidates_p.size; i++) { + // assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6); + // } +} + +int main(void) { + test_top_k({0.1, 0.2, 0.3, 0.4}, {0.4}, 1); + test_top_k({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3, 0.2}, 3); + + test_top_p({0.1, 0.2, 0.3, 0.4}, {0.4}, 0); + test_top_p({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3}, 0.7); + test_top_p({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3, 0.2, 0.1}, 1); + + test_tfs({0.1, 0.15, 0.2, 0.25, 0.3}, {0.3}, 0.25); + test_tfs({0.1, 0.15, 0.2, 0.25, 0.3}, {0.3, 0.25}, 0.75); + test_tfs({0.1, 0.15, 0.2, 0.25, 0.3}, {0.3, 0.25}, 0.99); + + test_typical({0.97, 0.01, 0.01, 0.01}, {0.97}, 0.5); + test_typical({0.4, 0.2, 0.2, 0.2}, {0.2, 0.2, 0.2}, 0.5); + + test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0, 0.25, 0.25, 0.25, 0.25}, 50.0); + test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0, 0, 0, 0.5, 0.5}, 50.0); + test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0, 0, 0, 0.5, 0.5}, 50.0); + + test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0.249997, 0.249997, 0.249997, 0.249997, 0.000011}, 5.0, 5.0); + test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0.499966, 0.499966, 0.000023, 0.000023, 0.000023}, 5.0, 5.0); + test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0.499977, 0.499977, 0.000023, 0.000023, 0.000000}, 5.0, 5.0); + + // test_mirostat(); + + printf("OK\n"); +}