llama : suffix the internal APIs with "_impl"

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-07-22 19:59:00 +03:00
parent 39fbaf9f50
commit dae3cae841
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
7 changed files with 179 additions and 166 deletions

View file

@ -464,7 +464,7 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * gram
return result;
}
void llama_grammar_sample(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
GGML_ASSERT(grammar);
GGML_ASSERT(vocab);
@ -488,7 +488,7 @@ void llama_grammar_sample(const struct llama_grammar * grammar, const struct lla
const llama_token id = candidates->data[i].id;
const std::string & piece = vocab->cache_token_to_piece.at(id);
if (llama_token_is_eog(*vocab, id)) {
if (llama_token_is_eog_impl(*vocab, id)) {
if (!allow_eog) {
candidates->data[i].logit = -INFINITY;
}
@ -508,10 +508,10 @@ void llama_grammar_sample(const struct llama_grammar * grammar, const struct lla
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
void llama_grammar_accept_token(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
const int64_t t_start_sample_us = ggml_time_us();
if (llama_token_is_eog(*vocab, token)) {
if (llama_token_is_eog_impl(*vocab, token)) {
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
return;

View file

@ -15,6 +15,10 @@ struct llama_grammar {
struct llama_grammar * llama_get_grammar(struct llama_context * ctx);
//
// internal API
//
struct llama_grammar * llama_grammar_init_impl(
const llama_grammar_element ** rules,
size_t n_rules,
@ -24,13 +28,13 @@ void llama_grammar_free_impl(struct llama_grammar * grammar);
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar);
void llama_grammar_sample(
void llama_grammar_sample_impl(
const struct llama_grammar * grammar,
const struct llama_vocab * vocab,
const struct llama_sampling * smpl,
llama_token_data_array * candidates);
void llama_grammar_accept_token(
void llama_grammar_accept_token_impl(
struct llama_grammar * grammar,
const struct llama_vocab * vocab,
const struct llama_sampling * smpl,

View file

@ -21,7 +21,7 @@ static void llama_log_softmax(float * array, size_t size) {
}
}
void llama_set_rng_seed(struct llama_sampling * smpl, uint32_t seed) {
void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
if (seed == LLAMA_DEFAULT_SEED) {
seed = time(NULL);
}
@ -29,7 +29,7 @@ void llama_set_rng_seed(struct llama_sampling * smpl, uint32_t seed) {
smpl->rng.seed(seed);
}
void llama_sample_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) {
void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
GGML_ASSERT(candidates->size > 0);
const int64_t t_start_sample_us = ggml_time_us();
@ -58,7 +58,7 @@ void llama_sample_softmax(struct llama_sampling * smpl, llama_token_data_array *
}
}
void llama_sample_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
// if (k >= (int32_t)candidates->size) {
// return;
@ -139,12 +139,12 @@ void llama_sample_top_k(struct llama_sampling * smpl, llama_token_data_array * c
}
}
void llama_sample_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
if (p >= 1.0f) {
return;
}
llama_sample_softmax(smpl, candidates);
llama_sample_softmax_impl(smpl, candidates);
const int64_t t_start_sample_us = ggml_time_us();
@ -171,7 +171,7 @@ void llama_sample_top_p(struct llama_sampling * smpl, llama_token_data_array * c
}
}
void llama_sample_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
if (p <= 0.0f || !candidates->size) {
return;
}
@ -232,12 +232,12 @@ void llama_sample_min_p(struct llama_sampling * smpl, llama_token_data_array * c
}
}
void llama_sample_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
if (z >= 1.0f || candidates->size <= 2) {
return;
}
llama_sample_softmax((struct llama_sampling *) nullptr, candidates);
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
const int64_t t_start_sample_us = ggml_time_us();
// Compute the first and second derivatives
@ -291,7 +291,7 @@ void llama_sample_tail_free(struct llama_sampling * smpl, llama_token_data_array
}
}
void llama_sample_typical(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
// Reference implementation:
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
if (p >= 1.0f) {
@ -299,7 +299,7 @@ void llama_sample_typical(struct llama_sampling * smpl, llama_token_data_array *
}
// Compute the softmax of logits and calculate entropy
llama_sample_softmax((struct llama_sampling *) nullptr, candidates);
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
const int64_t t_start_sample_us = ggml_time_us();
@ -355,7 +355,7 @@ void llama_sample_typical(struct llama_sampling * smpl, llama_token_data_array *
}
}
void llama_sample_entropy(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
const int64_t t_start_sample_us = ggml_time_us();
// no need to do anything if there is only one (or zero) candidates
@ -366,7 +366,7 @@ void llama_sample_entropy(struct llama_sampling * smpl, llama_token_data_array *
// Calculate maximum possible entropy
float max_entropy = -logf(1.0f / candidates->size);
llama_sample_softmax((struct llama_sampling *) nullptr, candidates);
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
// Calculate entropy of the softmax probabilities
float entropy = 0.0f;
@ -422,7 +422,7 @@ void llama_sample_entropy(struct llama_sampling * smpl, llama_token_data_array *
}
}
void llama_sample_temp(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
const int64_t t_start_sample_us = ggml_time_us();
for (size_t i = 0; i < candidates->size; ++i) {
@ -434,7 +434,7 @@ void llama_sample_temp(struct llama_sampling * smpl, llama_token_data_array * ca
}
}
void llama_sample_repetition_penalties(
void llama_sample_repetition_penalties_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
const llama_token * last_tokens,
@ -481,7 +481,7 @@ void llama_sample_repetition_penalties(
}
}
void llama_sample_apply_guidance(
void llama_sample_apply_guidance_impl(
struct llama_sampling * smpl,
float * logits,
float * logits_guidance,
@ -504,14 +504,14 @@ void llama_sample_apply_guidance(
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
llama_token llama_sample_token_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
GGML_ASSERT(smpl);
const int32_t n_vocab = float(smpl->n_vocab);
int64_t t_start_sample_us = ggml_time_us();
llama_sample_softmax((struct llama_sampling *) nullptr, candidates);
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
// Estimate s_hat using the most probable m tokens
float s_hat = 0.0;
@ -530,9 +530,9 @@ llama_token llama_sample_token_mirostat(struct llama_sampling * smpl, llama_toke
float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
// Sample the next word X using top-k sampling
llama_sample_top_k((struct llama_sampling *) nullptr, candidates, int(k), 1);
llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
llama_token X = llama_sample_token(smpl, candidates);
llama_token X = llama_sample_token_impl(smpl, candidates);
t_start_sample_us = ggml_time_us();
// Compute error as the difference between observed surprise and target surprise value
@ -549,11 +549,11 @@ llama_token llama_sample_token_mirostat(struct llama_sampling * smpl, llama_toke
return X;
}
llama_token llama_sample_token_mirostat_v2(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
int64_t t_start_sample_us;
t_start_sample_us = ggml_time_us();
llama_sample_softmax(smpl, candidates);
llama_sample_softmax_impl(smpl, candidates);
// Truncate the words with surprise values greater than mu
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
@ -569,10 +569,10 @@ llama_token llama_sample_token_mirostat_v2(struct llama_sampling * smpl, llama_t
}
// Normalize the probabilities of the remaining words
llama_sample_softmax(smpl, candidates);
llama_sample_softmax_impl(smpl, candidates);
// Sample the next word X from the remaining words
llama_token X = llama_sample_token(smpl, candidates);
llama_token X = llama_sample_token_impl(smpl, candidates);
t_start_sample_us = ggml_time_us();
// Compute error as the difference between observed surprise and target surprise value
@ -591,7 +591,7 @@ llama_token llama_sample_token_mirostat_v2(struct llama_sampling * smpl, llama_t
return X;
}
llama_token llama_sample_token_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) {
llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
const int64_t t_start_sample_us = ggml_time_us();
// Find max element
@ -607,11 +607,11 @@ llama_token llama_sample_token_greedy(struct llama_sampling * smpl, llama_token_
return result;
}
llama_token llama_sample_token_with_rng(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
GGML_ASSERT(smpl);
const int64_t t_start_sample_us = ggml_time_us();
llama_sample_softmax((struct llama_sampling *) nullptr, candidates);
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
std::vector<float> probs;
probs.reserve(candidates->size);
@ -630,6 +630,6 @@ llama_token llama_sample_token_with_rng(struct llama_sampling * smpl, llama_toke
return result;
}
llama_token llama_sample_token(struct llama_sampling * smpl, llama_token_data_array * candidates) {
return llama_sample_token_with_rng(smpl, candidates, smpl->rng);
llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
}

View file

@ -20,18 +20,22 @@ struct llama_sampling {
struct llama_sampling * llama_get_sampling(struct llama_context * ctx);
void llama_set_rng_seed(struct llama_sampling * smpl, uint32_t seed);
//
// internal API
//
void llama_sample_softmax (struct llama_sampling * smpl, llama_token_data_array * candidates);
void llama_sample_top_k (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
void llama_sample_top_p (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_min_p (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
void llama_sample_typical (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_entropy (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
void llama_sample_temp (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed);
void llama_sample_repetition_penalties(
void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
void llama_sample_repetition_penalties_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
const llama_token * last_tokens,
@ -40,15 +44,15 @@ void llama_sample_repetition_penalties(
float penalty_freq,
float penalty_present);
void llama_sample_apply_guidance(
void llama_sample_apply_guidance_impl(
struct llama_sampling * smpl,
float * logits,
float * logits_guidance,
float scale);
llama_token llama_sample_token_mirostat (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
llama_token llama_sample_token_mirostat_v2(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
llama_token llama_sample_token_greedy (struct llama_sampling * smpl, llama_token_data_array * candidates);
llama_token llama_sample_token_with_rng (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
llama_token llama_sample_token (struct llama_sampling * smpl, llama_token_data_array * candidates);
llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);

View file

@ -163,30 +163,6 @@ static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) {
}
}
llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
static const char * hex = "0123456789ABCDEF";
switch (llama_vocab_get_type(vocab)) {
case LLAMA_VOCAB_TYPE_SPM:
case LLAMA_VOCAB_TYPE_UGM: {
const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
auto token = vocab.token_to_id.find(buf);
if (token != vocab.token_to_id.end()) {
return (*token).second;
}
// Try to fall back to just the byte as a string
const char buf2[2] = { (char)ch, 0 };
return vocab.token_to_id.at(buf2);
}
case LLAMA_VOCAB_TYPE_WPM:
case LLAMA_VOCAB_TYPE_BPE: {
return vocab.token_to_id.at(unicode_byte_to_utf8(ch));
}
default:
GGML_ASSERT(false);
}
}
static void llama_escape_whitespace(std::string & text) {
replace_all(text, " ", "\xe2\x96\x81");
}
@ -303,7 +279,7 @@ private:
// output any symbols that did not form tokens as bytes.
output.reserve(output.size() + symbol.n);
for (int j = 0; j < (int)symbol.n; ++j) {
llama_vocab::id token_id = llama_byte_to_token(vocab, symbol.text[j]);
llama_vocab::id token_id = llama_byte_to_token_impl(vocab, symbol.text[j]);
output.push_back(token_id);
}
return;
@ -1426,81 +1402,105 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
return output;
}
const char * llama_token_get_text(const struct llama_vocab & vocab, llama_token token) {
llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch) {
GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
static const char * hex = "0123456789ABCDEF";
switch (llama_vocab_get_type(vocab)) {
case LLAMA_VOCAB_TYPE_SPM:
case LLAMA_VOCAB_TYPE_UGM: {
const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
auto token = vocab.token_to_id.find(buf);
if (token != vocab.token_to_id.end()) {
return (*token).second;
}
// Try to fall back to just the byte as a string
const char buf2[2] = { (char)ch, 0 };
return vocab.token_to_id.at(buf2);
}
case LLAMA_VOCAB_TYPE_WPM:
case LLAMA_VOCAB_TYPE_BPE: {
return vocab.token_to_id.at(unicode_byte_to_utf8(ch));
}
default:
GGML_ASSERT(false);
}
}
const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token) {
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
return vocab.id_to_token[token].text.c_str();
}
float llama_token_get_score(const struct llama_vocab & vocab, llama_token token) {
float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token) {
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
return vocab.id_to_token[token].score;
}
llama_token_attr llama_token_get_attr(const struct llama_vocab & vocab, llama_token token) {
llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token) {
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
return vocab.id_to_token[token].attr;
}
bool llama_token_is_eog(const struct llama_vocab & vocab, llama_token token) {
bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
return token != -1 && (
token == llama_token_eos(vocab) ||
token == llama_token_eot(vocab)
token == llama_token_eos_impl(vocab) ||
token == llama_token_eot_impl(vocab)
);
}
bool llama_token_is_control(const struct llama_vocab & vocab, llama_token token) {
bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
return llama_is_control_token(vocab, token);
}
llama_token llama_token_bos(const struct llama_vocab & vocab) {
llama_token llama_token_bos_impl(const struct llama_vocab & vocab) {
return vocab.special_bos_id;
}
llama_token llama_token_eos(const struct llama_vocab & vocab) {
llama_token llama_token_eos_impl(const struct llama_vocab & vocab) {
return vocab.special_eos_id;
}
llama_token llama_token_cls(const struct llama_vocab & vocab) {
llama_token llama_token_cls_impl(const struct llama_vocab & vocab) {
return vocab.special_cls_id;
}
llama_token llama_token_sep(const struct llama_vocab & vocab) {
llama_token llama_token_sep_impl(const struct llama_vocab & vocab) {
return vocab.special_sep_id;
}
llama_token llama_token_nl(const struct llama_vocab & vocab) {
llama_token llama_token_nl_impl(const struct llama_vocab & vocab) {
return vocab.linefeed_id;
}
llama_token llama_token_pad(const struct llama_vocab & vocab) {
llama_token llama_token_pad_impl(const struct llama_vocab & vocab) {
return vocab.special_pad_id;
}
int32_t llama_add_bos_token(const struct llama_vocab & vocab) {
int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab) {
return vocab.tokenizer_add_bos;
}
int32_t llama_add_eos_token(const struct llama_vocab & vocab) {
int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab) {
return vocab.tokenizer_add_eos;
}
llama_token llama_token_prefix(const struct llama_vocab & vocab) {
llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) {
return vocab.special_prefix_id;
}
llama_token llama_token_middle(const struct llama_vocab & vocab) {
llama_token llama_token_middle_impl(const struct llama_vocab & vocab) {
return vocab.special_middle_id;
}
llama_token llama_token_suffix(const struct llama_vocab & vocab) {
llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) {
return vocab.special_suffix_id;
}
llama_token llama_token_eot(const struct llama_vocab & vocab) {
llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
return vocab.special_eot_id;
}
int32_t llama_tokenize(
int32_t llama_tokenize_impl(
const struct llama_vocab & vocab,
const char * text,
int32_t text_len,
@ -1542,10 +1542,10 @@ static std::string llama_decode_text(const std::string & text) {
}
// does not write null-terminator to buf
int32_t llama_token_to_piece(const struct llama_vocab & vocab, llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) {
int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) {
// ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
const llama_token_attr attr = llama_token_get_attr(vocab, token);
const llama_token_attr attr = llama_token_get_attr_impl(vocab, token);
if (!special && (attr & attr_special)) {
return 0;
}
@ -1613,7 +1613,7 @@ int32_t llama_token_to_piece(const struct llama_vocab & vocab, llama_token token
return 0;
}
int32_t llama_detokenize(
int32_t llama_detokenize_impl(
const struct llama_vocab & vocab,
const llama_token * tokens,
int32_t n_tokens,
@ -1643,7 +1643,7 @@ int32_t llama_detokenize(
for (int32_t i = 0; i < n_tokens; ++i) {
GGML_ASSERT(avail >= 0);
int32_t n_chars = llama_token_to_piece(vocab, tokens[i], text, avail, remove_space, unparse_special);
int32_t n_chars = llama_token_to_piece_impl(vocab, tokens[i], text, avail, remove_space, unparse_special);
remove_space = false;
if (n_chars < 0) {
avail = 0;

View file

@ -64,6 +64,11 @@ struct llama_vocab {
const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx);
const struct llama_vocab * llama_get_vocab(const struct llama_model * model);
//
// internal API
//
// TODO: rename to llama_tokenize_impl
// TODO: This should probably be in llama.h
std::vector<llama_vocab::id> llama_tokenize_internal(
const llama_vocab & vocab,
@ -71,44 +76,44 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
bool add_special,
bool parse_special = false);
llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch);
llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
const char * llama_token_get_text(const struct llama_vocab & vocab, llama_token token);
const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
float llama_token_get_score(const struct llama_vocab & vocab, llama_token token);
float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token);
llama_token_attr llama_token_get_attr(const struct llama_vocab & vocab, llama_token token);
llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token);
bool llama_token_is_eog(const struct llama_vocab & vocab, llama_token token);
bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token);
bool llama_token_is_control(const struct llama_vocab & vocab, llama_token token);
bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token);
llama_token llama_token_bos(const struct llama_vocab & vocab);
llama_token llama_token_eos(const struct llama_vocab & vocab);
llama_token llama_token_cls(const struct llama_vocab & vocab);
llama_token llama_token_sep(const struct llama_vocab & vocab);
llama_token llama_token_nl (const struct llama_vocab & vocab);
llama_token llama_token_pad(const struct llama_vocab & vocab);
llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
int32_t llama_add_bos_token(const struct llama_vocab & vocab);
int32_t llama_add_eos_token(const struct llama_vocab & vocab);
int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab);
int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab);
llama_token llama_token_prefix(const struct llama_vocab & vocab);
llama_token llama_token_middle(const struct llama_vocab & vocab);
llama_token llama_token_suffix(const struct llama_vocab & vocab);
llama_token llama_token_eot (const struct llama_vocab & vocab);
llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
llama_token llama_token_eot_impl (const struct llama_vocab & vocab);
int32_t llama_tokenize(
const struct llama_vocab & vocab,
const char * text,
int32_t text_len,
llama_token * tokens,
int32_t n_tokens_max,
bool add_special,
bool parse_special);
int32_t llama_tokenize_impl(
const struct llama_vocab & vocab,
const char * text,
int32_t text_len,
llama_token * tokens,
int32_t n_tokens_max,
bool add_special,
bool parse_special);
// does not write null-terminator to buf
int32_t llama_token_to_piece(
int32_t llama_token_to_piece_impl(
const struct llama_vocab & vocab,
llama_token token,
char * buf,
@ -116,7 +121,7 @@ int32_t llama_token_to_piece(
int32_t lstrip,
bool special);
int32_t llama_detokenize(
int32_t llama_detokenize_impl(
const struct llama_vocab & vocab,
const llama_token * tokens,
int32_t n_tokens,

View file

@ -5535,7 +5535,7 @@ static void llm_load_vocab(
}
}
try {
vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
vocab.linefeed_id = llama_byte_to_token_impl(vocab, '\n');
} catch (const std::exception & e) {
LLAMA_LOG_WARN("%s: SPM vocabulary, but newline token not found: %s! Using special_pad_id instead.", __func__, e.what());
vocab.linefeed_id = vocab.special_pad_id;
@ -18511,71 +18511,71 @@ float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id
//
const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
return llama_token_get_text(model->vocab, token);
return llama_token_get_text_impl(model->vocab, token);
}
float llama_token_get_score(const struct llama_model * model, llama_token token) {
return llama_token_get_score(model->vocab, token);
return llama_token_get_score_impl(model->vocab, token);
}
enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) {
return llama_token_get_attr(model->vocab, token);
return llama_token_get_attr_impl(model->vocab, token);
}
bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
return llama_token_is_eog(model->vocab, token);
return llama_token_is_eog_impl(model->vocab, token);
}
bool llama_token_is_control(const struct llama_model * model, llama_token token) {
return llama_token_is_control(model->vocab, token);
return llama_token_is_control_impl(model->vocab, token);
}
llama_token llama_token_bos(const struct llama_model * model) {
return llama_token_bos(model->vocab);
return llama_token_bos_impl(model->vocab);
}
llama_token llama_token_eos(const struct llama_model * model) {
return llama_token_eos(model->vocab);
return llama_token_eos_impl(model->vocab);
}
llama_token llama_token_cls(const struct llama_model * model) {
return llama_token_cls(model->vocab);
return llama_token_cls_impl(model->vocab);
}
llama_token llama_token_sep(const struct llama_model * model) {
return llama_token_sep(model->vocab);
return llama_token_sep_impl(model->vocab);
}
llama_token llama_token_nl (const struct llama_model * model) {
return llama_token_nl (model->vocab);
return llama_token_nl_impl(model->vocab);
}
llama_token llama_token_pad(const struct llama_model * model) {
return llama_token_pad(model->vocab);
return llama_token_pad_impl(model->vocab);
}
int32_t llama_add_bos_token(const struct llama_model * model) {
return llama_add_bos_token(model->vocab);
return llama_add_bos_token_impl(model->vocab);
}
int32_t llama_add_eos_token(const struct llama_model * model) {
return llama_add_eos_token(model->vocab);
return llama_add_eos_token_impl(model->vocab);
}
llama_token llama_token_prefix(const struct llama_model * model) {
return llama_token_prefix(model->vocab);
return llama_token_prefix_impl(model->vocab);
}
llama_token llama_token_middle(const struct llama_model * model) {
return llama_token_middle(model->vocab);
return llama_token_middle_impl(model->vocab);
}
llama_token llama_token_suffix(const struct llama_model * model) {
return llama_token_suffix(model->vocab);
return llama_token_suffix_impl(model->vocab);
}
llama_token llama_token_eot(const struct llama_model * model) {
return llama_token_eot(model->vocab);
return llama_token_eot_impl(model->vocab);
}
//
@ -18590,7 +18590,7 @@ int32_t llama_tokenize(
int32_t n_tokens_max,
bool add_special,
bool parse_special) {
return llama_tokenize(model->vocab, text, text_len, tokens, n_tokens_max, add_special, parse_special);
return llama_tokenize_impl(model->vocab, text, text_len, tokens, n_tokens_max, add_special, parse_special);
}
int32_t llama_token_to_piece(
@ -18600,7 +18600,7 @@ int32_t llama_token_to_piece(
int32_t length,
int32_t lstrip,
bool special) {
return llama_token_to_piece(model->vocab, token, buf, length, lstrip, special);
return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
}
int32_t llama_detokenize(
@ -18611,7 +18611,7 @@ int32_t llama_detokenize(
int32_t text_len_max,
bool remove_special,
bool unparse_special) {
return llama_detokenize(model->vocab, tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
return llama_detokenize_impl(model->vocab, tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
}
//
@ -18931,7 +18931,7 @@ void llama_grammar_sample(
const struct llama_grammar * grammar,
const struct llama_context * ctx,
llama_token_data_array * candidates) {
llama_grammar_sample(grammar, &ctx->model.vocab, &ctx->sampling, candidates);
llama_grammar_sample_impl(grammar, &ctx->model.vocab, &ctx->sampling, candidates);
}
void llama_sample_grammar(
@ -18945,7 +18945,7 @@ void llama_grammar_accept_token(
struct llama_grammar * grammar,
struct llama_context * ctx,
llama_token token) {
llama_grammar_accept_token(grammar, &ctx->model.vocab, &ctx->sampling, token);
llama_grammar_accept_token_impl(grammar, &ctx->model.vocab, &ctx->sampling, token);
}
//
@ -18953,39 +18953,39 @@ void llama_grammar_accept_token(
//
void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) {
llama_set_rng_seed(&ctx->sampling, seed);
llama_set_rng_seed_impl(&ctx->sampling, seed);
}
void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
llama_sample_softmax(ctx ? &ctx->sampling : nullptr, candidates);
llama_sample_softmax_impl(ctx ? &ctx->sampling : nullptr, candidates);
}
void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
llama_sample_top_k(ctx ? &ctx->sampling : nullptr, candidates, k, min_keep);
llama_sample_top_k_impl(ctx ? &ctx->sampling : nullptr, candidates, k, min_keep);
}
void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
llama_sample_top_p(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
llama_sample_top_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
}
void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
llama_sample_min_p(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
}
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
llama_sample_tail_free(ctx ? &ctx->sampling : nullptr, candidates, z, min_keep);
llama_sample_tail_free_impl(ctx ? &ctx->sampling : nullptr, candidates, z, min_keep);
}
void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
llama_sample_typical(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
llama_sample_typical_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
}
void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
llama_sample_entropy(ctx ? &ctx->sampling : nullptr, candidates_p, min_temp, max_temp, exponent_val);
llama_sample_entropy_impl(ctx ? &ctx->sampling : nullptr, candidates_p, min_temp, max_temp, exponent_val);
}
void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
llama_sample_temp(ctx ? &ctx->sampling : nullptr, candidates_p, temp);
llama_sample_temp_impl(ctx ? &ctx->sampling : nullptr, candidates_p, temp);
}
void llama_sample_repetition_penalties(
@ -18996,7 +18996,7 @@ void llama_sample_repetition_penalties(
float penalty_repeat,
float penalty_freq,
float penalty_present) {
llama_sample_repetition_penalties(ctx ? &ctx->sampling : nullptr, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
llama_sample_repetition_penalties_impl(ctx ? &ctx->sampling : nullptr, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
}
void llama_sample_apply_guidance(
@ -19004,27 +19004,27 @@ void llama_sample_apply_guidance(
float * logits,
float * logits_guidance,
float scale) {
llama_sample_apply_guidance(&ctx->sampling, logits, logits_guidance, scale);
llama_sample_apply_guidance_impl(&ctx->sampling, logits, logits_guidance, scale);
}
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
return llama_sample_token_mirostat(&ctx->sampling, candidates, tau, eta, m, mu);
return llama_sample_token_mirostat_impl(&ctx->sampling, candidates, tau, eta, m, mu);
}
llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
return llama_sample_token_mirostat_v2(ctx ? &ctx->sampling : nullptr, candidates, tau, eta, mu);
return llama_sample_token_mirostat_v2_impl(ctx ? &ctx->sampling : nullptr, candidates, tau, eta, mu);
}
llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) {
return llama_sample_token_greedy(ctx ? &ctx->sampling : nullptr, candidates);
return llama_sample_token_greedy_impl(ctx ? &ctx->sampling : nullptr, candidates);
}
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
return llama_sample_token_with_rng(&ctx->sampling, candidates, rng);
return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, rng);
}
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
return llama_sample_token_with_rng(&ctx->sampling, candidates, ctx->sampling.rng);
return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, ctx->sampling.rng);
}
int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {