From 9b3b07cc5ce2cae67b5da1bce8657a15e0daf39a Mon Sep 17 00:00:00 2001 From: Ivan Stepanov Date: Sat, 22 Apr 2023 14:31:08 +0300 Subject: [PATCH] Sample interface, new samplers. New samplers: - locally typical sampling - tail free sampling - frequency and presence penalty - mirostat Ignore EOS fix: -inf should be used. --- CMakeLists.txt | 2 +- Makefile | 2 +- examples/common.cpp | 28 +++ examples/main/main.cpp | 62 +++++- llama.cpp | 460 +++++++++++++++++++++++++++++++---------- llama.h | 34 +-- tests/CMakeLists.txt | 3 + 7 files changed, 455 insertions(+), 136 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5fdbeddfc..9d7c9d1ed 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -76,7 +76,7 @@ option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) # Compile flags # -set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED true) set(CMAKE_C_STANDARD 11) set(CMAKE_C_STANDARD_REQUIRED true) diff --git a/Makefile b/Makefile index 0715e857b..b4af18c0e 100644 --- a/Makefile +++ b/Makefile @@ -35,7 +35,7 @@ endif # keep standard at C11 and C++11 CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC -CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC +CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++20 -fPIC LDFLAGS = # warnings diff --git a/examples/common.cpp b/examples/common.cpp index 9f10dc268..a8f57360a 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -114,6 +114,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.temp = std::stof(argv[i]); + } else if (arg == "--tfs") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.tfs_z = std::stof(argv[i]); + } else if (arg == "--typical") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.typical_p = std::stof(argv[i]); } else if (arg == "--repeat_last_n") { if (++i >= argc) { invalid_param = true; @@ -126,6 +138,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.repeat_penalty = std::stof(argv[i]); + } else if (arg == "--alpha_frequency") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.alpha_frequency = std::stof(argv[i]); + } else if (arg == "--alpha_presence") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.alpha_presence = std::stof(argv[i]); } else if (arg == "-b" || arg == "--batch_size") { if (++i >= argc) { invalid_param = true; @@ -242,6 +266,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict); fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k); fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", (double)params.top_p); + fprintf(stderr, " --tfs N tail free sampling (default: %.1f)\n", (double)params.tfs_z); + fprintf(stderr, " --typical N locally typical sampling (default: %.1f)\n", (double)params.typical_p); + fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %d)\n", params.alpha_presence); + fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f)\n", (double)params.alpha_frequency); fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n); fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty); fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index fda65574f..9b795bd3a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -276,8 +276,8 @@ int main(int argc, char ** argv) { fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); } } - fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", - params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); + fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f\n", + params.repeat_last_n, params.repeat_penalty, params.alpha_presence, params.alpha_frequency, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp); fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); fprintf(stderr, "\n\n"); @@ -387,10 +387,15 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) { // out of user input, sample next token - const int32_t top_k = params.top_k; - const float top_p = params.top_p; const float temp = params.temp; + const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; + const float top_p = params.top_p; + const float tfs_z = params.tfs_z; + const float typical_p = params.typical_p; + const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; const float repeat_penalty = params.repeat_penalty; + const float alpha_presence = params.alpha_presence; + const float alpha_frequency = params.alpha_frequency; // optionally save the session on first sample (for faster prompt loading next time) if (!path_session.empty() && need_to_save_session) { @@ -402,14 +407,55 @@ int main(int argc, char ** argv) { { auto logits = llama_get_logits(ctx); + auto n_vocab = llama_n_vocab(ctx); if (params.ignore_eos) { - logits[llama_token_eos()] = 0; + logits[llama_token_eos()] = -INFINITY; } - id = llama_sample_top_p_top_k(ctx, - last_n_tokens.data() + n_ctx - params.repeat_last_n, - params.repeat_last_n, top_k, top_p, temp, repeat_penalty); + std::vector candidates; + candidates.reserve(n_vocab); + for (size_t i = 0; i < n_vocab; i++) { + candidates.emplace_back(i, logits[i], 0.0f); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + + // Apply penalties + auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); + llama_sample_repetition_penalty(&candidates_p, + last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + last_n_repeat, repeat_penalty); + llama_sample_frequency_and_presence_penalties(&candidates_p, + last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + last_n_repeat, alpha_frequency, alpha_presence); + + +#if 1 + if (temp <= 0) { + // Greedy sampling + id = llama_sample_token_greedy(ctx, &candidates_p); + } else { + // Temperature sampling + llama_sample_top_k(&candidates_p, top_k); + llama_sample_tail_free(&candidates_p, tfs_z); + llama_sample_typical(&candidates_p, typical_p); + llama_sample_top_p(&candidates_p, top_p); + + llama_sample_temperature(&candidates_p, temp); + // printf("`%d`", candidates_p.size); + id = llama_sample_token(ctx, &candidates_p); + } +#else + const float tau = 5.0f; + static float mu = 2.0f * tau; + static int k = 40; + const float eta = 0.1f; + const int m = 100; + const float N = n_vocab; + id = llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu); + // id = llama_sample_mirostat_v2(ctx, &candidates_p, tau, eta, &mu); +#endif last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); diff --git a/llama.cpp b/llama.cpp index dca017db6..64debd715 100644 --- a/llama.cpp +++ b/llama.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #define LLAMA_USE_SCRATCH #define LLAMA_MAX_SCRATCH_BUFFERS 16 @@ -1478,109 +1479,369 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co // sampling // -static void sample_top_k(std::vector> & logits_id, int top_k) { - // find the top k tokens - std::partial_sort( - logits_id.begin(), - logits_id.begin() + top_k, logits_id.end(), - [](const std::pair & a, const std::pair & b) { - return a.first > b.first; - }); +void llama_sample_softmax(llama_token_data_array * candidates) { + assert(candidates->size > 0); + std::span tokens(candidates->data, candidates->size); - logits_id.resize(top_k); + // Sort the logits in descending order + if (!candidates->sorted) { + std::sort(tokens.begin(), tokens.end(), [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + candidates->sorted = true; + } + + float max_l = tokens[0].logit; + float cum_sum = 0.0f; + for (size_t i = 0; i < tokens.size(); ++i) { + // printf("llama_sample_softmax: i: %d, logit: %f\n", i, tokens[i].logit); + float p = expf(tokens[i].logit - max_l); + tokens[i].p = p; + cum_sum += p; + } + for (size_t i = 0; i < tokens.size(); ++i) { + tokens[i].p /= cum_sum; + } } -static llama_vocab::id llama_sample_top_p_top_k( - llama_context & lctx, - const std::vector & last_n_tokens, - int top_k, - float top_p, - float temp, - float repeat_penalty) { - auto & rng = lctx.rng; +void llama_sample_top_k(llama_token_data_array * candidates_p, int k) { + assert(k > 0); + std::span candidates(candidates_p->data, candidates_p->size); - const int n_logits = lctx.model.hparams.n_vocab; - - const auto & logits = lctx.logits; - const auto * plogits = logits.data() + logits.size() - n_logits; - - if (temp <= 0) { - // select the token with the highest logit directly - float max_logit = plogits[0]; - llama_vocab::id max_id = 0; - - for (int i = 1; i < n_logits; ++i) { - if (plogits[i] > max_logit) { - max_logit = plogits[i]; - max_id = i; - } + // Sort scores in descending order + if (!candidates_p->sorted) { + if (k >= candidates_p->size) { + std::sort(candidates.begin(), candidates.end(), [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + } else { + std::partial_sort(candidates.begin(), candidates.begin() + k, candidates.end(), + [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); } - return max_id; + candidates_p->sorted = true; + } + candidates_p->size = std::min(k, (int) candidates.size()); +} + +void llama_sample_top_p(llama_token_data_array * candidates_p, float p, size_t min_keep) { + if (p >= 1.0f) { + return; } - std::vector> logits_id; - logits_id.reserve(n_logits); + llama_sample_softmax(candidates_p); - { - const float scale = 1.0f/temp; - for (int i = 0; i < n_logits; ++i) { - // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858) - // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main - if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) { - // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability - if (plogits[i] < 0.0f) { - logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i)); - } else { - logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i)); - } - } else { - logits_id.push_back(std::make_pair(plogits[i]*scale, i)); - } + // Compute the cumulative probabilities + float cum_sum = 0.0f; + size_t last_idx = candidates_p->size; + + for (size_t i = 0; i < candidates_p->size; ++i) { + cum_sum += candidates_p->data[i].p; + + // Check if the running sum is greater than p or if we have kept at least min_keep tokens + if (cum_sum > p && i >= min_keep) { + last_idx = i; + break; } } - sample_top_k(logits_id, top_k > 0 ? std::min(top_k, n_logits) : n_logits); + // Resize the output vector to keep only the top-p tokens + candidates_p->size = last_idx; +} + +// https://www.trentonbricken.com/Tail-Free-Sampling/ +void llama_sample_tail_free(llama_token_data_array * candidates_p, float z, size_t min_keep) { + if (z >= 1.0f || candidates_p->size <= 2) { + return; + } + + llama_sample_softmax(candidates_p); + + // Compute the first and second derivatives + std::vector first_derivatives(candidates_p->size - 1); + std::vector second_derivatives(candidates_p->size - 2); + + for (size_t i = 0; i < first_derivatives.size(); ++i) { + first_derivatives[i] = candidates_p->data[i].p - candidates_p->data[i + 1].p; + } + for (size_t i = 0; i < second_derivatives.size(); ++i) { + second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1]; + } + + // Calculate absolute value of second derivatives + for (size_t i = 0; i < second_derivatives.size(); ++i) { + second_derivatives[i] = abs(second_derivatives[i]); + } + + // Normalize the second derivatives + float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f); + for (float & value : second_derivatives) { + value /= second_derivatives_sum; + } + + float cum_sum = 0.0f; + size_t last_idx = candidates_p->size; + for (size_t i = 0; i < second_derivatives.size(); ++i) { + cum_sum += second_derivatives[i]; + + // Check if the running sum is greater than z or if we have kept at least min_keep tokens + if (cum_sum > z && i >= min_keep) { + last_idx = i; + break; + } + } + + // Resize the output vector to keep only the tokens above the tail location + candidates_p->size = last_idx; +} + + +// https://arxiv.org/pdf/2202.00666.pdf +// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr +void llama_sample_typical(llama_token_data_array * candidates_p, float typical_p, size_t min_keep) { + if (typical_p >= 1.0f) { + return; + } + + // Compute the softmax of logits and calculate entropy + llama_sample_softmax(candidates_p); + + std::span candidates(candidates_p->data, candidates_p->size); + + float entropy = 0.0f; + for (const auto & candidate : candidates) { + entropy += -candidate.p * logf(candidate.p); + } + + // Compute the absolute difference between negative log probability and entropy for each candidate + std::vector shifted_scores; + for (const auto & candidate : candidates) { + float shifted_score = fabsf(-logf(candidate.p) - entropy); + shifted_scores.push_back(shifted_score); + } + + // Sort candidates based on the shifted_scores and their corresponding indices + std::vector indices(candidates.size()); + std::iota(indices.begin(), indices.end(), 0); + + std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { + return shifted_scores[a] < shifted_scores[b]; + }); + + // Compute the cumulative probabilities + float cum_sum = 0.0f; + size_t last_idx = indices.size(); + + for (size_t i = 0; i < indices.size(); ++i) { + size_t idx = indices[i]; + cum_sum += candidates[idx].p; + + // Check if the running sum is greater than typical or if we have kept at least min_keep tokens + if (cum_sum > typical_p && i >= min_keep - 1) { + last_idx = i + 1; + break; + } + } + + // Resize the output vector to keep only the locally typical tokens + std::vector new_candidates; + for (size_t i = 0; i < last_idx; ++i) { + size_t idx = indices[i]; + new_candidates.push_back(candidates[idx]); + } + + // Replace the data in candidates_p with the new_candidates data + std::copy(new_candidates.begin(), new_candidates.end(), candidates_p->data); + candidates_p->size = new_candidates.size(); +} + + +void llama_sample_temperature(llama_token_data_array * candidates_p, float temp) { + std::span candidates(candidates_p->data, candidates_p->size); + for (auto & candidate : candidates) { + candidate.logit /= temp; + } +} + +void llama_sample_repetition_penalty(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty) { + if (last_tokens_size == 0 || penalty == 1.0f) { + return; + } + + // CTRL paper: https://arxiv.org/pdf/1909.05858.pdf + std::span candidates(candidates_p->data, candidates_p->size); + std::span last_tokens(last_tokens_p, last_tokens_size); + + for (size_t i = 0; i < candidates.size(); ++i) { + auto token_iter = std::find(last_tokens.begin(), last_tokens.end(), candidates[i].id); + if (token_iter == last_tokens.end()) { + continue; + } + + // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. + // This is common fix for this problem, which is to multiply by the penalty instead of dividing. + if (candidates[i].logit <= 0) { + candidates[i].logit *= penalty; + } else { + candidates[i].logit /= penalty; + } + + // But it does not penalize tokens that logits are near zero, which is a problem. + // Another solution is to convert the logits to probabilities, apply the penalty, and then convert back to logits. + // float probability = std::exp(candidates[i].logit); + // probability /= penalty; + // candidates[i].logit = std::log(probability); + } + + candidates_p->sorted = false; +} + +void llama_sample_frequency_and_presence_penalties(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) { + if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) { + return; + } + + std::span candidates(candidates_p->data, candidates_p->size); + std::span last_tokens(last_tokens_p, last_tokens_size); + + // Create a frequency map to count occurrences of each token in last_tokens + std::unordered_map token_count; + for (const auto & token : last_tokens) { + token_count[token]++; + } + + // Apply frequency and presence penalties to the candidates + for (size_t i = 0; i < candidates.size(); ++i) { + auto token_iter = token_count.find(candidates[i].id); + if (token_iter == token_count.end()) { + continue; + } + + int count = token_iter->second; + candidates[i].logit -= count * alpha_frequency + float(count > 0) * alpha_presence; + } + + candidates_p->sorted = false; +} + +/// @brief Mirostat 1.0 implementation. +/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. +/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. +/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. +/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. +/// @param N The size of the vocabulary. This is used in the calculation of the `k` value. +/// @param k A reference to the integer variable used to store the calculated top-k value. The top-k value determines how many of the most probable tokens are considered for sampling. +/// @param mu A reference to the floating-point variable that represents the maximum cross-entropy value. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. +llama_token llama_sample_mirostat(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, int m, float N, int * k, float * mu) { + // https://arxiv.org/abs/2007.14966 + std::span candidates(candidates_p->data, candidates_p->size); + + // printf("llama_sample_mirostat: candidates.size() = %d, m = %d, N = %f, tau = %f, eta = %f, *k = %d, *mu = %f\n", candidates.size(), m, N, tau, eta, *k, *mu); + + llama_sample_softmax(candidates_p); + + // Estimate s_hat using the most probable m tokens + float s_hat = 0.0; + float sum_ti_bi = 0.0; + float sum_ti_sq = 0.0; + for (int i = 0; i < m - 1 && i < candidates.size() - 1; ++i) { + float t_i = logf((i + 2) / float(i + 1)); + float b_i = logf(candidates[i].p / candidates[i + 1].p); + sum_ti_bi += t_i * b_i; + sum_ti_sq += t_i * t_i; + } + s_hat = sum_ti_bi / sum_ti_sq; + + // Compute k from the estimated s_hat and target surprise value + float epsilon_hat = s_hat - 1; + // printf("llama_sample_mirostat: s_hat = %f, epsilon_hat = %f, *mu = %f, N = %f\n", s_hat, epsilon_hat, *mu, N); + float new_k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat); + *k = std::min(new_k, float(candidates.size())); + + // Sample the next word X using top-k sampling + // printf("llama_sample_mirostat *k = %d\n", *k); + llama_sample_top_k(candidates_p, *k); + llama_token X = llama_sample_token(ctx, candidates_p); + + // Compute error as the difference between observed surprise and target surprise value + int X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) { + return candidate.id == X; + })); + float observed_surprise = -log2f(candidates[X_idx].p); + float e = observed_surprise - tau; + + // Update mu using the learning rate and error + *mu = *mu - eta * e; + + return X; +} + +llama_token llama_sample_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, float * mu) { + std::span candidates(candidates_p->data, candidates_p->size); + + llama_sample_softmax(candidates_p); + + // Truncate the words with surprise values greater than mu + candidates_p->size = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) { + return -log2f(candidate.p) > *mu; + })); + + // Normalize the probabilities of the remaining words + llama_sample_softmax(candidates_p); + + // Sample the next word X from the remaining words + llama_token X = llama_sample_token(ctx, candidates_p); + + // Compute error as the difference between observed surprise and target surprise value + int X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) { + return candidate.id == X; + })); + float observed_surprise = -log2f(candidates[X_idx].p); + float e = observed_surprise - tau; + + // Update mu using the learning rate and error + *mu = *mu - eta * e; + + return X; +} + + +llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates_p) { + // Find max element + std::span candidates(candidates_p->data, candidates_p->size); + auto max_iter = std::max_element(candidates.begin(), candidates.end(), [](const llama_token_data & a, const llama_token_data & b) { + return a.logit < b.logit; + }); + + llama_token result = max_iter->id; + ctx->n_sample++; + return result; +} + + +llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates_p) { + // const int64_t t_start_sample_us = ggml_time_us(); + llama_sample_softmax(candidates_p); + + std::span candidates(candidates_p->data, candidates_p->size); - // compute probs for the top k tokens std::vector probs; - probs.reserve(logits_id.size()); - - float maxl = logits_id[0].first; - double sum = 0.0; - for (const auto & kv : logits_id) { - const float p = expf(kv.first - maxl); - probs.push_back(p); - sum += p; + probs.reserve(candidates.size()); + for (auto & candidate : candidates) { + probs.push_back(candidate.p); } - // normalize the probs - for (auto & p : probs) { - p /= sum; - } - - if (top_p < 1.0) { - double cumsum = 0.0; - for (int i = 0; i < (int) probs.size(); i++) { - cumsum += probs[i]; - if (cumsum >= top_p) { - probs.resize(i + 1); - logits_id.resize(i + 1); - break; - } - } - } - - //printf("\n"); - //for (int i = 0; i < (int) 10; i++) { - // printf("%d: '%s' %f\n", i, lctx.vocab.id_to_token.at(logits_id[i].second).tok.c_str(), probs[i]); - //} - //printf("\n\n"); - //exit(0); - std::discrete_distribution<> dist(probs.begin(), probs.end()); + auto & rng = ctx->rng; int idx = dist(rng); - return logits_id[idx].second; + llama_token result = candidates[idx].id; + + // ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + ctx->n_sample++; + + return result; } // @@ -2352,35 +2613,6 @@ llama_token llama_token_eos() { return 2; } -llama_token llama_sample_top_p_top_k( - llama_context * ctx, - const llama_token * last_n_tokens_data, - int last_n_tokens_size, - int top_k, - float top_p, - float temp, - float repeat_penalty) { - const int64_t t_start_sample_us = ggml_time_us(); - - llama_token result = 0; - - // TODO: avoid this ... - const auto last_n_tokens = std::vector(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size); - - result = llama_sample_top_p_top_k( - *ctx, - last_n_tokens, - top_k, - top_p, - temp, - repeat_penalty); - - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - ctx->n_sample++; - - return result; -} - void llama_print_timings(struct llama_context * ctx) { const int64_t t_end_us = ggml_time_us(); diff --git a/llama.h b/llama.h index 86a7d279a..129574eed 100644 --- a/llama.h +++ b/llama.h @@ -39,12 +39,16 @@ extern "C" { typedef struct llama_token_data { llama_token id; // token id - + float logit; // log-odds of the token float p; // probability of the token - float plog; // log probability of the token - } llama_token_data; + typedef struct llama_token_data_array { + llama_token_data * data; + size_t size; + bool sorted; + } llama_token_data_array; + typedef void (*llama_progress_callback)(float progress, void *ctx); struct llama_context_params { @@ -182,15 +186,21 @@ extern "C" { LLAMA_API llama_token llama_token_bos(); LLAMA_API llama_token llama_token_eos(); - // TODO: improve the last_n_tokens interface ? - LLAMA_API llama_token llama_sample_top_p_top_k( - struct llama_context * ctx, - const llama_token * last_n_tokens_data, - int last_n_tokens_size, - int top_k, - float top_p, - float temp, - float repeat_penalty); + // Sampling functions + LLAMA_API void llama_sample_repetition_penalty(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty); + LLAMA_API void llama_sample_frequency_and_presence_penalties(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence); + + LLAMA_API void llama_sample_softmax(llama_token_data_array * candidates); + LLAMA_API void llama_sample_top_k(llama_token_data_array * candidates, int k); + LLAMA_API void llama_sample_top_p(llama_token_data_array * candidates, float p, size_t min_keep = 1); + LLAMA_API void llama_sample_tail_free(llama_token_data_array * candidates, float z, size_t min_keep = 1); + LLAMA_API void llama_sample_typical(llama_token_data_array * candidates, float p, size_t min_keep = 1); + LLAMA_API void llama_sample_temperature(llama_token_data_array * candidates, float temp); + + LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates); + LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates); + LLAMA_API llama_token llama_sample_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float N, int * k, float * mu); + LLAMA_API llama_token llama_sample_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu); // Performance information LLAMA_API void llama_print_timings(struct llama_context * ctx); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 81eadbc4d..9bc5ea036 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -3,9 +3,12 @@ function(llama_add_test source) add_executable(${TEST_TARGET} ${source}) target_link_libraries(${TEST_TARGET} PRIVATE llama) add_test(NAME ${TEST_TARGET} COMMAND $ ${ARGN}) + target_compile_options(${TEST_TARGET} PRIVATE -fsanitize=address) + target_link_options(${TEST_TARGET} PRIVATE -fsanitize=address) endfunction() # llama_add_test(test-double-float.c) # SLOW llama_add_test(test-quantize-fns.cpp) llama_add_test(test-quantize-perf.cpp) +llama_add_test(test-sampling.cpp) llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin)