Sample interface, new samplers.
New samplers: - locally typical sampling - tail free sampling - frequency and presence penalty - mirostat Ignore EOS fix: -inf should be used.
This commit is contained in:
parent
5fba3c016b
commit
9b3b07cc5c
7 changed files with 455 additions and 136 deletions
|
@ -76,7 +76,7 @@ option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
|
|||
# Compile flags
|
||||
#
|
||||
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED true)
|
||||
set(CMAKE_C_STANDARD 11)
|
||||
set(CMAKE_C_STANDARD_REQUIRED true)
|
||||
|
|
2
Makefile
2
Makefile
|
@ -35,7 +35,7 @@ endif
|
|||
|
||||
# keep standard at C11 and C++11
|
||||
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
|
||||
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
|
||||
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++20 -fPIC
|
||||
LDFLAGS =
|
||||
|
||||
# warnings
|
||||
|
|
|
@ -114,6 +114,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
break;
|
||||
}
|
||||
params.temp = std::stof(argv[i]);
|
||||
} else if (arg == "--tfs") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.tfs_z = std::stof(argv[i]);
|
||||
} else if (arg == "--typical") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.typical_p = std::stof(argv[i]);
|
||||
} else if (arg == "--repeat_last_n") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -126,6 +138,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
break;
|
||||
}
|
||||
params.repeat_penalty = std::stof(argv[i]);
|
||||
} else if (arg == "--alpha_frequency") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.alpha_frequency = std::stof(argv[i]);
|
||||
} else if (arg == "--alpha_presence") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.alpha_presence = std::stof(argv[i]);
|
||||
} else if (arg == "-b" || arg == "--batch_size") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -242,6 +266,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
|
||||
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
|
||||
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", (double)params.top_p);
|
||||
fprintf(stderr, " --tfs N tail free sampling (default: %.1f)\n", (double)params.tfs_z);
|
||||
fprintf(stderr, " --typical N locally typical sampling (default: %.1f)\n", (double)params.typical_p);
|
||||
fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %d)\n", params.alpha_presence);
|
||||
fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f)\n", (double)params.alpha_frequency);
|
||||
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
|
||||
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty);
|
||||
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
|
||||
|
|
|
@ -276,8 +276,8 @@ int main(int argc, char ** argv) {
|
|||
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
|
||||
params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
|
||||
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f\n",
|
||||
params.repeat_last_n, params.repeat_penalty, params.alpha_presence, params.alpha_frequency, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp);
|
||||
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||
fprintf(stderr, "\n\n");
|
||||
|
||||
|
@ -387,10 +387,15 @@ int main(int argc, char ** argv) {
|
|||
|
||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||
// out of user input, sample next token
|
||||
const int32_t top_k = params.top_k;
|
||||
const float top_p = params.top_p;
|
||||
const float temp = params.temp;
|
||||
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
|
||||
const float top_p = params.top_p;
|
||||
const float tfs_z = params.tfs_z;
|
||||
const float typical_p = params.typical_p;
|
||||
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
|
||||
const float repeat_penalty = params.repeat_penalty;
|
||||
const float alpha_presence = params.alpha_presence;
|
||||
const float alpha_frequency = params.alpha_frequency;
|
||||
|
||||
// optionally save the session on first sample (for faster prompt loading next time)
|
||||
if (!path_session.empty() && need_to_save_session) {
|
||||
|
@ -402,14 +407,55 @@ int main(int argc, char ** argv) {
|
|||
|
||||
{
|
||||
auto logits = llama_get_logits(ctx);
|
||||
auto n_vocab = llama_n_vocab(ctx);
|
||||
|
||||
if (params.ignore_eos) {
|
||||
logits[llama_token_eos()] = 0;
|
||||
logits[llama_token_eos()] = -INFINITY;
|
||||
}
|
||||
|
||||
id = llama_sample_top_p_top_k(ctx,
|
||||
last_n_tokens.data() + n_ctx - params.repeat_last_n,
|
||||
params.repeat_last_n, top_k, top_p, temp, repeat_penalty);
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (size_t i = 0; i < n_vocab; i++) {
|
||||
candidates.emplace_back(i, logits[i], 0.0f);
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size() };
|
||||
|
||||
// Apply penalties
|
||||
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
|
||||
llama_sample_repetition_penalty(&candidates_p,
|
||||
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
||||
last_n_repeat, repeat_penalty);
|
||||
llama_sample_frequency_and_presence_penalties(&candidates_p,
|
||||
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
||||
last_n_repeat, alpha_frequency, alpha_presence);
|
||||
|
||||
|
||||
#if 1
|
||||
if (temp <= 0) {
|
||||
// Greedy sampling
|
||||
id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||
} else {
|
||||
// Temperature sampling
|
||||
llama_sample_top_k(&candidates_p, top_k);
|
||||
llama_sample_tail_free(&candidates_p, tfs_z);
|
||||
llama_sample_typical(&candidates_p, typical_p);
|
||||
llama_sample_top_p(&candidates_p, top_p);
|
||||
|
||||
llama_sample_temperature(&candidates_p, temp);
|
||||
// printf("`%d`", candidates_p.size);
|
||||
id = llama_sample_token(ctx, &candidates_p);
|
||||
}
|
||||
#else
|
||||
const float tau = 5.0f;
|
||||
static float mu = 2.0f * tau;
|
||||
static int k = 40;
|
||||
const float eta = 0.1f;
|
||||
const int m = 100;
|
||||
const float N = n_vocab;
|
||||
id = llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu);
|
||||
// id = llama_sample_mirostat_v2(ctx, &candidates_p, tau, eta, &mu);
|
||||
#endif
|
||||
|
||||
last_n_tokens.erase(last_n_tokens.begin());
|
||||
last_n_tokens.push_back(id);
|
||||
|
|
460
llama.cpp
460
llama.cpp
|
@ -28,6 +28,7 @@
|
|||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <sstream>
|
||||
#include <span>
|
||||
|
||||
#define LLAMA_USE_SCRATCH
|
||||
#define LLAMA_MAX_SCRATCH_BUFFERS 16
|
||||
|
@ -1478,109 +1479,369 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
|
|||
// sampling
|
||||
//
|
||||
|
||||
static void sample_top_k(std::vector<std::pair<float, llama_vocab::id>> & logits_id, int top_k) {
|
||||
// find the top k tokens
|
||||
std::partial_sort(
|
||||
logits_id.begin(),
|
||||
logits_id.begin() + top_k, logits_id.end(),
|
||||
[](const std::pair<float, llama_vocab::id> & a, const std::pair<float, llama_vocab::id> & b) {
|
||||
return a.first > b.first;
|
||||
});
|
||||
void llama_sample_softmax(llama_token_data_array * candidates) {
|
||||
assert(candidates->size > 0);
|
||||
std::span<llama_token_data> tokens(candidates->data, candidates->size);
|
||||
|
||||
logits_id.resize(top_k);
|
||||
// Sort the logits in descending order
|
||||
if (!candidates->sorted) {
|
||||
std::sort(tokens.begin(), tokens.end(), [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
});
|
||||
candidates->sorted = true;
|
||||
}
|
||||
|
||||
float max_l = tokens[0].logit;
|
||||
float cum_sum = 0.0f;
|
||||
for (size_t i = 0; i < tokens.size(); ++i) {
|
||||
// printf("llama_sample_softmax: i: %d, logit: %f\n", i, tokens[i].logit);
|
||||
float p = expf(tokens[i].logit - max_l);
|
||||
tokens[i].p = p;
|
||||
cum_sum += p;
|
||||
}
|
||||
for (size_t i = 0; i < tokens.size(); ++i) {
|
||||
tokens[i].p /= cum_sum;
|
||||
}
|
||||
}
|
||||
|
||||
static llama_vocab::id llama_sample_top_p_top_k(
|
||||
llama_context & lctx,
|
||||
const std::vector<llama_vocab::id> & last_n_tokens,
|
||||
int top_k,
|
||||
float top_p,
|
||||
float temp,
|
||||
float repeat_penalty) {
|
||||
auto & rng = lctx.rng;
|
||||
void llama_sample_top_k(llama_token_data_array * candidates_p, int k) {
|
||||
assert(k > 0);
|
||||
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
|
||||
|
||||
const int n_logits = lctx.model.hparams.n_vocab;
|
||||
|
||||
const auto & logits = lctx.logits;
|
||||
const auto * plogits = logits.data() + logits.size() - n_logits;
|
||||
|
||||
if (temp <= 0) {
|
||||
// select the token with the highest logit directly
|
||||
float max_logit = plogits[0];
|
||||
llama_vocab::id max_id = 0;
|
||||
|
||||
for (int i = 1; i < n_logits; ++i) {
|
||||
if (plogits[i] > max_logit) {
|
||||
max_logit = plogits[i];
|
||||
max_id = i;
|
||||
}
|
||||
// Sort scores in descending order
|
||||
if (!candidates_p->sorted) {
|
||||
if (k >= candidates_p->size) {
|
||||
std::sort(candidates.begin(), candidates.end(), [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
});
|
||||
} else {
|
||||
std::partial_sort(candidates.begin(), candidates.begin() + k, candidates.end(),
|
||||
[](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
});
|
||||
}
|
||||
return max_id;
|
||||
candidates_p->sorted = true;
|
||||
}
|
||||
candidates_p->size = std::min(k, (int) candidates.size());
|
||||
}
|
||||
|
||||
void llama_sample_top_p(llama_token_data_array * candidates_p, float p, size_t min_keep) {
|
||||
if (p >= 1.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::pair<float, llama_vocab::id>> logits_id;
|
||||
logits_id.reserve(n_logits);
|
||||
llama_sample_softmax(candidates_p);
|
||||
|
||||
{
|
||||
const float scale = 1.0f/temp;
|
||||
for (int i = 0; i < n_logits; ++i) {
|
||||
// repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
|
||||
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
|
||||
if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
|
||||
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||
if (plogits[i] < 0.0f) {
|
||||
logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
|
||||
} else {
|
||||
logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
|
||||
}
|
||||
} else {
|
||||
logits_id.push_back(std::make_pair(plogits[i]*scale, i));
|
||||
}
|
||||
// Compute the cumulative probabilities
|
||||
float cum_sum = 0.0f;
|
||||
size_t last_idx = candidates_p->size;
|
||||
|
||||
for (size_t i = 0; i < candidates_p->size; ++i) {
|
||||
cum_sum += candidates_p->data[i].p;
|
||||
|
||||
// Check if the running sum is greater than p or if we have kept at least min_keep tokens
|
||||
if (cum_sum > p && i >= min_keep) {
|
||||
last_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
sample_top_k(logits_id, top_k > 0 ? std::min(top_k, n_logits) : n_logits);
|
||||
// Resize the output vector to keep only the top-p tokens
|
||||
candidates_p->size = last_idx;
|
||||
}
|
||||
|
||||
// https://www.trentonbricken.com/Tail-Free-Sampling/
|
||||
void llama_sample_tail_free(llama_token_data_array * candidates_p, float z, size_t min_keep) {
|
||||
if (z >= 1.0f || candidates_p->size <= 2) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_sample_softmax(candidates_p);
|
||||
|
||||
// Compute the first and second derivatives
|
||||
std::vector<float> first_derivatives(candidates_p->size - 1);
|
||||
std::vector<float> second_derivatives(candidates_p->size - 2);
|
||||
|
||||
for (size_t i = 0; i < first_derivatives.size(); ++i) {
|
||||
first_derivatives[i] = candidates_p->data[i].p - candidates_p->data[i + 1].p;
|
||||
}
|
||||
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
||||
second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
|
||||
}
|
||||
|
||||
// Calculate absolute value of second derivatives
|
||||
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
||||
second_derivatives[i] = abs(second_derivatives[i]);
|
||||
}
|
||||
|
||||
// Normalize the second derivatives
|
||||
float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
|
||||
for (float & value : second_derivatives) {
|
||||
value /= second_derivatives_sum;
|
||||
}
|
||||
|
||||
float cum_sum = 0.0f;
|
||||
size_t last_idx = candidates_p->size;
|
||||
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
||||
cum_sum += second_derivatives[i];
|
||||
|
||||
// Check if the running sum is greater than z or if we have kept at least min_keep tokens
|
||||
if (cum_sum > z && i >= min_keep) {
|
||||
last_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Resize the output vector to keep only the tokens above the tail location
|
||||
candidates_p->size = last_idx;
|
||||
}
|
||||
|
||||
|
||||
// https://arxiv.org/pdf/2202.00666.pdf
|
||||
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
|
||||
void llama_sample_typical(llama_token_data_array * candidates_p, float typical_p, size_t min_keep) {
|
||||
if (typical_p >= 1.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute the softmax of logits and calculate entropy
|
||||
llama_sample_softmax(candidates_p);
|
||||
|
||||
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
|
||||
|
||||
float entropy = 0.0f;
|
||||
for (const auto & candidate : candidates) {
|
||||
entropy += -candidate.p * logf(candidate.p);
|
||||
}
|
||||
|
||||
// Compute the absolute difference between negative log probability and entropy for each candidate
|
||||
std::vector<float> shifted_scores;
|
||||
for (const auto & candidate : candidates) {
|
||||
float shifted_score = fabsf(-logf(candidate.p) - entropy);
|
||||
shifted_scores.push_back(shifted_score);
|
||||
}
|
||||
|
||||
// Sort candidates based on the shifted_scores and their corresponding indices
|
||||
std::vector<size_t> indices(candidates.size());
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
|
||||
std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
|
||||
return shifted_scores[a] < shifted_scores[b];
|
||||
});
|
||||
|
||||
// Compute the cumulative probabilities
|
||||
float cum_sum = 0.0f;
|
||||
size_t last_idx = indices.size();
|
||||
|
||||
for (size_t i = 0; i < indices.size(); ++i) {
|
||||
size_t idx = indices[i];
|
||||
cum_sum += candidates[idx].p;
|
||||
|
||||
// Check if the running sum is greater than typical or if we have kept at least min_keep tokens
|
||||
if (cum_sum > typical_p && i >= min_keep - 1) {
|
||||
last_idx = i + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Resize the output vector to keep only the locally typical tokens
|
||||
std::vector<llama_token_data> new_candidates;
|
||||
for (size_t i = 0; i < last_idx; ++i) {
|
||||
size_t idx = indices[i];
|
||||
new_candidates.push_back(candidates[idx]);
|
||||
}
|
||||
|
||||
// Replace the data in candidates_p with the new_candidates data
|
||||
std::copy(new_candidates.begin(), new_candidates.end(), candidates_p->data);
|
||||
candidates_p->size = new_candidates.size();
|
||||
}
|
||||
|
||||
|
||||
void llama_sample_temperature(llama_token_data_array * candidates_p, float temp) {
|
||||
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
|
||||
for (auto & candidate : candidates) {
|
||||
candidate.logit /= temp;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_repetition_penalty(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty) {
|
||||
if (last_tokens_size == 0 || penalty == 1.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
// CTRL paper: https://arxiv.org/pdf/1909.05858.pdf
|
||||
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
|
||||
std::span<llama_token> last_tokens(last_tokens_p, last_tokens_size);
|
||||
|
||||
for (size_t i = 0; i < candidates.size(); ++i) {
|
||||
auto token_iter = std::find(last_tokens.begin(), last_tokens.end(), candidates[i].id);
|
||||
if (token_iter == last_tokens.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
||||
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
||||
if (candidates[i].logit <= 0) {
|
||||
candidates[i].logit *= penalty;
|
||||
} else {
|
||||
candidates[i].logit /= penalty;
|
||||
}
|
||||
|
||||
// But it does not penalize tokens that logits are near zero, which is a problem.
|
||||
// Another solution is to convert the logits to probabilities, apply the penalty, and then convert back to logits.
|
||||
// float probability = std::exp(candidates[i].logit);
|
||||
// probability /= penalty;
|
||||
// candidates[i].logit = std::log(probability);
|
||||
}
|
||||
|
||||
candidates_p->sorted = false;
|
||||
}
|
||||
|
||||
void llama_sample_frequency_and_presence_penalties(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) {
|
||||
if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
|
||||
std::span<llama_token> last_tokens(last_tokens_p, last_tokens_size);
|
||||
|
||||
// Create a frequency map to count occurrences of each token in last_tokens
|
||||
std::unordered_map<llama_token, int> token_count;
|
||||
for (const auto & token : last_tokens) {
|
||||
token_count[token]++;
|
||||
}
|
||||
|
||||
// Apply frequency and presence penalties to the candidates
|
||||
for (size_t i = 0; i < candidates.size(); ++i) {
|
||||
auto token_iter = token_count.find(candidates[i].id);
|
||||
if (token_iter == token_count.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int count = token_iter->second;
|
||||
candidates[i].logit -= count * alpha_frequency + float(count > 0) * alpha_presence;
|
||||
}
|
||||
|
||||
candidates_p->sorted = false;
|
||||
}
|
||||
|
||||
/// @brief Mirostat 1.0 implementation.
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
|
||||
/// @param N The size of the vocabulary. This is used in the calculation of the `k` value.
|
||||
/// @param k A reference to the integer variable used to store the calculated top-k value. The top-k value determines how many of the most probable tokens are considered for sampling.
|
||||
/// @param mu A reference to the floating-point variable that represents the maximum cross-entropy value. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||
llama_token llama_sample_mirostat(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, int m, float N, int * k, float * mu) {
|
||||
// https://arxiv.org/abs/2007.14966
|
||||
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
|
||||
|
||||
// printf("llama_sample_mirostat: candidates.size() = %d, m = %d, N = %f, tau = %f, eta = %f, *k = %d, *mu = %f\n", candidates.size(), m, N, tau, eta, *k, *mu);
|
||||
|
||||
llama_sample_softmax(candidates_p);
|
||||
|
||||
// Estimate s_hat using the most probable m tokens
|
||||
float s_hat = 0.0;
|
||||
float sum_ti_bi = 0.0;
|
||||
float sum_ti_sq = 0.0;
|
||||
for (int i = 0; i < m - 1 && i < candidates.size() - 1; ++i) {
|
||||
float t_i = logf((i + 2) / float(i + 1));
|
||||
float b_i = logf(candidates[i].p / candidates[i + 1].p);
|
||||
sum_ti_bi += t_i * b_i;
|
||||
sum_ti_sq += t_i * t_i;
|
||||
}
|
||||
s_hat = sum_ti_bi / sum_ti_sq;
|
||||
|
||||
// Compute k from the estimated s_hat and target surprise value
|
||||
float epsilon_hat = s_hat - 1;
|
||||
// printf("llama_sample_mirostat: s_hat = %f, epsilon_hat = %f, *mu = %f, N = %f\n", s_hat, epsilon_hat, *mu, N);
|
||||
float new_k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat);
|
||||
*k = std::min(new_k, float(candidates.size()));
|
||||
|
||||
// Sample the next word X using top-k sampling
|
||||
// printf("llama_sample_mirostat *k = %d\n", *k);
|
||||
llama_sample_top_k(candidates_p, *k);
|
||||
llama_token X = llama_sample_token(ctx, candidates_p);
|
||||
|
||||
// Compute error as the difference between observed surprise and target surprise value
|
||||
int X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) {
|
||||
return candidate.id == X;
|
||||
}));
|
||||
float observed_surprise = -log2f(candidates[X_idx].p);
|
||||
float e = observed_surprise - tau;
|
||||
|
||||
// Update mu using the learning rate and error
|
||||
*mu = *mu - eta * e;
|
||||
|
||||
return X;
|
||||
}
|
||||
|
||||
llama_token llama_sample_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, float * mu) {
|
||||
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
|
||||
|
||||
llama_sample_softmax(candidates_p);
|
||||
|
||||
// Truncate the words with surprise values greater than mu
|
||||
candidates_p->size = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) {
|
||||
return -log2f(candidate.p) > *mu;
|
||||
}));
|
||||
|
||||
// Normalize the probabilities of the remaining words
|
||||
llama_sample_softmax(candidates_p);
|
||||
|
||||
// Sample the next word X from the remaining words
|
||||
llama_token X = llama_sample_token(ctx, candidates_p);
|
||||
|
||||
// Compute error as the difference between observed surprise and target surprise value
|
||||
int X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) {
|
||||
return candidate.id == X;
|
||||
}));
|
||||
float observed_surprise = -log2f(candidates[X_idx].p);
|
||||
float e = observed_surprise - tau;
|
||||
|
||||
// Update mu using the learning rate and error
|
||||
*mu = *mu - eta * e;
|
||||
|
||||
return X;
|
||||
}
|
||||
|
||||
|
||||
llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates_p) {
|
||||
// Find max element
|
||||
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
|
||||
auto max_iter = std::max_element(candidates.begin(), candidates.end(), [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit < b.logit;
|
||||
});
|
||||
|
||||
llama_token result = max_iter->id;
|
||||
ctx->n_sample++;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates_p) {
|
||||
// const int64_t t_start_sample_us = ggml_time_us();
|
||||
llama_sample_softmax(candidates_p);
|
||||
|
||||
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
|
||||
|
||||
// compute probs for the top k tokens
|
||||
std::vector<float> probs;
|
||||
probs.reserve(logits_id.size());
|
||||
|
||||
float maxl = logits_id[0].first;
|
||||
double sum = 0.0;
|
||||
for (const auto & kv : logits_id) {
|
||||
const float p = expf(kv.first - maxl);
|
||||
probs.push_back(p);
|
||||
sum += p;
|
||||
probs.reserve(candidates.size());
|
||||
for (auto & candidate : candidates) {
|
||||
probs.push_back(candidate.p);
|
||||
}
|
||||
|
||||
// normalize the probs
|
||||
for (auto & p : probs) {
|
||||
p /= sum;
|
||||
}
|
||||
|
||||
if (top_p < 1.0) {
|
||||
double cumsum = 0.0;
|
||||
for (int i = 0; i < (int) probs.size(); i++) {
|
||||
cumsum += probs[i];
|
||||
if (cumsum >= top_p) {
|
||||
probs.resize(i + 1);
|
||||
logits_id.resize(i + 1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//printf("\n");
|
||||
//for (int i = 0; i < (int) 10; i++) {
|
||||
// printf("%d: '%s' %f\n", i, lctx.vocab.id_to_token.at(logits_id[i].second).tok.c_str(), probs[i]);
|
||||
//}
|
||||
//printf("\n\n");
|
||||
//exit(0);
|
||||
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
auto & rng = ctx->rng;
|
||||
int idx = dist(rng);
|
||||
|
||||
return logits_id[idx].second;
|
||||
llama_token result = candidates[idx].id;
|
||||
|
||||
// ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
ctx->n_sample++;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
|
@ -2352,35 +2613,6 @@ llama_token llama_token_eos() {
|
|||
return 2;
|
||||
}
|
||||
|
||||
llama_token llama_sample_top_p_top_k(
|
||||
llama_context * ctx,
|
||||
const llama_token * last_n_tokens_data,
|
||||
int last_n_tokens_size,
|
||||
int top_k,
|
||||
float top_p,
|
||||
float temp,
|
||||
float repeat_penalty) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
llama_token result = 0;
|
||||
|
||||
// TODO: avoid this ...
|
||||
const auto last_n_tokens = std::vector<llama_token>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
|
||||
|
||||
result = llama_sample_top_p_top_k(
|
||||
*ctx,
|
||||
last_n_tokens,
|
||||
top_k,
|
||||
top_p,
|
||||
temp,
|
||||
repeat_penalty);
|
||||
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
ctx->n_sample++;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
void llama_print_timings(struct llama_context * ctx) {
|
||||
const int64_t t_end_us = ggml_time_us();
|
||||
|
|
34
llama.h
34
llama.h
|
@ -39,12 +39,16 @@ extern "C" {
|
|||
|
||||
typedef struct llama_token_data {
|
||||
llama_token id; // token id
|
||||
|
||||
float logit; // log-odds of the token
|
||||
float p; // probability of the token
|
||||
float plog; // log probability of the token
|
||||
|
||||
} llama_token_data;
|
||||
|
||||
typedef struct llama_token_data_array {
|
||||
llama_token_data * data;
|
||||
size_t size;
|
||||
bool sorted;
|
||||
} llama_token_data_array;
|
||||
|
||||
typedef void (*llama_progress_callback)(float progress, void *ctx);
|
||||
|
||||
struct llama_context_params {
|
||||
|
@ -182,15 +186,21 @@ extern "C" {
|
|||
LLAMA_API llama_token llama_token_bos();
|
||||
LLAMA_API llama_token llama_token_eos();
|
||||
|
||||
// TODO: improve the last_n_tokens interface ?
|
||||
LLAMA_API llama_token llama_sample_top_p_top_k(
|
||||
struct llama_context * ctx,
|
||||
const llama_token * last_n_tokens_data,
|
||||
int last_n_tokens_size,
|
||||
int top_k,
|
||||
float top_p,
|
||||
float temp,
|
||||
float repeat_penalty);
|
||||
// Sampling functions
|
||||
LLAMA_API void llama_sample_repetition_penalty(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty);
|
||||
LLAMA_API void llama_sample_frequency_and_presence_penalties(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
|
||||
|
||||
LLAMA_API void llama_sample_softmax(llama_token_data_array * candidates);
|
||||
LLAMA_API void llama_sample_top_k(llama_token_data_array * candidates, int k);
|
||||
LLAMA_API void llama_sample_top_p(llama_token_data_array * candidates, float p, size_t min_keep = 1);
|
||||
LLAMA_API void llama_sample_tail_free(llama_token_data_array * candidates, float z, size_t min_keep = 1);
|
||||
LLAMA_API void llama_sample_typical(llama_token_data_array * candidates, float p, size_t min_keep = 1);
|
||||
LLAMA_API void llama_sample_temperature(llama_token_data_array * candidates, float temp);
|
||||
|
||||
LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
LLAMA_API llama_token llama_sample_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float N, int * k, float * mu);
|
||||
LLAMA_API llama_token llama_sample_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
|
||||
|
||||
// Performance information
|
||||
LLAMA_API void llama_print_timings(struct llama_context * ctx);
|
||||
|
|
|
@ -3,9 +3,12 @@ function(llama_add_test source)
|
|||
add_executable(${TEST_TARGET} ${source})
|
||||
target_link_libraries(${TEST_TARGET} PRIVATE llama)
|
||||
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN})
|
||||
target_compile_options(${TEST_TARGET} PRIVATE -fsanitize=address)
|
||||
target_link_options(${TEST_TARGET} PRIVATE -fsanitize=address)
|
||||
endfunction()
|
||||
|
||||
# llama_add_test(test-double-float.c) # SLOW
|
||||
llama_add_test(test-quantize-fns.cpp)
|
||||
llama_add_test(test-quantize-perf.cpp)
|
||||
llama_add_test(test-sampling.cpp)
|
||||
llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue