Sample interface, new samplers.

New samplers:
- locally typical sampling
- tail free sampling
- frequency and presence penalty
- mirostat

Ignore EOS fix: -inf should be used.
This commit is contained in:
Ivan Stepanov 2023-04-22 14:31:08 +03:00
parent 5fba3c016b
commit 9b3b07cc5c
7 changed files with 455 additions and 136 deletions

View file

@ -76,7 +76,7 @@ option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
# Compile flags
#
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED true)
set(CMAKE_C_STANDARD 11)
set(CMAKE_C_STANDARD_REQUIRED true)

View file

@ -35,7 +35,7 @@ endif
# keep standard at C11 and C++11
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++20 -fPIC
LDFLAGS =
# warnings

View file

@ -114,6 +114,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.temp = std::stof(argv[i]);
} else if (arg == "--tfs") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.tfs_z = std::stof(argv[i]);
} else if (arg == "--typical") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.typical_p = std::stof(argv[i]);
} else if (arg == "--repeat_last_n") {
if (++i >= argc) {
invalid_param = true;
@ -126,6 +138,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.repeat_penalty = std::stof(argv[i]);
} else if (arg == "--alpha_frequency") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.alpha_frequency = std::stof(argv[i]);
} else if (arg == "--alpha_presence") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.alpha_presence = std::stof(argv[i]);
} else if (arg == "-b" || arg == "--batch_size") {
if (++i >= argc) {
invalid_param = true;
@ -242,6 +266,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", (double)params.top_p);
fprintf(stderr, " --tfs N tail free sampling (default: %.1f)\n", (double)params.tfs_z);
fprintf(stderr, " --typical N locally typical sampling (default: %.1f)\n", (double)params.typical_p);
fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %d)\n", params.alpha_presence);
fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f)\n", (double)params.alpha_frequency);
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty);
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);

View file

@ -276,8 +276,8 @@ int main(int argc, char ** argv) {
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
}
}
fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f\n",
params.repeat_last_n, params.repeat_penalty, params.alpha_presence, params.alpha_frequency, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp);
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
fprintf(stderr, "\n\n");
@ -387,10 +387,15 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// out of user input, sample next token
const int32_t top_k = params.top_k;
const float top_p = params.top_p;
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
const float repeat_penalty = params.repeat_penalty;
const float alpha_presence = params.alpha_presence;
const float alpha_frequency = params.alpha_frequency;
// optionally save the session on first sample (for faster prompt loading next time)
if (!path_session.empty() && need_to_save_session) {
@ -402,14 +407,55 @@ int main(int argc, char ** argv) {
{
auto logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx);
if (params.ignore_eos) {
logits[llama_token_eos()] = 0;
logits[llama_token_eos()] = -INFINITY;
}
id = llama_sample_top_p_top_k(ctx,
last_n_tokens.data() + n_ctx - params.repeat_last_n,
params.repeat_last_n, top_k, top_p, temp, repeat_penalty);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (size_t i = 0; i < n_vocab; i++) {
candidates.emplace_back(i, logits[i], 0.0f);
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size() };
// Apply penalties
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
llama_sample_repetition_penalty(&candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, repeat_penalty);
llama_sample_frequency_and_presence_penalties(&candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence);
#if 1
if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx, &candidates_p);
} else {
// Temperature sampling
llama_sample_top_k(&candidates_p, top_k);
llama_sample_tail_free(&candidates_p, tfs_z);
llama_sample_typical(&candidates_p, typical_p);
llama_sample_top_p(&candidates_p, top_p);
llama_sample_temperature(&candidates_p, temp);
// printf("`%d`", candidates_p.size);
id = llama_sample_token(ctx, &candidates_p);
}
#else
const float tau = 5.0f;
static float mu = 2.0f * tau;
static int k = 40;
const float eta = 0.1f;
const int m = 100;
const float N = n_vocab;
id = llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu);
// id = llama_sample_mirostat_v2(ctx, &candidates_p, tau, eta, &mu);
#endif
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);

460
llama.cpp
View file

@ -28,6 +28,7 @@
#include <atomic>
#include <mutex>
#include <sstream>
#include <span>
#define LLAMA_USE_SCRATCH
#define LLAMA_MAX_SCRATCH_BUFFERS 16
@ -1478,109 +1479,369 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
// sampling
//
static void sample_top_k(std::vector<std::pair<float, llama_vocab::id>> & logits_id, int top_k) {
// find the top k tokens
std::partial_sort(
logits_id.begin(),
logits_id.begin() + top_k, logits_id.end(),
[](const std::pair<float, llama_vocab::id> & a, const std::pair<float, llama_vocab::id> & b) {
return a.first > b.first;
});
void llama_sample_softmax(llama_token_data_array * candidates) {
assert(candidates->size > 0);
std::span<llama_token_data> tokens(candidates->data, candidates->size);
logits_id.resize(top_k);
// Sort the logits in descending order
if (!candidates->sorted) {
std::sort(tokens.begin(), tokens.end(), [](const llama_token_data & a, const llama_token_data & b) {
return a.logit > b.logit;
});
candidates->sorted = true;
}
float max_l = tokens[0].logit;
float cum_sum = 0.0f;
for (size_t i = 0; i < tokens.size(); ++i) {
// printf("llama_sample_softmax: i: %d, logit: %f\n", i, tokens[i].logit);
float p = expf(tokens[i].logit - max_l);
tokens[i].p = p;
cum_sum += p;
}
for (size_t i = 0; i < tokens.size(); ++i) {
tokens[i].p /= cum_sum;
}
}
static llama_vocab::id llama_sample_top_p_top_k(
llama_context & lctx,
const std::vector<llama_vocab::id> & last_n_tokens,
int top_k,
float top_p,
float temp,
float repeat_penalty) {
auto & rng = lctx.rng;
void llama_sample_top_k(llama_token_data_array * candidates_p, int k) {
assert(k > 0);
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
const int n_logits = lctx.model.hparams.n_vocab;
const auto & logits = lctx.logits;
const auto * plogits = logits.data() + logits.size() - n_logits;
if (temp <= 0) {
// select the token with the highest logit directly
float max_logit = plogits[0];
llama_vocab::id max_id = 0;
for (int i = 1; i < n_logits; ++i) {
if (plogits[i] > max_logit) {
max_logit = plogits[i];
max_id = i;
}
// Sort scores in descending order
if (!candidates_p->sorted) {
if (k >= candidates_p->size) {
std::sort(candidates.begin(), candidates.end(), [](const llama_token_data & a, const llama_token_data & b) {
return a.logit > b.logit;
});
} else {
std::partial_sort(candidates.begin(), candidates.begin() + k, candidates.end(),
[](const llama_token_data & a, const llama_token_data & b) {
return a.logit > b.logit;
});
}
return max_id;
candidates_p->sorted = true;
}
candidates_p->size = std::min(k, (int) candidates.size());
}
void llama_sample_top_p(llama_token_data_array * candidates_p, float p, size_t min_keep) {
if (p >= 1.0f) {
return;
}
std::vector<std::pair<float, llama_vocab::id>> logits_id;
logits_id.reserve(n_logits);
llama_sample_softmax(candidates_p);
{
const float scale = 1.0f/temp;
for (int i = 0; i < n_logits; ++i) {
// repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if (plogits[i] < 0.0f) {
logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
} else {
logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
}
} else {
logits_id.push_back(std::make_pair(plogits[i]*scale, i));
}
// Compute the cumulative probabilities
float cum_sum = 0.0f;
size_t last_idx = candidates_p->size;
for (size_t i = 0; i < candidates_p->size; ++i) {
cum_sum += candidates_p->data[i].p;
// Check if the running sum is greater than p or if we have kept at least min_keep tokens
if (cum_sum > p && i >= min_keep) {
last_idx = i;
break;
}
}
sample_top_k(logits_id, top_k > 0 ? std::min(top_k, n_logits) : n_logits);
// Resize the output vector to keep only the top-p tokens
candidates_p->size = last_idx;
}
// https://www.trentonbricken.com/Tail-Free-Sampling/
void llama_sample_tail_free(llama_token_data_array * candidates_p, float z, size_t min_keep) {
if (z >= 1.0f || candidates_p->size <= 2) {
return;
}
llama_sample_softmax(candidates_p);
// Compute the first and second derivatives
std::vector<float> first_derivatives(candidates_p->size - 1);
std::vector<float> second_derivatives(candidates_p->size - 2);
for (size_t i = 0; i < first_derivatives.size(); ++i) {
first_derivatives[i] = candidates_p->data[i].p - candidates_p->data[i + 1].p;
}
for (size_t i = 0; i < second_derivatives.size(); ++i) {
second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
}
// Calculate absolute value of second derivatives
for (size_t i = 0; i < second_derivatives.size(); ++i) {
second_derivatives[i] = abs(second_derivatives[i]);
}
// Normalize the second derivatives
float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
for (float & value : second_derivatives) {
value /= second_derivatives_sum;
}
float cum_sum = 0.0f;
size_t last_idx = candidates_p->size;
for (size_t i = 0; i < second_derivatives.size(); ++i) {
cum_sum += second_derivatives[i];
// Check if the running sum is greater than z or if we have kept at least min_keep tokens
if (cum_sum > z && i >= min_keep) {
last_idx = i;
break;
}
}
// Resize the output vector to keep only the tokens above the tail location
candidates_p->size = last_idx;
}
// https://arxiv.org/pdf/2202.00666.pdf
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
void llama_sample_typical(llama_token_data_array * candidates_p, float typical_p, size_t min_keep) {
if (typical_p >= 1.0f) {
return;
}
// Compute the softmax of logits and calculate entropy
llama_sample_softmax(candidates_p);
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
float entropy = 0.0f;
for (const auto & candidate : candidates) {
entropy += -candidate.p * logf(candidate.p);
}
// Compute the absolute difference between negative log probability and entropy for each candidate
std::vector<float> shifted_scores;
for (const auto & candidate : candidates) {
float shifted_score = fabsf(-logf(candidate.p) - entropy);
shifted_scores.push_back(shifted_score);
}
// Sort candidates based on the shifted_scores and their corresponding indices
std::vector<size_t> indices(candidates.size());
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
return shifted_scores[a] < shifted_scores[b];
});
// Compute the cumulative probabilities
float cum_sum = 0.0f;
size_t last_idx = indices.size();
for (size_t i = 0; i < indices.size(); ++i) {
size_t idx = indices[i];
cum_sum += candidates[idx].p;
// Check if the running sum is greater than typical or if we have kept at least min_keep tokens
if (cum_sum > typical_p && i >= min_keep - 1) {
last_idx = i + 1;
break;
}
}
// Resize the output vector to keep only the locally typical tokens
std::vector<llama_token_data> new_candidates;
for (size_t i = 0; i < last_idx; ++i) {
size_t idx = indices[i];
new_candidates.push_back(candidates[idx]);
}
// Replace the data in candidates_p with the new_candidates data
std::copy(new_candidates.begin(), new_candidates.end(), candidates_p->data);
candidates_p->size = new_candidates.size();
}
void llama_sample_temperature(llama_token_data_array * candidates_p, float temp) {
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
for (auto & candidate : candidates) {
candidate.logit /= temp;
}
}
void llama_sample_repetition_penalty(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty) {
if (last_tokens_size == 0 || penalty == 1.0f) {
return;
}
// CTRL paper: https://arxiv.org/pdf/1909.05858.pdf
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
std::span<llama_token> last_tokens(last_tokens_p, last_tokens_size);
for (size_t i = 0; i < candidates.size(); ++i) {
auto token_iter = std::find(last_tokens.begin(), last_tokens.end(), candidates[i].id);
if (token_iter == last_tokens.end()) {
continue;
}
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
if (candidates[i].logit <= 0) {
candidates[i].logit *= penalty;
} else {
candidates[i].logit /= penalty;
}
// But it does not penalize tokens that logits are near zero, which is a problem.
// Another solution is to convert the logits to probabilities, apply the penalty, and then convert back to logits.
// float probability = std::exp(candidates[i].logit);
// probability /= penalty;
// candidates[i].logit = std::log(probability);
}
candidates_p->sorted = false;
}
void llama_sample_frequency_and_presence_penalties(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) {
if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) {
return;
}
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
std::span<llama_token> last_tokens(last_tokens_p, last_tokens_size);
// Create a frequency map to count occurrences of each token in last_tokens
std::unordered_map<llama_token, int> token_count;
for (const auto & token : last_tokens) {
token_count[token]++;
}
// Apply frequency and presence penalties to the candidates
for (size_t i = 0; i < candidates.size(); ++i) {
auto token_iter = token_count.find(candidates[i].id);
if (token_iter == token_count.end()) {
continue;
}
int count = token_iter->second;
candidates[i].logit -= count * alpha_frequency + float(count > 0) * alpha_presence;
}
candidates_p->sorted = false;
}
/// @brief Mirostat 1.0 implementation.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
/// @param N The size of the vocabulary. This is used in the calculation of the `k` value.
/// @param k A reference to the integer variable used to store the calculated top-k value. The top-k value determines how many of the most probable tokens are considered for sampling.
/// @param mu A reference to the floating-point variable that represents the maximum cross-entropy value. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
llama_token llama_sample_mirostat(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, int m, float N, int * k, float * mu) {
// https://arxiv.org/abs/2007.14966
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
// printf("llama_sample_mirostat: candidates.size() = %d, m = %d, N = %f, tau = %f, eta = %f, *k = %d, *mu = %f\n", candidates.size(), m, N, tau, eta, *k, *mu);
llama_sample_softmax(candidates_p);
// Estimate s_hat using the most probable m tokens
float s_hat = 0.0;
float sum_ti_bi = 0.0;
float sum_ti_sq = 0.0;
for (int i = 0; i < m - 1 && i < candidates.size() - 1; ++i) {
float t_i = logf((i + 2) / float(i + 1));
float b_i = logf(candidates[i].p / candidates[i + 1].p);
sum_ti_bi += t_i * b_i;
sum_ti_sq += t_i * t_i;
}
s_hat = sum_ti_bi / sum_ti_sq;
// Compute k from the estimated s_hat and target surprise value
float epsilon_hat = s_hat - 1;
// printf("llama_sample_mirostat: s_hat = %f, epsilon_hat = %f, *mu = %f, N = %f\n", s_hat, epsilon_hat, *mu, N);
float new_k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat);
*k = std::min(new_k, float(candidates.size()));
// Sample the next word X using top-k sampling
// printf("llama_sample_mirostat *k = %d\n", *k);
llama_sample_top_k(candidates_p, *k);
llama_token X = llama_sample_token(ctx, candidates_p);
// Compute error as the difference between observed surprise and target surprise value
int X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) {
return candidate.id == X;
}));
float observed_surprise = -log2f(candidates[X_idx].p);
float e = observed_surprise - tau;
// Update mu using the learning rate and error
*mu = *mu - eta * e;
return X;
}
llama_token llama_sample_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, float * mu) {
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
llama_sample_softmax(candidates_p);
// Truncate the words with surprise values greater than mu
candidates_p->size = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) {
return -log2f(candidate.p) > *mu;
}));
// Normalize the probabilities of the remaining words
llama_sample_softmax(candidates_p);
// Sample the next word X from the remaining words
llama_token X = llama_sample_token(ctx, candidates_p);
// Compute error as the difference between observed surprise and target surprise value
int X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) {
return candidate.id == X;
}));
float observed_surprise = -log2f(candidates[X_idx].p);
float e = observed_surprise - tau;
// Update mu using the learning rate and error
*mu = *mu - eta * e;
return X;
}
llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates_p) {
// Find max element
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
auto max_iter = std::max_element(candidates.begin(), candidates.end(), [](const llama_token_data & a, const llama_token_data & b) {
return a.logit < b.logit;
});
llama_token result = max_iter->id;
ctx->n_sample++;
return result;
}
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates_p) {
// const int64_t t_start_sample_us = ggml_time_us();
llama_sample_softmax(candidates_p);
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
// compute probs for the top k tokens
std::vector<float> probs;
probs.reserve(logits_id.size());
float maxl = logits_id[0].first;
double sum = 0.0;
for (const auto & kv : logits_id) {
const float p = expf(kv.first - maxl);
probs.push_back(p);
sum += p;
probs.reserve(candidates.size());
for (auto & candidate : candidates) {
probs.push_back(candidate.p);
}
// normalize the probs
for (auto & p : probs) {
p /= sum;
}
if (top_p < 1.0) {
double cumsum = 0.0;
for (int i = 0; i < (int) probs.size(); i++) {
cumsum += probs[i];
if (cumsum >= top_p) {
probs.resize(i + 1);
logits_id.resize(i + 1);
break;
}
}
}
//printf("\n");
//for (int i = 0; i < (int) 10; i++) {
// printf("%d: '%s' %f\n", i, lctx.vocab.id_to_token.at(logits_id[i].second).tok.c_str(), probs[i]);
//}
//printf("\n\n");
//exit(0);
std::discrete_distribution<> dist(probs.begin(), probs.end());
auto & rng = ctx->rng;
int idx = dist(rng);
return logits_id[idx].second;
llama_token result = candidates[idx].id;
// ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
ctx->n_sample++;
return result;
}
//
@ -2352,35 +2613,6 @@ llama_token llama_token_eos() {
return 2;
}
llama_token llama_sample_top_p_top_k(
llama_context * ctx,
const llama_token * last_n_tokens_data,
int last_n_tokens_size,
int top_k,
float top_p,
float temp,
float repeat_penalty) {
const int64_t t_start_sample_us = ggml_time_us();
llama_token result = 0;
// TODO: avoid this ...
const auto last_n_tokens = std::vector<llama_token>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
result = llama_sample_top_p_top_k(
*ctx,
last_n_tokens,
top_k,
top_p,
temp,
repeat_penalty);
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
ctx->n_sample++;
return result;
}
void llama_print_timings(struct llama_context * ctx) {
const int64_t t_end_us = ggml_time_us();

34
llama.h
View file

@ -39,12 +39,16 @@ extern "C" {
typedef struct llama_token_data {
llama_token id; // token id
float logit; // log-odds of the token
float p; // probability of the token
float plog; // log probability of the token
} llama_token_data;
typedef struct llama_token_data_array {
llama_token_data * data;
size_t size;
bool sorted;
} llama_token_data_array;
typedef void (*llama_progress_callback)(float progress, void *ctx);
struct llama_context_params {
@ -182,15 +186,21 @@ extern "C" {
LLAMA_API llama_token llama_token_bos();
LLAMA_API llama_token llama_token_eos();
// TODO: improve the last_n_tokens interface ?
LLAMA_API llama_token llama_sample_top_p_top_k(
struct llama_context * ctx,
const llama_token * last_n_tokens_data,
int last_n_tokens_size,
int top_k,
float top_p,
float temp,
float repeat_penalty);
// Sampling functions
LLAMA_API void llama_sample_repetition_penalty(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty);
LLAMA_API void llama_sample_frequency_and_presence_penalties(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
LLAMA_API void llama_sample_softmax(llama_token_data_array * candidates);
LLAMA_API void llama_sample_top_k(llama_token_data_array * candidates, int k);
LLAMA_API void llama_sample_top_p(llama_token_data_array * candidates, float p, size_t min_keep = 1);
LLAMA_API void llama_sample_tail_free(llama_token_data_array * candidates, float z, size_t min_keep = 1);
LLAMA_API void llama_sample_typical(llama_token_data_array * candidates, float p, size_t min_keep = 1);
LLAMA_API void llama_sample_temperature(llama_token_data_array * candidates, float temp);
LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates);
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
LLAMA_API llama_token llama_sample_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float N, int * k, float * mu);
LLAMA_API llama_token llama_sample_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
// Performance information
LLAMA_API void llama_print_timings(struct llama_context * ctx);

View file

@ -3,9 +3,12 @@ function(llama_add_test source)
add_executable(${TEST_TARGET} ${source})
target_link_libraries(${TEST_TARGET} PRIVATE llama)
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN})
target_compile_options(${TEST_TARGET} PRIVATE -fsanitize=address)
target_link_options(${TEST_TARGET} PRIVATE -fsanitize=address)
endfunction()
# llama_add_test(test-double-float.c) # SLOW
llama_add_test(test-quantize-fns.cpp)
llama_add_test(test-quantize-perf.cpp)
llama_add_test(test-sampling.cpp)
llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin)