llama : redirect external API to internal APIs

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-07-19 16:56:20 +03:00
parent 66ac80f5b9
commit 39fbaf9f50
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
9 changed files with 838 additions and 519 deletions

View file

@ -330,7 +330,7 @@ static llama_token llama_sampling_sample_impl(
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
// Apply grammar constraints to the single token
llama_grammar_sample(ctx_main, &single_token_data_array, ctx_sampling->grammar);
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array);
// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
@ -421,7 +421,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
// apply grammar checks before sampling logic
if (apply_grammar && ctx_sampling->grammar != NULL) {
llama_grammar_sample(ctx_main, &cur_p, ctx_sampling->grammar);
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
}
return cur_p;
@ -455,6 +455,6 @@ void llama_sampling_accept(
ctx_sampling->prev.push_back(id);
if (ctx_sampling->grammar != NULL && apply_grammar) {
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
}
}

View file

@ -965,6 +965,10 @@ extern "C" {
bool remove_special,
bool unparse_special);
//
// Chat templates
//
/// Apply chat template. Inspired by hf apply_chat_template() on python.
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
/// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
@ -1005,10 +1009,10 @@ extern "C" {
/// @details Apply constraints from grammar
LLAMA_API void llama_grammar_sample(
struct llama_context * ctx,
llama_token_data_array * candidates,
const struct llama_grammar * grammar);
LLAMA_API DEPRECATED(bool llama_sample_grammar(
const struct llama_grammar * grammar,
const struct llama_context * ctx,
llama_token_data_array * candidates);
LLAMA_API DEPRECATED(void llama_sample_grammar(
struct llama_context * ctx,
llama_token_data_array * candidates,
const struct llama_grammar * grammar),
@ -1016,8 +1020,8 @@ extern "C" {
/// @details Accepts the sampled token into the grammar
LLAMA_API void llama_grammar_accept_token(
struct llama_context * ctx,
struct llama_grammar * grammar,
struct llama_context * ctx,
llama_token token);
//

View file

@ -384,7 +384,7 @@ static bool llama_grammar_detect_left_recursion(
// grammar - external
//
struct llama_grammar * llama_grammar_init(
struct llama_grammar * llama_grammar_init_impl(
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index) {
@ -441,11 +441,11 @@ struct llama_grammar * llama_grammar_init(
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
}
void llama_grammar_free(struct llama_grammar * grammar) {
void llama_grammar_free_impl(struct llama_grammar * grammar) {
delete grammar;
}
struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) {
llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
// redirect elements in stacks to point to new rules
@ -464,8 +464,10 @@ struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar)
return result;
}
void llama_grammar_sample(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
GGML_ASSERT(ctx);
void llama_grammar_sample(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
GGML_ASSERT(grammar);
GGML_ASSERT(vocab);
int64_t t_start_sample_us = ggml_time_us();
bool allow_eog = false;
@ -484,9 +486,9 @@ void llama_grammar_sample(struct llama_context * ctx, llama_token_data_array * c
for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
const std::string & piece = llama_get_vocab(ctx)->cache_token_to_piece.at(id);
const std::string & piece = vocab->cache_token_to_piece.at(id);
if (llama_token_is_eog(llama_get_model(ctx), id)) {
if (llama_token_is_eog(*vocab, id)) {
if (!allow_eog) {
candidates->data[i].logit = -INFINITY;
}
@ -503,13 +505,13 @@ void llama_grammar_sample(struct llama_context * ctx, llama_token_data_array * c
candidates->data[reject.index].logit = -INFINITY;
}
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
void llama_grammar_accept_token(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
const int64_t t_start_sample_us = ggml_time_us();
if (llama_token_is_eog(llama_get_model(ctx), token)) {
if (llama_token_is_eog(*vocab, token)) {
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
return;
@ -518,7 +520,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
GGML_ASSERT(false);
}
const std::string & piece = llama_get_vocab(ctx)->cache_token_to_piece.at(token);
const std::string & piece = vocab->cache_token_to_piece.at(token);
// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
@ -533,5 +535,5 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
grammar->partial_utf8 = decoded.second;
GGML_ASSERT(!grammar->stacks.empty());
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}

View file

@ -3,6 +3,7 @@
#include "llama-impl.h"
struct llama_vocab;
struct llama_sampling;
struct llama_grammar {
const llama_grammar_rules rules;
@ -13,3 +14,24 @@ struct llama_grammar {
};
struct llama_grammar * llama_get_grammar(struct llama_context * ctx);
struct llama_grammar * llama_grammar_init_impl(
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index);
void llama_grammar_free_impl(struct llama_grammar * grammar);
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar);
void llama_grammar_sample(
const struct llama_grammar * grammar,
const struct llama_vocab * vocab,
const struct llama_sampling * smpl,
llama_token_data_array * candidates);
void llama_grammar_accept_token(
struct llama_grammar * grammar,
const struct llama_vocab * vocab,
const struct llama_sampling * smpl,
llama_token token);

View file

@ -7,15 +7,29 @@
#include <numeric>
#include <unordered_map>
void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) {
static void llama_log_softmax(float * array, size_t size) {
float max_l = *std::max_element(array, array + size);
float sum = 0.f;
for (size_t i = 0; i < size; ++i) {
float p = expf(array[i] - max_l);
sum += p;
array[i] = p;
}
for (size_t i = 0; i < size; ++i) {
array[i] = logf(array[i] / sum);
}
}
void llama_set_rng_seed(struct llama_sampling * smpl, uint32_t seed) {
if (seed == LLAMA_DEFAULT_SEED) {
seed = time(NULL);
}
llama_get_sampling(ctx)->rng.seed(seed);
smpl->rng.seed(seed);
}
void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
void llama_sample_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) {
GGML_ASSERT(candidates->size > 0);
const int64_t t_start_sample_us = ggml_time_us();
@ -39,12 +53,12 @@ void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * c
candidates->data[i].p /= cum_sum;
}
if (ctx) {
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
void llama_sample_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
// if (k >= (int32_t)candidates->size) {
// return;
@ -120,17 +134,17 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can
}
candidates->size = k;
if (ctx) {
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
void llama_sample_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
if (p >= 1.0f) {
return;
}
llama_sample_softmax(ctx, candidates);
llama_sample_softmax(smpl, candidates);
const int64_t t_start_sample_us = ggml_time_us();
@ -152,12 +166,12 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can
// Resize the output vector to keep only the top-p tokens
candidates->size = last_idx;
if (ctx) {
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
void llama_sample_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
if (p <= 0.0f || !candidates->size) {
return;
}
@ -213,17 +227,17 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
candidates->size = i;
}
if (ctx) {
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
void llama_sample_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
if (z >= 1.0f || candidates->size <= 2) {
return;
}
llama_sample_softmax(nullptr, candidates);
llama_sample_softmax((struct llama_sampling *) nullptr, candidates);
const int64_t t_start_sample_us = ggml_time_us();
// Compute the first and second derivatives
@ -272,12 +286,12 @@ void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array *
// Resize the output vector to keep only the tokens above the tail location
candidates->size = last_idx;
if (ctx) {
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
void llama_sample_typical(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
// Reference implementation:
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
if (p >= 1.0f) {
@ -285,7 +299,7 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c
}
// Compute the softmax of logits and calculate entropy
llama_sample_softmax(nullptr, candidates);
llama_sample_softmax((struct llama_sampling *) nullptr, candidates);
const int64_t t_start_sample_us = ggml_time_us();
@ -336,34 +350,34 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c
candidates->size = new_candidates.size();
candidates->sorted = false;
if (ctx) {
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
void llama_sample_entropy(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
const int64_t t_start_sample_us = ggml_time_us();
// no need to do anything if there is only one (or zero) candidates
if(candidates_p->size <= 1) {
if(candidates->size <= 1) {
return;
}
// Calculate maximum possible entropy
float max_entropy = -logf(1.0f / candidates_p->size);
float max_entropy = -logf(1.0f / candidates->size);
llama_sample_softmax(nullptr, candidates_p);
llama_sample_softmax((struct llama_sampling *) nullptr, candidates);
// Calculate entropy of the softmax probabilities
float entropy = 0.0f;
for (size_t i = 0; i < candidates_p->size; ++i) {
float prob = candidates_p->data[i].p;
for (size_t i = 0; i < candidates->size; ++i) {
float prob = candidates->data[i].p;
if (prob > 0.0f) { // Ensure no log(0)
entropy -= prob * logf(prob);
}
}
// Normalize the entropy (max_entropy cannot be 0 here because we checked candidates_p->size != 1 above)
// Normalize the entropy (max_entropy cannot be 0 here because we checked candidates->size != 1 above)
float normalized_entropy = entropy / max_entropy;
// Map the normalized entropy to the desired temperature range using the power function
@ -379,55 +393,55 @@ void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * c
#endif
// Apply the dynamically calculated temperature scaling
for (size_t i = 0; i < candidates_p->size; ++i) {
candidates_p->data[i].logit /= dyn_temp;
for (size_t i = 0; i < candidates->size; ++i) {
candidates->data[i].logit /= dyn_temp;
}
// Re-compute softmax probabilities after scaling logits with dynamic temperature
double max_l_double = candidates_p->data[0].logit;
double max_l_double = candidates->data[0].logit;
double cum_sum_double = 0.0;
for (size_t i = 0; i < candidates_p->size; ++i) {
double p = exp(candidates_p->data[i].logit - max_l_double);
candidates_p->data[i].p = p; // Store the scaled probability
for (size_t i = 0; i < candidates->size; ++i) {
double p = exp(candidates->data[i].logit - max_l_double);
candidates->data[i].p = p; // Store the scaled probability
cum_sum_double += p;
}
for (size_t i = 0; i < candidates_p->size; ++i) {
candidates_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
for (size_t i = 0; i < candidates->size; ++i) {
candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities
}
#ifdef DEBUG
// Print the updated top 25 probabilities after temperature scaling
LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
for (size_t i = 0; i < 25 && i < candidates_p->size; ++i) {
LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates_p->data[i].p * 100.0f);
for (size_t i = 0; i < 25 && i < candidates->size; ++i) {
LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
}
#endif
if (ctx) {
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
void llama_sample_temp(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
const int64_t t_start_sample_us = ggml_time_us();
for (size_t i = 0; i < candidates_p->size; ++i) {
candidates_p->data[i].logit /= temp;
for (size_t i = 0; i < candidates->size; ++i) {
candidates->data[i].logit /= temp;
}
if (ctx) {
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
void llama_sample_repetition_penalties(
struct llama_context * ctx,
llama_token_data_array * candidates,
const llama_token * last_tokens,
size_t penalty_last_n,
float penalty_repeat,
float penalty_freq,
float penalty_present) {
struct llama_sampling * smpl,
llama_token_data_array * candidates,
const llama_token * last_tokens,
size_t penalty_last_n,
float penalty_repeat,
float penalty_freq,
float penalty_present) {
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
return;
}
@ -462,34 +476,20 @@ void llama_sample_repetition_penalties(
candidates->sorted = false;
if (ctx) {
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
static void llama_log_softmax(float * array, size_t size) {
float max_l = *std::max_element(array, array + size);
float sum = 0.f;
for (size_t i = 0; i < size; ++i) {
float p = expf(array[i] - max_l);
sum += p;
array[i] = p;
}
for (size_t i = 0; i < size; ++i) {
array[i] = logf(array[i] / sum);
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
void llama_sample_apply_guidance(
struct llama_context * ctx,
float * logits,
float * logits_guidance,
float scale) {
GGML_ASSERT(ctx);
struct llama_sampling * smpl,
float * logits,
float * logits_guidance,
float scale) {
GGML_ASSERT(smpl);
const auto t_start_sample_us = ggml_time_us();
const auto n_vocab = llama_get_sampling(ctx)->n_vocab;
const auto n_vocab = smpl->n_vocab;
llama_log_softmax(logits, n_vocab);
llama_log_softmax(logits_guidance, n_vocab);
@ -501,17 +501,17 @@ void llama_sample_apply_guidance(
l = scale * (l - g) + g;
}
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
GGML_ASSERT(ctx);
llama_token llama_sample_token_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
GGML_ASSERT(smpl);
const int32_t n_vocab = float(llama_get_sampling(ctx)->n_vocab);
const int32_t n_vocab = float(smpl->n_vocab);
int64_t t_start_sample_us = ggml_time_us();
llama_sample_softmax(nullptr, candidates);
llama_sample_softmax((struct llama_sampling *) nullptr, candidates);
// Estimate s_hat using the most probable m tokens
float s_hat = 0.0;
@ -530,9 +530,9 @@ llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_
float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
// Sample the next word X using top-k sampling
llama_sample_top_k(nullptr, candidates, int(k), 1);
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
llama_token X = llama_sample_token(ctx, candidates);
llama_sample_top_k((struct llama_sampling *) nullptr, candidates, int(k), 1);
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
llama_token X = llama_sample_token(smpl, candidates);
t_start_sample_us = ggml_time_us();
// Compute error as the difference between observed surprise and target surprise value
@ -545,15 +545,15 @@ llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_
// Update mu using the learning rate and error
*mu = *mu - eta * e;
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
return X;
}
llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
llama_token llama_sample_token_mirostat_v2(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
int64_t t_start_sample_us;
t_start_sample_us = ggml_time_us();
llama_sample_softmax(ctx, candidates);
llama_sample_softmax(smpl, candidates);
// Truncate the words with surprise values greater than mu
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
@ -564,15 +564,15 @@ llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_tok
candidates->size = 1;
}
if (ctx) {
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
// Normalize the probabilities of the remaining words
llama_sample_softmax(ctx, candidates);
llama_sample_softmax(smpl, candidates);
// Sample the next word X from the remaining words
llama_token X = llama_sample_token(ctx, candidates);
llama_token X = llama_sample_token(smpl, candidates);
t_start_sample_us = ggml_time_us();
// Compute error as the difference between observed surprise and target surprise value
@ -585,13 +585,13 @@ llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_tok
// Update mu using the learning rate and error
*mu = *mu - eta * e;
if (ctx) {
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
return X;
}
llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) {
llama_token llama_sample_token_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) {
const int64_t t_start_sample_us = ggml_time_us();
// Find max element
@ -600,18 +600,18 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da
});
llama_token result = max_iter->id;
if (ctx) {
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
llama_get_sampling(ctx)->n_sample++;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
smpl->n_sample++;
}
return result;
}
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
GGML_ASSERT(ctx);
llama_token llama_sample_token_with_rng(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
GGML_ASSERT(smpl);
const int64_t t_start_sample_us = ggml_time_us();
llama_sample_softmax(nullptr, candidates);
llama_sample_softmax((struct llama_sampling *) nullptr, candidates);
std::vector<float> probs;
probs.reserve(candidates->size);
@ -624,12 +624,12 @@ llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_
llama_token result = candidates->data[idx].id;
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
llama_get_sampling(ctx)->n_sample++;
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
smpl->n_sample++;
return result;
}
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
return llama_sample_token_with_rng(ctx, candidates, llama_get_sampling(ctx)->rng);
llama_token llama_sample_token(struct llama_sampling * smpl, llama_token_data_array * candidates) {
return llama_sample_token_with_rng(smpl, candidates, smpl->rng);
}

View file

@ -7,15 +7,48 @@ struct llama_sampling {
std::mt19937 rng;
int64_t t_sample_us = 0;
int32_t n_sample = 0;
int32_t n_vocab = 0;
void reset_timings() {
mutable int64_t t_sample_us = 0;
mutable int32_t n_sample = 0;
void reset_timings() const {
t_sample_us = 0;
n_sample = 0;
}
};
struct llama_sampling * llama_get_sampling(struct llama_context * ctx);
void llama_set_rng_seed(struct llama_sampling * smpl, uint32_t seed);
void llama_sample_softmax (struct llama_sampling * smpl, llama_token_data_array * candidates);
void llama_sample_top_k (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
void llama_sample_top_p (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_min_p (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
void llama_sample_typical (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_entropy (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
void llama_sample_temp (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
void llama_sample_repetition_penalties(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
const llama_token * last_tokens,
size_t penalty_last_n,
float penalty_repeat,
float penalty_freq,
float penalty_present);
void llama_sample_apply_guidance(
struct llama_sampling * smpl,
float * logits,
float * logits_guidance,
float scale);
llama_token llama_sample_token_mirostat (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
llama_token llama_sample_token_mirostat_v2(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
llama_token llama_sample_token_greedy (struct llama_sampling * smpl, llama_token_data_array * candidates);
llama_token llama_sample_token_with_rng (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
llama_token llama_sample_token (struct llama_sampling * smpl, llama_token_data_array * candidates);

View file

@ -12,29 +12,10 @@
#include <queue>
#include <sstream>
#if __cplusplus >= 202000L
#define LU8(x) (const char*)(u8##x)
#else
#define LU8(x) u8##x
#endif
//
// helpers
//
// trim whitespace from the beginning and end of a string
static std::string trim(const std::string & str) {
size_t start = 0;
size_t end = str.size();
while (start < end && isspace(str[start])) {
start += 1;
}
while (end > start && isspace(str[end - 1])) {
end -= 1;
}
return str.substr(start, end - start);
}
static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
std::string result;
for (size_t pos = 0; ; pos += search.length()) {
@ -1445,106 +1426,89 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
return output;
}
const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
const struct llama_vocab * vocab = llama_get_vocab(model);
GGML_ASSERT(vocab->type != LLAMA_VOCAB_TYPE_NONE);
return vocab->id_to_token[token].text.c_str();
const char * llama_token_get_text(const struct llama_vocab & vocab, llama_token token) {
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
return vocab.id_to_token[token].text.c_str();
}
float llama_token_get_score(const struct llama_model * model, llama_token token) {
const struct llama_vocab * vocab = llama_get_vocab(model);
GGML_ASSERT(vocab->type != LLAMA_VOCAB_TYPE_NONE);
return vocab->id_to_token[token].score;
float llama_token_get_score(const struct llama_vocab & vocab, llama_token token) {
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
return vocab.id_to_token[token].score;
}
llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) {
const struct llama_vocab * vocab = llama_get_vocab(model);
GGML_ASSERT(vocab->type != LLAMA_VOCAB_TYPE_NONE);
return vocab->id_to_token[token].attr;
llama_token_attr llama_token_get_attr(const struct llama_vocab & vocab, llama_token token) {
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
return vocab.id_to_token[token].attr;
}
bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
bool llama_token_is_eog(const struct llama_vocab & vocab, llama_token token) {
return token != -1 && (
token == llama_token_eos(model) ||
token == llama_token_eot(model)
token == llama_token_eos(vocab) ||
token == llama_token_eot(vocab)
);
}
bool llama_token_is_control(const struct llama_model * model, llama_token token) {
const struct llama_vocab * vocab = llama_get_vocab(model);
return llama_is_control_token(*vocab, token);
bool llama_token_is_control(const struct llama_vocab & vocab, llama_token token) {
return llama_is_control_token(vocab, token);
}
llama_token llama_token_bos(const struct llama_model * model) {
const struct llama_vocab * vocab = llama_get_vocab(model);
return vocab->special_bos_id;
llama_token llama_token_bos(const struct llama_vocab & vocab) {
return vocab.special_bos_id;
}
llama_token llama_token_eos(const struct llama_model * model) {
const struct llama_vocab * vocab = llama_get_vocab(model);
return vocab->special_eos_id;
llama_token llama_token_eos(const struct llama_vocab & vocab) {
return vocab.special_eos_id;
}
llama_token llama_token_cls(const struct llama_model * model) {
const struct llama_vocab * vocab = llama_get_vocab(model);
return vocab->special_cls_id;
llama_token llama_token_cls(const struct llama_vocab & vocab) {
return vocab.special_cls_id;
}
llama_token llama_token_sep(const struct llama_model * model) {
const struct llama_vocab * vocab = llama_get_vocab(model);
return vocab->special_sep_id;
llama_token llama_token_sep(const struct llama_vocab & vocab) {
return vocab.special_sep_id;
}
llama_token llama_token_nl(const struct llama_model * model) {
const struct llama_vocab * vocab = llama_get_vocab(model);
return vocab->linefeed_id;
llama_token llama_token_nl(const struct llama_vocab & vocab) {
return vocab.linefeed_id;
}
int32_t llama_add_bos_token(const struct llama_model * model) {
const struct llama_vocab * vocab = llama_get_vocab(model);
return vocab->tokenizer_add_bos;
llama_token llama_token_pad(const struct llama_vocab & vocab) {
return vocab.special_pad_id;
}
int32_t llama_add_eos_token(const struct llama_model * model) {
const struct llama_vocab * vocab = llama_get_vocab(model);
return vocab->tokenizer_add_eos;
int32_t llama_add_bos_token(const struct llama_vocab & vocab) {
return vocab.tokenizer_add_bos;
}
llama_token llama_token_prefix(const struct llama_model * model) {
const struct llama_vocab * vocab = llama_get_vocab(model);
return vocab->special_prefix_id;
int32_t llama_add_eos_token(const struct llama_vocab & vocab) {
return vocab.tokenizer_add_eos;
}
llama_token llama_token_middle(const struct llama_model * model) {
const struct llama_vocab * vocab = llama_get_vocab(model);
return vocab->special_middle_id;
llama_token llama_token_prefix(const struct llama_vocab & vocab) {
return vocab.special_prefix_id;
}
llama_token llama_token_suffix(const struct llama_model * model) {
const struct llama_vocab * vocab = llama_get_vocab(model);
return vocab->special_suffix_id;
llama_token llama_token_middle(const struct llama_vocab & vocab) {
return vocab.special_middle_id;
}
llama_token llama_token_eot(const struct llama_model * model) {
const struct llama_vocab * vocab = llama_get_vocab(model);
return vocab->special_eot_id;
llama_token llama_token_suffix(const struct llama_vocab & vocab) {
return vocab.special_suffix_id;
}
llama_token llama_token_pad(const struct llama_model * model) {
const struct llama_vocab * vocab = llama_get_vocab(model);
return vocab->special_pad_id;
llama_token llama_token_eot(const struct llama_vocab & vocab) {
return vocab.special_eot_id;
}
int32_t llama_tokenize(
const struct llama_model * model,
const struct llama_vocab & vocab,
const char * text,
int32_t text_len,
llama_token * tokens,
int32_t n_tokens_max,
bool add_special,
bool parse_special) {
const struct llama_vocab * vocab = llama_get_vocab(model);
auto res = llama_tokenize_internal(*vocab, std::string(text, text_len), add_special, parse_special);
auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special);
if (n_tokens_max < (int) res.size()) {
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
return -((int) res.size());
@ -1578,10 +1542,10 @@ static std::string llama_decode_text(const std::string & text) {
}
// does not write null-terminator to buf
int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) {
int32_t llama_token_to_piece(const struct llama_vocab & vocab, llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) {
// ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
const llama_token_attr attr = llama_token_get_attr(model, token);
const llama_token_attr attr = llama_token_get_attr(vocab, token);
if (!special && (attr & attr_special)) {
return 0;
}
@ -1600,11 +1564,9 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
return (int32_t) size;
};
const struct llama_vocab * vocab = llama_get_vocab(model);
// if we have a cache - use it
{
const auto & cache = vocab->cache_token_to_piece;
const auto & cache = vocab.cache_token_to_piece;
if (!cache.empty()) {
const auto & result = cache.at(token);
@ -1612,9 +1574,9 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
}
}
if (0 <= token && token < llama_n_vocab(model)) {
const std::string & token_text = vocab->id_to_token[token].text;
switch (llama_vocab_get_type(*vocab)) {
if (0 <= token && token < (int32_t) vocab.id_to_token.size()) {
const std::string & token_text = vocab.id_to_token[token].text;
switch (llama_vocab_get_type(vocab)) {
case LLAMA_VOCAB_TYPE_WPM:
case LLAMA_VOCAB_TYPE_SPM:
case LLAMA_VOCAB_TYPE_UGM: {
@ -1627,7 +1589,7 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
llama_unescape_whitespace(result);
return _try_copy(result.data(), result.size());
} else if (attr & LLAMA_TOKEN_ATTR_BYTE) {
char byte = (char) llama_token_to_byte(*vocab, token);
char byte = (char) llama_token_to_byte(vocab, token);
return _try_copy((char*) &byte, 1);
}
break;
@ -1647,11 +1609,12 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
GGML_ASSERT(false);
}
}
return 0;
}
int32_t llama_detokenize(
const struct llama_model * model,
const struct llama_vocab & vocab,
const llama_token * tokens,
int32_t n_tokens,
char * text,
@ -1661,28 +1624,26 @@ int32_t llama_detokenize(
int32_t avail = text_len_max;
int32_t total = 0;
const struct llama_vocab * vocab = llama_get_vocab(model);
// remove the leading space
bool remove_space = vocab->tokenizer_add_space_prefix;
bool remove_space = vocab.tokenizer_add_space_prefix;
if (remove_special && vocab->tokenizer_add_bos) {
if (n_tokens > 0 && tokens[0] == vocab->special_bos_id) {
if (remove_special && vocab.tokenizer_add_bos) {
if (n_tokens > 0 && tokens[0] == vocab.special_bos_id) {
remove_space = false;
n_tokens--;
tokens++;
}
}
if (remove_special && vocab->tokenizer_add_eos) {
if (n_tokens > 0 && tokens[n_tokens-1] == vocab->special_eos_id) {
if (remove_special && vocab.tokenizer_add_eos) {
if (n_tokens > 0 && tokens[n_tokens-1] == vocab.special_eos_id) {
n_tokens--;
}
}
for (int32_t i = 0; i < n_tokens; ++i) {
GGML_ASSERT(avail >= 0);
int32_t n_chars = llama_token_to_piece(model, tokens[i], text, avail, remove_space, unparse_special);
int32_t n_chars = llama_token_to_piece(vocab, tokens[i], text, avail, remove_space, unparse_special);
remove_space = false;
if (n_chars < 0) {
avail = 0;
@ -1698,7 +1659,7 @@ int32_t llama_detokenize(
return -total;
}
if (vocab->tokenizer_clean_spaces) {
if (vocab.tokenizer_clean_spaces) {
text -= total; // restart text
// first pass: characters ?!., //TODO: where do these characters come from?
@ -1758,298 +1719,3 @@ int32_t llama_detokenize(
return total <= text_len_max ? total : -total;
}
//
// chat templates
//
// Simple version of "llama_apply_chat_template" that only works with strings
// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
static int32_t llama_chat_apply_template_internal(
const std::string & tmpl,
const std::vector<const llama_chat_message *> & chat,
std::string & dest, bool add_ass) {
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
std::stringstream ss;
auto tmpl_contains = [&tmpl](std::string haystack) -> bool {
return tmpl.find(haystack) != std::string::npos;
};
if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) {
// chatml template
for (auto message : chat) {
ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
}
if (add_ass) {
ss << "<|im_start|>assistant\n";
}
} else if (tmpl == "llama2" || tmpl == "mistral" || tmpl_contains("[INST]")) {
// llama2 template and its variants
// [variant] support system message
bool support_system_message = tmpl_contains("<<SYS>>") || tmpl == "mistral";
// [variant] space before + after response
bool space_around_response = tmpl_contains("' ' + eos_token");
// [variant] add BOS inside history
bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]");
// [variant] trim spaces from the input message
bool strip_message = tmpl_contains("content.strip()");
// construct the prompt
bool is_inside_turn = true; // skip BOS at the beginning
ss << "[INST] ";
for (auto message : chat) {
std::string content = strip_message ? trim(message->content) : message->content;
std::string role(message->role);
if (!is_inside_turn) {
is_inside_turn = true;
ss << (add_bos_inside_history ? "<s>[INST] " : "[INST] ");
}
if (role == "system") {
if (support_system_message) {
ss << "<<SYS>>\n" << content << "\n<</SYS>>\n\n";
} else {
// if the model does not support system message, we still include it in the first message, but without <<SYS>>
ss << content << "\n";
}
} else if (role == "user") {
ss << content << " [/INST]";
} else {
ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << "</s>";
is_inside_turn = false;
}
}
// llama2 templates seem to not care about "add_generation_prompt"
} else if (tmpl == "phi3" || (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>"))) {
// Phi 3
for (auto message : chat) {
std::string role(message->role);
ss << "<|" << role << "|>\n" << message->content << "<|end|>\n";
}
if (add_ass) {
ss << "<|assistant|>\n";
}
} else if (tmpl == "zephyr" || tmpl_contains("<|user|>")) {
// zephyr template
for (auto message : chat) {
ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
}
if (add_ass) {
ss << "<|assistant|>\n";
}
} else if (tmpl == "monarch" || tmpl_contains("bos_token + message['role']")) {
// mlabonne/AlphaMonarch-7B template (the <s> is included inside history)
for (auto message : chat) {
std::string bos = (message == chat.front()) ? "" : "<s>"; // skip BOS for first message
ss << bos << message->role << "\n" << message->content << "</s>\n";
}
if (add_ass) {
ss << "<s>assistant\n";
}
} else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl_contains("<start_of_turn>")) {
// google/gemma-7b-it
std::string system_prompt = "";
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
// there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
system_prompt = trim(message->content);
continue;
}
// in gemma, "assistant" is "model"
role = role == "assistant" ? "model" : message->role;
ss << "<start_of_turn>" << role << "\n";
if (!system_prompt.empty() && role != "model") {
ss << system_prompt << "\n\n";
system_prompt = "";
}
ss << trim(message->content) << "<end_of_turn>\n";
}
if (add_ass) {
ss << "<start_of_turn>model\n";
}
} else if (tmpl == "orion" || tmpl_contains("'\\n\\nAssistant: ' + eos_token")) {
// OrionStarAI/Orion-14B-Chat
std::string system_prompt = "";
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
// there is no system message support, we will merge it with user prompt
system_prompt = message->content;
continue;
} else if (role == "user") {
ss << "Human: ";
if (!system_prompt.empty()) {
ss << system_prompt << "\n\n";
system_prompt = "";
}
ss << message->content << "\n\nAssistant: </s>";
} else {
ss << message->content << "</s>";
}
}
} else if (tmpl == "openchat" || tmpl_contains("GPT4 Correct ")) {
// openchat/openchat-3.5-0106,
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << message->content << "<|end_of_turn|>";
} else {
role[0] = toupper(role[0]);
ss << "GPT4 Correct " << role << ": " << message->content << "<|end_of_turn|>";
}
}
if (add_ass) {
ss << "GPT4 Correct Assistant:";
}
} else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: "))) {
// eachadea/vicuna-13b-1.1 (and Orca variant)
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
// Orca-Vicuna variant uses a system prefix
if (tmpl == "vicuna-orca" || tmpl_contains("SYSTEM: ")) {
ss << "SYSTEM: " << message->content << "\n";
} else {
ss << message->content << "\n\n";
}
} else if (role == "user") {
ss << "USER: " << message->content << "\n";
} else if (role == "assistant") {
ss << "ASSISTANT: " << message->content << "</s>\n";
}
}
if (add_ass) {
ss << "ASSISTANT:";
}
} else if (tmpl == "deepseek" || (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>"))) {
// deepseek-ai/deepseek-coder-33b-instruct
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << message->content;
} else if (role == "user") {
ss << "### Instruction:\n" << message->content << "\n";
} else if (role == "assistant") {
ss << "### Response:\n" << message->content << "\n<|EOT|>\n";
}
}
if (add_ass) {
ss << "### Response:\n";
}
} else if (tmpl == "command-r" || (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>"))) {
// CohereForAI/c4ai-command-r-plus
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
} else if (role == "user") {
ss << "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
} else if (role == "assistant") {
ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
}
}
if (add_ass) {
ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>";
}
} else if (tmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>"))) {
// Llama 3
for (auto message : chat) {
std::string role(message->role);
ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n" << trim(message->content) << "<|eot_id|>";
}
if (add_ass) {
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
}
} else if (tmpl == "chatglm3" || tmpl_contains("[gMASK]sop")) {
// chatglm3-6b
ss << "[gMASK]" << "sop";
for (auto message : chat) {
std::string role(message->role);
ss << "<|" << role << "|>" << "\n " << message->content;
}
if (add_ass) {
ss << "<|assistant|>";
}
} else if (tmpl == "chatglm4" || tmpl_contains("[gMASK]<sop>")) {
ss << "[gMASK]" << "<sop>";
for (auto message : chat) {
std::string role(message->role);
ss << "<|" << role << "|>" << "\n" << message->content;
}
if (add_ass) {
ss << "<|assistant|>";
}
} else if (tmpl == "minicpm" || tmpl_contains(LU8("<用户>"))) {
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
for (auto message : chat) {
std::string role(message->role);
if (role == "user") {
ss << LU8("<用户>");
ss << trim(message->content);
ss << "<AI>";
} else {
ss << trim(message->content);
}
}
} else if (tmpl == "deepseek2" || tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
// DeepSeek-V2
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << message->content << "\n\n";
} else if (role == "user") {
ss << "User: " << message->content << "\n\n";
} else if (role == "assistant") {
ss << "Assistant: " << message->content << LU8("<end▁of▁sentence>");
}
}
if (add_ass) {
ss << "Assistant:";
}
} else {
// template not supported
return -1;
}
dest = ss.str();
return dest.size();
}
int32_t llama_chat_apply_template(
const struct llama_model * model,
const char * tmpl,
const struct llama_chat_message * chat,
size_t n_msg,
bool add_ass,
char * buf,
int32_t length) {
std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
if (tmpl == nullptr) {
GGML_ASSERT(model != nullptr);
// load template from model
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
std::string template_key = "tokenizer.chat_template";
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
if (res < 0) {
// worst case: there is no information about template, we will use chatml by default
curr_tmpl = "chatml"; // see llama_chat_apply_template_internal
} else {
curr_tmpl = std::string(model_template.data(), model_template.size());
}
}
// format the chat to string
std::vector<const llama_chat_message *> chat_vec;
chat_vec.resize(n_msg);
for (size_t i = 0; i < n_msg; i++) {
chat_vec[i] = &chat[i];
}
std::string formatted_chat;
int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
if (res < 0) {
return res;
}
if (buf && length > 0) {
strncpy(buf, formatted_chat.c_str(), length);
}
return res;
}

View file

@ -72,3 +72,55 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
bool parse_special = false);
llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch);
const char * llama_token_get_text(const struct llama_vocab & vocab, llama_token token);
float llama_token_get_score(const struct llama_vocab & vocab, llama_token token);
llama_token_attr llama_token_get_attr(const struct llama_vocab & vocab, llama_token token);
bool llama_token_is_eog(const struct llama_vocab & vocab, llama_token token);
bool llama_token_is_control(const struct llama_vocab & vocab, llama_token token);
llama_token llama_token_bos(const struct llama_vocab & vocab);
llama_token llama_token_eos(const struct llama_vocab & vocab);
llama_token llama_token_cls(const struct llama_vocab & vocab);
llama_token llama_token_sep(const struct llama_vocab & vocab);
llama_token llama_token_nl (const struct llama_vocab & vocab);
llama_token llama_token_pad(const struct llama_vocab & vocab);
int32_t llama_add_bos_token(const struct llama_vocab & vocab);
int32_t llama_add_eos_token(const struct llama_vocab & vocab);
llama_token llama_token_prefix(const struct llama_vocab & vocab);
llama_token llama_token_middle(const struct llama_vocab & vocab);
llama_token llama_token_suffix(const struct llama_vocab & vocab);
llama_token llama_token_eot (const struct llama_vocab & vocab);
int32_t llama_tokenize(
const struct llama_vocab & vocab,
const char * text,
int32_t text_len,
llama_token * tokens,
int32_t n_tokens_max,
bool add_special,
bool parse_special);
// does not write null-terminator to buf
int32_t llama_token_to_piece(
const struct llama_vocab & vocab,
llama_token token,
char * buf,
int32_t length,
int32_t lstrip,
bool special);
int32_t llama_detokenize(
const struct llama_vocab & vocab,
const llama_token * tokens,
int32_t n_tokens,
char * text,
int32_t text_len_max,
bool remove_special,
bool unparse_special);

View file

@ -48,6 +48,12 @@
#include <io.h>
#endif
#if __cplusplus >= 202000L
#define LU8(x) (const char*)(u8##x)
#else
#define LU8(x) u8##x
#endif
#include <algorithm>
#include <array>
#include <cassert>
@ -85,6 +91,19 @@
// helpers
//
// trim whitespace from the beginning and end of a string
static std::string trim(const std::string & str) {
size_t start = 0;
size_t end = str.size();
while (start < end && isspace(str[start])) {
start += 1;
}
while (end > start && isspace(str[end - 1])) {
end -= 1;
}
return str.substr(start, end - start);
}
static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
std::string result;
for (size_t pos = 0; ; pos += search.length()) {
@ -18487,6 +18506,527 @@ float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id
return it->second.data();
}
//
// vocab
//
const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
return llama_token_get_text(model->vocab, token);
}
float llama_token_get_score(const struct llama_model * model, llama_token token) {
return llama_token_get_score(model->vocab, token);
}
enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) {
return llama_token_get_attr(model->vocab, token);
}
bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
return llama_token_is_eog(model->vocab, token);
}
bool llama_token_is_control(const struct llama_model * model, llama_token token) {
return llama_token_is_control(model->vocab, token);
}
llama_token llama_token_bos(const struct llama_model * model) {
return llama_token_bos(model->vocab);
}
llama_token llama_token_eos(const struct llama_model * model) {
return llama_token_eos(model->vocab);
}
llama_token llama_token_cls(const struct llama_model * model) {
return llama_token_cls(model->vocab);
}
llama_token llama_token_sep(const struct llama_model * model) {
return llama_token_sep(model->vocab);
}
llama_token llama_token_nl (const struct llama_model * model) {
return llama_token_nl (model->vocab);
}
llama_token llama_token_pad(const struct llama_model * model) {
return llama_token_pad(model->vocab);
}
int32_t llama_add_bos_token(const struct llama_model * model) {
return llama_add_bos_token(model->vocab);
}
int32_t llama_add_eos_token(const struct llama_model * model) {
return llama_add_eos_token(model->vocab);
}
llama_token llama_token_prefix(const struct llama_model * model) {
return llama_token_prefix(model->vocab);
}
llama_token llama_token_middle(const struct llama_model * model) {
return llama_token_middle(model->vocab);
}
llama_token llama_token_suffix(const struct llama_model * model) {
return llama_token_suffix(model->vocab);
}
llama_token llama_token_eot(const struct llama_model * model) {
return llama_token_eot(model->vocab);
}
//
// tokenization
//
int32_t llama_tokenize(
const struct llama_model * model,
const char * text,
int32_t text_len,
llama_token * tokens,
int32_t n_tokens_max,
bool add_special,
bool parse_special) {
return llama_tokenize(model->vocab, text, text_len, tokens, n_tokens_max, add_special, parse_special);
}
int32_t llama_token_to_piece(
const struct llama_model * model,
llama_token token,
char * buf,
int32_t length,
int32_t lstrip,
bool special) {
return llama_token_to_piece(model->vocab, token, buf, length, lstrip, special);
}
int32_t llama_detokenize(
const struct llama_model * model,
const llama_token * tokens,
int32_t n_tokens,
char * text,
int32_t text_len_max,
bool remove_special,
bool unparse_special) {
return llama_detokenize(model->vocab, tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
}
//
// chat templates
//
// Simple version of "llama_apply_chat_template" that only works with strings
// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
static int32_t llama_chat_apply_template_internal(
const std::string & tmpl,
const std::vector<const llama_chat_message *> & chat,
std::string & dest, bool add_ass) {
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
std::stringstream ss;
auto tmpl_contains = [&tmpl](std::string haystack) -> bool {
return tmpl.find(haystack) != std::string::npos;
};
if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) {
// chatml template
for (auto message : chat) {
ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
}
if (add_ass) {
ss << "<|im_start|>assistant\n";
}
} else if (tmpl == "llama2" || tmpl == "mistral" || tmpl_contains("[INST]")) {
// llama2 template and its variants
// [variant] support system message
bool support_system_message = tmpl_contains("<<SYS>>") || tmpl == "mistral";
// [variant] space before + after response
bool space_around_response = tmpl_contains("' ' + eos_token");
// [variant] add BOS inside history
bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]");
// [variant] trim spaces from the input message
bool strip_message = tmpl_contains("content.strip()");
// construct the prompt
bool is_inside_turn = true; // skip BOS at the beginning
ss << "[INST] ";
for (auto message : chat) {
std::string content = strip_message ? trim(message->content) : message->content;
std::string role(message->role);
if (!is_inside_turn) {
is_inside_turn = true;
ss << (add_bos_inside_history ? "<s>[INST] " : "[INST] ");
}
if (role == "system") {
if (support_system_message) {
ss << "<<SYS>>\n" << content << "\n<</SYS>>\n\n";
} else {
// if the model does not support system message, we still include it in the first message, but without <<SYS>>
ss << content << "\n";
}
} else if (role == "user") {
ss << content << " [/INST]";
} else {
ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << "</s>";
is_inside_turn = false;
}
}
// llama2 templates seem to not care about "add_generation_prompt"
} else if (tmpl == "phi3" || (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>"))) {
// Phi 3
for (auto message : chat) {
std::string role(message->role);
ss << "<|" << role << "|>\n" << message->content << "<|end|>\n";
}
if (add_ass) {
ss << "<|assistant|>\n";
}
} else if (tmpl == "zephyr" || tmpl_contains("<|user|>")) {
// zephyr template
for (auto message : chat) {
ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
}
if (add_ass) {
ss << "<|assistant|>\n";
}
} else if (tmpl == "monarch" || tmpl_contains("bos_token + message['role']")) {
// mlabonne/AlphaMonarch-7B template (the <s> is included inside history)
for (auto message : chat) {
std::string bos = (message == chat.front()) ? "" : "<s>"; // skip BOS for first message
ss << bos << message->role << "\n" << message->content << "</s>\n";
}
if (add_ass) {
ss << "<s>assistant\n";
}
} else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl_contains("<start_of_turn>")) {
// google/gemma-7b-it
std::string system_prompt = "";
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
// there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
system_prompt = trim(message->content);
continue;
}
// in gemma, "assistant" is "model"
role = role == "assistant" ? "model" : message->role;
ss << "<start_of_turn>" << role << "\n";
if (!system_prompt.empty() && role != "model") {
ss << system_prompt << "\n\n";
system_prompt = "";
}
ss << trim(message->content) << "<end_of_turn>\n";
}
if (add_ass) {
ss << "<start_of_turn>model\n";
}
} else if (tmpl == "orion" || tmpl_contains("'\\n\\nAssistant: ' + eos_token")) {
// OrionStarAI/Orion-14B-Chat
std::string system_prompt = "";
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
// there is no system message support, we will merge it with user prompt
system_prompt = message->content;
continue;
} else if (role == "user") {
ss << "Human: ";
if (!system_prompt.empty()) {
ss << system_prompt << "\n\n";
system_prompt = "";
}
ss << message->content << "\n\nAssistant: </s>";
} else {
ss << message->content << "</s>";
}
}
} else if (tmpl == "openchat" || tmpl_contains("GPT4 Correct ")) {
// openchat/openchat-3.5-0106,
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << message->content << "<|end_of_turn|>";
} else {
role[0] = toupper(role[0]);
ss << "GPT4 Correct " << role << ": " << message->content << "<|end_of_turn|>";
}
}
if (add_ass) {
ss << "GPT4 Correct Assistant:";
}
} else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: "))) {
// eachadea/vicuna-13b-1.1 (and Orca variant)
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
// Orca-Vicuna variant uses a system prefix
if (tmpl == "vicuna-orca" || tmpl_contains("SYSTEM: ")) {
ss << "SYSTEM: " << message->content << "\n";
} else {
ss << message->content << "\n\n";
}
} else if (role == "user") {
ss << "USER: " << message->content << "\n";
} else if (role == "assistant") {
ss << "ASSISTANT: " << message->content << "</s>\n";
}
}
if (add_ass) {
ss << "ASSISTANT:";
}
} else if (tmpl == "deepseek" || (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>"))) {
// deepseek-ai/deepseek-coder-33b-instruct
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << message->content;
} else if (role == "user") {
ss << "### Instruction:\n" << message->content << "\n";
} else if (role == "assistant") {
ss << "### Response:\n" << message->content << "\n<|EOT|>\n";
}
}
if (add_ass) {
ss << "### Response:\n";
}
} else if (tmpl == "command-r" || (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>"))) {
// CohereForAI/c4ai-command-r-plus
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
} else if (role == "user") {
ss << "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
} else if (role == "assistant") {
ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
}
}
if (add_ass) {
ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>";
}
} else if (tmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>"))) {
// Llama 3
for (auto message : chat) {
std::string role(message->role);
ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n" << trim(message->content) << "<|eot_id|>";
}
if (add_ass) {
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
}
} else if (tmpl == "chatglm3" || tmpl_contains("[gMASK]sop")) {
// chatglm3-6b
ss << "[gMASK]" << "sop";
for (auto message : chat) {
std::string role(message->role);
ss << "<|" << role << "|>" << "\n " << message->content;
}
if (add_ass) {
ss << "<|assistant|>";
}
} else if (tmpl == "chatglm4" || tmpl_contains("[gMASK]<sop>")) {
ss << "[gMASK]" << "<sop>";
for (auto message : chat) {
std::string role(message->role);
ss << "<|" << role << "|>" << "\n" << message->content;
}
if (add_ass) {
ss << "<|assistant|>";
}
} else if (tmpl == "minicpm" || tmpl_contains(LU8("<用户>"))) {
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
for (auto message : chat) {
std::string role(message->role);
if (role == "user") {
ss << LU8("<用户>");
ss << trim(message->content);
ss << "<AI>";
} else {
ss << trim(message->content);
}
}
} else if (tmpl == "deepseek2" || tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
// DeepSeek-V2
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << message->content << "\n\n";
} else if (role == "user") {
ss << "User: " << message->content << "\n\n";
} else if (role == "assistant") {
ss << "Assistant: " << message->content << LU8("<end▁of▁sentence>");
}
}
if (add_ass) {
ss << "Assistant:";
}
} else {
// template not supported
return -1;
}
dest = ss.str();
return dest.size();
}
int32_t llama_chat_apply_template(
const struct llama_model * model,
const char * tmpl,
const struct llama_chat_message * chat,
size_t n_msg,
bool add_ass,
char * buf,
int32_t length) {
std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
if (tmpl == nullptr) {
GGML_ASSERT(model != nullptr);
// load template from model
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
std::string template_key = "tokenizer.chat_template";
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
if (res < 0) {
// worst case: there is no information about template, we will use chatml by default
curr_tmpl = "chatml"; // see llama_chat_apply_template_internal
} else {
curr_tmpl = std::string(model_template.data(), model_template.size());
}
}
// format the chat to string
std::vector<const llama_chat_message *> chat_vec;
chat_vec.resize(n_msg);
for (size_t i = 0; i < n_msg; i++) {
chat_vec[i] = &chat[i];
}
std::string formatted_chat;
int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
if (res < 0) {
return res;
}
if (buf && length > 0) {
strncpy(buf, formatted_chat.c_str(), length);
}
return res;
}
//
// grammar
//
struct llama_grammar * llama_grammar_init(
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index) {
return llama_grammar_init_impl(rules, n_rules, start_rule_index);
}
void llama_grammar_free(struct llama_grammar * grammar) {
llama_grammar_free_impl(grammar);
}
struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
return llama_grammar_copy_impl(grammar);
}
void llama_grammar_sample(
const struct llama_grammar * grammar,
const struct llama_context * ctx,
llama_token_data_array * candidates) {
llama_grammar_sample(grammar, &ctx->model.vocab, &ctx->sampling, candidates);
}
void llama_sample_grammar(
struct llama_context * ctx,
llama_token_data_array * candidates,
const struct llama_grammar * grammar) {
llama_grammar_sample(grammar, ctx, candidates);
}
void llama_grammar_accept_token(
struct llama_grammar * grammar,
struct llama_context * ctx,
llama_token token) {
llama_grammar_accept_token(grammar, &ctx->model.vocab, &ctx->sampling, token);
}
//
// sampling
//
void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) {
llama_set_rng_seed(&ctx->sampling, seed);
}
void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
llama_sample_softmax(ctx ? &ctx->sampling : nullptr, candidates);
}
void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
llama_sample_top_k(ctx ? &ctx->sampling : nullptr, candidates, k, min_keep);
}
void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
llama_sample_top_p(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
}
void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
llama_sample_min_p(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
}
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
llama_sample_tail_free(ctx ? &ctx->sampling : nullptr, candidates, z, min_keep);
}
void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
llama_sample_typical(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
}
void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
llama_sample_entropy(ctx ? &ctx->sampling : nullptr, candidates_p, min_temp, max_temp, exponent_val);
}
void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
llama_sample_temp(ctx ? &ctx->sampling : nullptr, candidates_p, temp);
}
void llama_sample_repetition_penalties(
struct llama_context * ctx,
llama_token_data_array * candidates,
const llama_token * last_tokens,
size_t penalty_last_n,
float penalty_repeat,
float penalty_freq,
float penalty_present) {
llama_sample_repetition_penalties(ctx ? &ctx->sampling : nullptr, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
}
void llama_sample_apply_guidance(
struct llama_context * ctx,
float * logits,
float * logits_guidance,
float scale) {
llama_sample_apply_guidance(&ctx->sampling, logits, logits_guidance, scale);
}
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
return llama_sample_token_mirostat(&ctx->sampling, candidates, tau, eta, m, mu);
}
llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
return llama_sample_token_mirostat_v2(ctx ? &ctx->sampling : nullptr, candidates, tau, eta, mu);
}
llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) {
return llama_sample_token_greedy(ctx ? &ctx->sampling : nullptr, candidates);
}
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
return llama_sample_token_with_rng(&ctx->sampling, candidates, rng);
}
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
return llama_sample_token_with_rng(&ctx->sampling, candidates, ctx->sampling.rng);
}
int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {