mirostat
This commit is contained in:
parent
9b3b07cc5c
commit
f01c67fe55
7 changed files with 425 additions and 91 deletions
|
@ -150,6 +150,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
break;
|
||||
}
|
||||
params.alpha_presence = std::stof(argv[i]);
|
||||
} else if (arg == "--mirostat") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.mirostat = std::stoi(argv[i]);
|
||||
} else if (arg == "--mirostat_eta") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.mirostat_eta = std::stof(argv[i]);
|
||||
} else if (arg == "--mirostat_tau") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.mirostat_tau = std::stof(argv[i]);
|
||||
} else if (arg == "-b" || arg == "--batch_size") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -264,14 +282,17 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
fprintf(stderr, " -f FNAME, --file FNAME\n");
|
||||
fprintf(stderr, " prompt file to start generation.\n");
|
||||
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
|
||||
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
|
||||
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", (double)params.top_p);
|
||||
fprintf(stderr, " --tfs N tail free sampling (default: %.1f)\n", (double)params.tfs_z);
|
||||
fprintf(stderr, " --typical N locally typical sampling (default: %.1f)\n", (double)params.typical_p);
|
||||
fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %d)\n", params.alpha_presence);
|
||||
fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f)\n", (double)params.alpha_frequency);
|
||||
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
|
||||
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty);
|
||||
fprintf(stderr, " --top_k N top-k sampling (default: %d, disabled: 0)\n", params.top_k);
|
||||
fprintf(stderr, " --top_p N top-p sampling (default: %.1f, disabled: 1.0)\n", (double)params.top_p);
|
||||
fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, disabled: 1.0)\n", (double)params.tfs_z);
|
||||
fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, disabled: 1.0)\n", (double)params.typical_p);
|
||||
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, disabled: 0)\n", params.repeat_last_n);
|
||||
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f, disabled: 1.0)\n", (double)params.repeat_penalty);
|
||||
fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %.1f, disabled: 0.0)\n", (double)params.alpha_presence);
|
||||
fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f, disabled: 0.0)\n", (double)params.alpha_frequency);
|
||||
fprintf(stderr, " --mirostat N use mirostat sampling (default: %d, disabled: 0, mirostat: 1, mirostat 2.0: 2)\n", params.mirostat);
|
||||
fprintf(stderr, " --mirostat_eta N mirostat learning rate (default: %.1f)\n", (double)params.mirostat_eta);
|
||||
fprintf(stderr, " --mirostat_tau N mirostat target entropy (default: %.1f)\n", (double)params.mirostat_tau);
|
||||
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
|
||||
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n");
|
||||
fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n");
|
||||
|
|
|
@ -17,17 +17,24 @@ struct gpt_params {
|
|||
int32_t seed = -1; // RNG seed
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t n_predict = 128; // new tokens to predict
|
||||
int32_t repeat_last_n = 64; // last n tokens to penalize
|
||||
int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)
|
||||
int32_t n_ctx = 512; // context size
|
||||
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
|
||||
// sampling parameters
|
||||
int32_t top_k = 40;
|
||||
float top_p = 0.95f;
|
||||
float temp = 0.80f;
|
||||
float repeat_penalty = 1.10f;
|
||||
int32_t top_k = 0; // <= 0 to use vocab size
|
||||
float top_p = 1.0f; // 1.0 = disabled
|
||||
float tfs_z = 1.0f; // 1.0 = disabled
|
||||
float typical_p = 1.0f; // 1.0 = disabled
|
||||
float temp = 1.0f; // 1.0 = disabled
|
||||
float repeat_penalty = 1.0f; // 1.0 = disabled
|
||||
int32_t repeat_last_n = -1; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
float alpha_frequency = 0.0f; // 0.0 = disabled
|
||||
float alpha_presence = 0.0f; // 0.0 = disabled
|
||||
int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float mirostat_tau = 5.0f; // target entropy
|
||||
float mirostat_eta = 0.1f; // learning rate
|
||||
|
||||
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
|
||||
std::string prompt = "";
|
||||
|
|
|
@ -276,8 +276,8 @@ int main(int argc, char ** argv) {
|
|||
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f\n",
|
||||
params.repeat_last_n, params.repeat_penalty, params.alpha_presence, params.alpha_frequency, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp);
|
||||
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_eta = %f, mirostat_tau = %f\n",
|
||||
params.repeat_last_n, params.repeat_penalty, params.alpha_presence, params.alpha_frequency, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
|
||||
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||
fprintf(stderr, "\n\n");
|
||||
|
||||
|
@ -396,6 +396,9 @@ int main(int argc, char ** argv) {
|
|||
const float repeat_penalty = params.repeat_penalty;
|
||||
const float alpha_presence = params.alpha_presence;
|
||||
const float alpha_frequency = params.alpha_frequency;
|
||||
const int mirostat = params.mirostat;
|
||||
const float mirostat_tau = params.mirostat_tau;
|
||||
const float mirostat_eta = params.mirostat_eta;
|
||||
|
||||
// optionally save the session on first sample (for faster prompt loading next time)
|
||||
if (!path_session.empty() && need_to_save_session) {
|
||||
|
@ -415,47 +418,45 @@ int main(int argc, char ** argv) {
|
|||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (size_t i = 0; i < n_vocab; i++) {
|
||||
for (size_t i = 0; i < (size_t) n_vocab; i++) {
|
||||
candidates.emplace_back(i, logits[i], 0.0f);
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size() };
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
// Apply penalties
|
||||
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
|
||||
llama_sample_repetition_penalty(&candidates_p,
|
||||
llama_sample_repetition_penalty(ctx, &candidates_p,
|
||||
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
||||
last_n_repeat, repeat_penalty);
|
||||
llama_sample_frequency_and_presence_penalties(&candidates_p,
|
||||
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
|
||||
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
||||
last_n_repeat, alpha_frequency, alpha_presence);
|
||||
|
||||
|
||||
#if 1
|
||||
if (temp <= 0) {
|
||||
// Greedy sampling
|
||||
id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||
} else {
|
||||
// Temperature sampling
|
||||
llama_sample_top_k(&candidates_p, top_k);
|
||||
llama_sample_tail_free(&candidates_p, tfs_z);
|
||||
llama_sample_typical(&candidates_p, typical_p);
|
||||
llama_sample_top_p(&candidates_p, top_p);
|
||||
|
||||
llama_sample_temperature(&candidates_p, temp);
|
||||
// printf("`%d`", candidates_p.size);
|
||||
id = llama_sample_token(ctx, &candidates_p);
|
||||
if (mirostat == 1) {
|
||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||
static int mirostat_k = 40;
|
||||
const int mirostat_m = 100;
|
||||
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, float(n_vocab), &mirostat_k, &mirostat_mu);
|
||||
} else if (mirostat == 2) {
|
||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
|
||||
} else {
|
||||
// Temperature sampling
|
||||
llama_sample_top_k(ctx, &candidates_p, top_k);
|
||||
llama_sample_tail_free(ctx, &candidates_p, tfs_z);
|
||||
llama_sample_typical(ctx, &candidates_p, typical_p);
|
||||
llama_sample_top_p(ctx, &candidates_p, top_p);
|
||||
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||
id = llama_sample_token(ctx, &candidates_p);
|
||||
}
|
||||
}
|
||||
#else
|
||||
const float tau = 5.0f;
|
||||
static float mu = 2.0f * tau;
|
||||
static int k = 40;
|
||||
const float eta = 0.1f;
|
||||
const int m = 100;
|
||||
const float N = n_vocab;
|
||||
id = llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu);
|
||||
// id = llama_sample_mirostat_v2(ctx, &candidates_p, tau, eta, &mu);
|
||||
#endif
|
||||
// printf("`%d`", candidates_p.size);
|
||||
|
||||
last_n_tokens.erase(last_n_tokens.begin());
|
||||
last_n_tokens.push_back(id);
|
||||
|
|
158
llama.cpp
158
llama.cpp
|
@ -1479,8 +1479,11 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
|
|||
// sampling
|
||||
//
|
||||
|
||||
void llama_sample_softmax(llama_token_data_array * candidates) {
|
||||
void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
|
||||
assert(candidates->size > 0);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
std::span<llama_token_data> tokens(candidates->data, candidates->size);
|
||||
|
||||
// Sort the logits in descending order
|
||||
|
@ -1502,35 +1505,47 @@ void llama_sample_softmax(llama_token_data_array * candidates) {
|
|||
for (size_t i = 0; i < tokens.size(); ++i) {
|
||||
tokens[i].p /= cum_sum;
|
||||
}
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_top_k(llama_token_data_array * candidates_p, int k) {
|
||||
assert(k > 0);
|
||||
void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates_p, int k, size_t min_keep) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
k = std::max(k, (int) min_keep);
|
||||
k = std::min(k, (int) candidates_p->size);
|
||||
|
||||
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
|
||||
|
||||
// Sort scores in descending order
|
||||
if (!candidates_p->sorted) {
|
||||
if (k >= candidates_p->size) {
|
||||
std::sort(candidates.begin(), candidates.end(), [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
});
|
||||
auto comp = [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
};
|
||||
if (k == (int) candidates_p->size) {
|
||||
std::sort(candidates.begin(), candidates.end(), comp);
|
||||
} else {
|
||||
std::partial_sort(candidates.begin(), candidates.begin() + k, candidates.end(),
|
||||
[](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
});
|
||||
std::partial_sort(candidates.begin(), candidates.begin() + k, candidates.end(), comp);
|
||||
}
|
||||
candidates_p->sorted = true;
|
||||
}
|
||||
candidates_p->size = std::min(k, (int) candidates.size());
|
||||
candidates_p->size = k;
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_top_p(llama_token_data_array * candidates_p, float p, size_t min_keep) {
|
||||
void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates_p, float p, size_t min_keep) {
|
||||
if (p >= 1.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_sample_softmax(candidates_p);
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
llama_sample_softmax(ctx, candidates_p);
|
||||
|
||||
// Compute the cumulative probabilities
|
||||
float cum_sum = 0.0f;
|
||||
|
@ -1548,15 +1563,21 @@ void llama_sample_top_p(llama_token_data_array * candidates_p, float p, size_t m
|
|||
|
||||
// Resize the output vector to keep only the top-p tokens
|
||||
candidates_p->size = last_idx;
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
// https://www.trentonbricken.com/Tail-Free-Sampling/
|
||||
void llama_sample_tail_free(llama_token_data_array * candidates_p, float z, size_t min_keep) {
|
||||
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates_p, float z, size_t min_keep) {
|
||||
if (z >= 1.0f || candidates_p->size <= 2) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_sample_softmax(candidates_p);
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
llama_sample_softmax(nullptr, candidates_p);
|
||||
|
||||
// Compute the first and second derivatives
|
||||
std::vector<float> first_derivatives(candidates_p->size - 1);
|
||||
|
@ -1594,18 +1615,23 @@ void llama_sample_tail_free(llama_token_data_array * candidates_p, float z, size
|
|||
|
||||
// Resize the output vector to keep only the tokens above the tail location
|
||||
candidates_p->size = last_idx;
|
||||
}
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
// https://arxiv.org/pdf/2202.00666.pdf
|
||||
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
|
||||
void llama_sample_typical(llama_token_data_array * candidates_p, float typical_p, size_t min_keep) {
|
||||
void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates_p, float typical_p, size_t min_keep) {
|
||||
if (typical_p >= 1.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Compute the softmax of logits and calculate entropy
|
||||
llama_sample_softmax(candidates_p);
|
||||
llama_sample_softmax(nullptr, candidates_p);
|
||||
|
||||
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
|
||||
|
||||
|
@ -1654,21 +1680,32 @@ void llama_sample_typical(llama_token_data_array * candidates_p, float typical_p
|
|||
// Replace the data in candidates_p with the new_candidates data
|
||||
std::copy(new_candidates.begin(), new_candidates.end(), candidates_p->data);
|
||||
candidates_p->size = new_candidates.size();
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
void llama_sample_temperature(llama_token_data_array * candidates_p, float temp) {
|
||||
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
|
||||
for (auto & candidate : candidates) {
|
||||
candidate.logit /= temp;
|
||||
}
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_repetition_penalty(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty) {
|
||||
void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty) {
|
||||
if (last_tokens_size == 0 || penalty == 1.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// CTRL paper: https://arxiv.org/pdf/1909.05858.pdf
|
||||
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
|
||||
std::span<llama_token> last_tokens(last_tokens_p, last_tokens_size);
|
||||
|
@ -1695,13 +1732,19 @@ void llama_sample_repetition_penalty(llama_token_data_array * candidates_p, llam
|
|||
}
|
||||
|
||||
candidates_p->sorted = false;
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_frequency_and_presence_penalties(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) {
|
||||
void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) {
|
||||
if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
|
||||
std::span<llama_token> last_tokens(last_tokens_p, last_tokens_size);
|
||||
|
||||
|
@ -1723,6 +1766,10 @@ void llama_sample_frequency_and_presence_penalties(llama_token_data_array * cand
|
|||
}
|
||||
|
||||
candidates_p->sorted = false;
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Mirostat 1.0 implementation.
|
||||
|
@ -1733,20 +1780,26 @@ void llama_sample_frequency_and_presence_penalties(llama_token_data_array * cand
|
|||
/// @param N The size of the vocabulary. This is used in the calculation of the `k` value.
|
||||
/// @param k A reference to the integer variable used to store the calculated top-k value. The top-k value determines how many of the most probable tokens are considered for sampling.
|
||||
/// @param mu A reference to the floating-point variable that represents the maximum cross-entropy value. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||
llama_token llama_sample_mirostat(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, int m, float N, int * k, float * mu) {
|
||||
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, int m, float N, int * k, float * mu) {
|
||||
assert(ctx);
|
||||
|
||||
int64_t t_start_sample_us;
|
||||
t_start_sample_us = ggml_time_us();
|
||||
|
||||
// https://arxiv.org/abs/2007.14966
|
||||
// Algorithm 1
|
||||
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
|
||||
|
||||
// printf("llama_sample_mirostat: candidates.size() = %d, m = %d, N = %f, tau = %f, eta = %f, *k = %d, *mu = %f\n", candidates.size(), m, N, tau, eta, *k, *mu);
|
||||
|
||||
llama_sample_softmax(candidates_p);
|
||||
llama_sample_softmax(nullptr, candidates_p);
|
||||
|
||||
// Estimate s_hat using the most probable m tokens
|
||||
float s_hat = 0.0;
|
||||
float sum_ti_bi = 0.0;
|
||||
float sum_ti_sq = 0.0;
|
||||
for (int i = 0; i < m - 1 && i < candidates.size() - 1; ++i) {
|
||||
float t_i = logf((i + 2) / float(i + 1));
|
||||
for (size_t i = 0; i < size_t(m - 1) && i < candidates.size() - 1; ++i) {
|
||||
float t_i = logf(float(i + 2) / float(i + 1));
|
||||
float b_i = logf(candidates[i].p / candidates[i + 1].p);
|
||||
sum_ti_bi += t_i * b_i;
|
||||
sum_ti_sq += t_i * t_i;
|
||||
|
@ -1757,15 +1810,20 @@ llama_token llama_sample_mirostat(struct llama_context * ctx, llama_token_data_a
|
|||
float epsilon_hat = s_hat - 1;
|
||||
// printf("llama_sample_mirostat: s_hat = %f, epsilon_hat = %f, *mu = %f, N = %f\n", s_hat, epsilon_hat, *mu, N);
|
||||
float new_k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat);
|
||||
*k = std::min(new_k, float(candidates.size()));
|
||||
// printf("llama_sample_mirostat: new_k = %f\n", new_k);
|
||||
*k = int(std::min(new_k, float(candidates.size())));
|
||||
|
||||
// Sample the next word X using top-k sampling
|
||||
// printf("llama_sample_mirostat *k = %d\n", *k);
|
||||
llama_sample_top_k(candidates_p, *k);
|
||||
llama_sample_top_k(nullptr, candidates_p, *k);
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
llama_token X = llama_sample_token(ctx, candidates_p);
|
||||
t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Compute error as the difference between observed surprise and target surprise value
|
||||
int X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) {
|
||||
size_t X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) {
|
||||
return candidate.id == X;
|
||||
}));
|
||||
float observed_surprise = -log2f(candidates[X_idx].p);
|
||||
|
@ -1774,13 +1832,23 @@ llama_token llama_sample_mirostat(struct llama_context * ctx, llama_token_data_a
|
|||
// Update mu using the learning rate and error
|
||||
*mu = *mu - eta * e;
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
ctx->n_sample++;
|
||||
}
|
||||
return X;
|
||||
}
|
||||
|
||||
llama_token llama_sample_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, float * mu) {
|
||||
llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, float * mu) {
|
||||
assert(ctx);
|
||||
int64_t t_start_sample_us;
|
||||
t_start_sample_us = ggml_time_us();
|
||||
|
||||
// https://arxiv.org/abs/2007.14966
|
||||
// Algorithm 2
|
||||
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
|
||||
|
||||
llama_sample_softmax(candidates_p);
|
||||
llama_sample_softmax(ctx, candidates_p);
|
||||
|
||||
// Truncate the words with surprise values greater than mu
|
||||
candidates_p->size = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) {
|
||||
|
@ -1788,13 +1856,17 @@ llama_token llama_sample_mirostat_v2(struct llama_context * ctx, llama_token_dat
|
|||
}));
|
||||
|
||||
// Normalize the probabilities of the remaining words
|
||||
llama_sample_softmax(candidates_p);
|
||||
llama_sample_softmax(ctx, candidates_p);
|
||||
|
||||
// Sample the next word X from the remaining words
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
llama_token X = llama_sample_token(ctx, candidates_p);
|
||||
t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Compute error as the difference between observed surprise and target surprise value
|
||||
int X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) {
|
||||
size_t X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) {
|
||||
return candidate.id == X;
|
||||
}));
|
||||
float observed_surprise = -log2f(candidates[X_idx].p);
|
||||
|
@ -1803,11 +1875,15 @@ llama_token llama_sample_mirostat_v2(struct llama_context * ctx, llama_token_dat
|
|||
// Update mu using the learning rate and error
|
||||
*mu = *mu - eta * e;
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
return X;
|
||||
}
|
||||
|
||||
|
||||
llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates_p) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Find max element
|
||||
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
|
||||
auto max_iter = std::max_element(candidates.begin(), candidates.end(), [](const llama_token_data & a, const llama_token_data & b) {
|
||||
|
@ -1815,14 +1891,17 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da
|
|||
});
|
||||
|
||||
llama_token result = max_iter->id;
|
||||
ctx->n_sample++;
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
ctx->n_sample++;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates_p) {
|
||||
// const int64_t t_start_sample_us = ggml_time_us();
|
||||
llama_sample_softmax(candidates_p);
|
||||
assert(ctx);
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
llama_sample_softmax(nullptr, candidates_p);
|
||||
|
||||
std::span<llama_token_data> candidates(candidates_p->data, candidates_p->size);
|
||||
|
||||
|
@ -1838,9 +1917,8 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
|
|||
|
||||
llama_token result = candidates[idx].id;
|
||||
|
||||
// ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
ctx->n_sample++;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
20
llama.h
20
llama.h
|
@ -187,20 +187,20 @@ extern "C" {
|
|||
LLAMA_API llama_token llama_token_eos();
|
||||
|
||||
// Sampling functions
|
||||
LLAMA_API void llama_sample_repetition_penalty(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty);
|
||||
LLAMA_API void llama_sample_frequency_and_presence_penalties(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
|
||||
LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty);
|
||||
LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
|
||||
|
||||
LLAMA_API void llama_sample_softmax(llama_token_data_array * candidates);
|
||||
LLAMA_API void llama_sample_top_k(llama_token_data_array * candidates, int k);
|
||||
LLAMA_API void llama_sample_top_p(llama_token_data_array * candidates, float p, size_t min_keep = 1);
|
||||
LLAMA_API void llama_sample_tail_free(llama_token_data_array * candidates, float z, size_t min_keep = 1);
|
||||
LLAMA_API void llama_sample_typical(llama_token_data_array * candidates, float p, size_t min_keep = 1);
|
||||
LLAMA_API void llama_sample_temperature(llama_token_data_array * candidates, float temp);
|
||||
LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1);
|
||||
LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
|
||||
LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1);
|
||||
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
|
||||
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
|
||||
|
||||
LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float N, int * k, float * mu);
|
||||
LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
|
||||
LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
LLAMA_API llama_token llama_sample_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float N, int * k, float * mu);
|
||||
LLAMA_API llama_token llama_sample_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
|
||||
|
||||
// Performance information
|
||||
LLAMA_API void llama_print_timings(struct llama_context * ctx);
|
||||
|
|
|
@ -3,8 +3,6 @@ function(llama_add_test source)
|
|||
add_executable(${TEST_TARGET} ${source})
|
||||
target_link_libraries(${TEST_TARGET} PRIVATE llama)
|
||||
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN})
|
||||
target_compile_options(${TEST_TARGET} PRIVATE -fsanitize=address)
|
||||
target_link_options(${TEST_TARGET} PRIVATE -fsanitize=address)
|
||||
endfunction()
|
||||
|
||||
# llama_add_test(test-double-float.c) # SLOW
|
||||
|
|
229
tests/test-sampling.cpp
Normal file
229
tests/test-sampling.cpp
Normal file
|
@ -0,0 +1,229 @@
|
|||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include <assert.h>
|
||||
#include <math.h>
|
||||
#include <numeric>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
void dump(const llama_token_data_array * candidates) {
|
||||
for (size_t i = 0; i < candidates->size; i++) {
|
||||
printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit);
|
||||
}
|
||||
}
|
||||
|
||||
#define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0)
|
||||
|
||||
|
||||
void test_top_k(const std::vector<float> & probs,
|
||||
const std::vector<float> & expected_probs,
|
||||
int k) {
|
||||
size_t n_vocab = probs.size();
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
float logit = log(probs[i]);
|
||||
candidates.emplace_back(llama_token_data{i, logit, 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size() };
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
// DUMP(&candidates_p);
|
||||
llama_sample_top_k(nullptr, &candidates_p, k);
|
||||
// DUMP(&candidates_p);
|
||||
|
||||
assert(candidates_p.size == expected_probs.size());
|
||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void test_top_p(const std::vector<float> & probs,
|
||||
const std::vector<float> & expected_probs,
|
||||
float p) {
|
||||
|
||||
size_t n_vocab = probs.size();
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
float logit = log(probs[i]);
|
||||
candidates.emplace_back(llama_token_data{i, logit, 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size() };
|
||||
// DUMP(&candidates_p);
|
||||
llama_sample_top_p(nullptr, &candidates_p, p);
|
||||
// DUMP(&candidates_p);
|
||||
|
||||
assert(candidates_p.size == expected_probs.size());
|
||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void test_tfs(const std::vector<float> & probs,
|
||||
const std::vector<float> & expected_probs,
|
||||
float z) {
|
||||
size_t n_vocab = probs.size();
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
float logit = log(probs[i]);
|
||||
candidates.emplace_back(llama_token_data{i, logit, 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size() };
|
||||
// DUMP(&candidates_p);
|
||||
llama_sample_tail_free(nullptr, &candidates_p, z);
|
||||
// DUMP(&candidates_p);
|
||||
|
||||
assert(candidates_p.size == expected_probs.size());
|
||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void test_typical(const std::vector<float> & probs,
|
||||
const std::vector<float> & expected_probs,
|
||||
float p) {
|
||||
size_t n_vocab = probs.size();
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
float logit = log(probs[i]);
|
||||
candidates.emplace_back(llama_token_data{i, logit, 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size() };
|
||||
// DUMP(&candidates_p);
|
||||
llama_sample_typical(nullptr, &candidates_p, p);
|
||||
// DUMP(&candidates_p);
|
||||
|
||||
assert(candidates_p.size == expected_probs.size());
|
||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void test_repetition_penalty(
|
||||
const std::vector<float> & probs,
|
||||
const std::vector<llama_token> & last_tokens,
|
||||
const std::vector<float> & expected_probs,
|
||||
float penalty) {
|
||||
assert(probs.size() == expected_probs.size());
|
||||
|
||||
size_t n_vocab = probs.size();
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
float logit = log(probs[i]);
|
||||
candidates.emplace_back(llama_token_data{i, logit, 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size() };
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
DUMP(&candidates_p);
|
||||
llama_sample_repetition_penalty(nullptr, &candidates_p, (llama_token *)last_tokens.data(), last_tokens.size(), penalty);
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
assert(candidates_p.size == expected_probs.size());
|
||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void test_frequency_presence_penalty(
|
||||
const std::vector<float> & probs,
|
||||
const std::vector<llama_token> & last_tokens,
|
||||
const std::vector<float> & expected_probs,
|
||||
float alpha_frequency, float alpha_presence) {
|
||||
assert(probs.size() == expected_probs.size());
|
||||
|
||||
size_t n_vocab = probs.size();
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
float logit = log(probs[i]);
|
||||
candidates.emplace_back(llama_token_data{i, logit, 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size() };
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
// DUMP(&candidates_p);
|
||||
llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p, (llama_token *)last_tokens.data(), last_tokens.size(), alpha_frequency, alpha_presence);
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
// DUMP(&candidates_p);
|
||||
|
||||
assert(candidates_p.size == expected_probs.size());
|
||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void test_mirostat() {
|
||||
std::vector<float> probs = {0.1, 0.2, 0.3, 0.4};
|
||||
std::vector<float> expected_probs = {0.1, 0.2, 0.3, 0.4};
|
||||
|
||||
size_t n_vocab = probs.size();
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
float logit = log(probs[i]);
|
||||
candidates.emplace_back(llama_token_data{i, logit, 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size() };
|
||||
DUMP(&candidates_p);
|
||||
|
||||
float tau = 5.0f;
|
||||
float mu = 2.0f * tau;
|
||||
int k = 0;
|
||||
float eta = 0.1f;
|
||||
int m = 100;
|
||||
// float N = 32000;
|
||||
float N = 4;
|
||||
// llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
// assert(candidates_p.size == expected_probs.size());
|
||||
// for (size_t i = 0; i < candidates_p.size; i++) {
|
||||
// assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6);
|
||||
// }
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
test_top_k({0.1, 0.2, 0.3, 0.4}, {0.4}, 1);
|
||||
test_top_k({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3, 0.2}, 3);
|
||||
|
||||
test_top_p({0.1, 0.2, 0.3, 0.4}, {0.4}, 0);
|
||||
test_top_p({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3}, 0.7);
|
||||
test_top_p({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3, 0.2, 0.1}, 1);
|
||||
|
||||
test_tfs({0.1, 0.15, 0.2, 0.25, 0.3}, {0.3}, 0.25);
|
||||
test_tfs({0.1, 0.15, 0.2, 0.25, 0.3}, {0.3, 0.25}, 0.75);
|
||||
test_tfs({0.1, 0.15, 0.2, 0.25, 0.3}, {0.3, 0.25}, 0.99);
|
||||
|
||||
test_typical({0.97, 0.01, 0.01, 0.01}, {0.97}, 0.5);
|
||||
test_typical({0.4, 0.2, 0.2, 0.2}, {0.2, 0.2, 0.2}, 0.5);
|
||||
|
||||
test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0, 0.25, 0.25, 0.25, 0.25}, 50.0);
|
||||
test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0, 0, 0, 0.5, 0.5}, 50.0);
|
||||
test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0, 0, 0, 0.5, 0.5}, 50.0);
|
||||
|
||||
test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0.249997, 0.249997, 0.249997, 0.249997, 0.000011}, 5.0, 5.0);
|
||||
test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0.499966, 0.499966, 0.000023, 0.000023, 0.000023}, 5.0, 5.0);
|
||||
test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0.499977, 0.499977, 0.000023, 0.000023, 0.000000}, 5.0, 5.0);
|
||||
|
||||
// test_mirostat();
|
||||
|
||||
printf("OK\n");
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue