llama : avoid hardcoded special tokens

This commit is contained in:
Georgi Gerganov 2023-08-18 17:29:20 +03:00
parent 035d511457
commit 5d2656d670
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
11 changed files with 61 additions and 65 deletions

View file

@ -427,7 +427,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} }
params.hellaswag_tasks = std::stoi(argv[i]); params.hellaswag_tasks = std::stoi(argv[i]);
} else if (arg == "--ignore-eos") { } else if (arg == "--ignore-eos") {
params.logit_bias[llama_token_eos()] = -INFINITY; params.ignore_eos = true;
} else if (arg == "--no-penalize-nl") { } else if (arg == "--no-penalize-nl") {
params.penalize_nl = false; params.penalize_nl = false;
} else if (arg == "-l" || arg == "--logit-bias") { } else if (arg == "-l" || arg == "--logit-bias") {
@ -662,7 +662,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
return lparams; return lparams;
} }
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) { std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
auto lparams = llama_context_params_from_gpt_params(params); auto lparams = llama_context_params_from_gpt_params(params);
llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams); llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
@ -691,6 +691,10 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
} }
} }
if (params.ignore_eos) {
params.logit_bias[llama_token_eos(lctx)] = -INFINITY;
}
return std::make_tuple(model, lctx); return std::make_tuple(model, lctx);
} }

View file

@ -32,7 +32,6 @@ struct gpt_params {
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
// sampling parameters // sampling parameters
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
int32_t top_k = 40; // <= 0 to use vocab size int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled float top_p = 0.95f; // 1.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled float tfs_z = 1.00f; // 1.0 = disabled
@ -46,6 +45,8 @@ struct gpt_params {
float mirostat_tau = 5.00f; // target entropy float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate float mirostat_eta = 0.10f; // learning rate
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
// Classifier-Free Guidance // Classifier-Free Guidance
// https://arxiv.org/abs/2306.17806 // https://arxiv.org/abs/2306.17806
std::string cfg_negative_prompt; // string to help guidance std::string cfg_negative_prompt; // string to help guidance
@ -81,6 +82,7 @@ struct gpt_params {
bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool ignore_eos = false; // ignore generated EOS tokens
bool instruct = false; // instruction mode (used for Alpaca models) bool instruct = false; // instruction mode (used for Alpaca models)
bool penalize_nl = true; // consider newlines as a repeatable token bool penalize_nl = true; // consider newlines as a repeatable token
bool perplexity = false; // compute perplexity over the prompt bool perplexity = false; // compute perplexity over the prompt
@ -102,7 +104,7 @@ std::string gpt_random_prompt(std::mt19937 & rng);
// Model utils // Model utils
// //
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params); std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
// //

View file

@ -167,7 +167,7 @@ llama_token sampling_id(struct MyModel* mymodel) {
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// TODO: Apply penalties // TODO: Apply penalties
// float nl_logit = logits[llama_token_nl()]; // float nl_logit = logits[llama_token_nl(ctx)];
// auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); // auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
// llama_sample_repetition_penalty(ctx, &candidates_p, // llama_sample_repetition_penalty(ctx, &candidates_p,
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
@ -176,7 +176,7 @@ llama_token sampling_id(struct MyModel* mymodel) {
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
// last_n_repeat, alpha_frequency, alpha_presence); // last_n_repeat, alpha_frequency, alpha_presence);
// if (!penalize_nl) { // if (!penalize_nl) {
// logits[llama_token_nl()] = nl_logit; // logits[llama_token_nl(ctx)] = nl_logit;
// } // }
if (temp <= 0) { if (temp <= 0) {
@ -211,7 +211,7 @@ const char * sampling(struct MyModel * mymodel) {
llama_context * ctx = mymodel->ctx; llama_context * ctx = mymodel->ctx;
int id = sampling_id(mymodel); int id = sampling_id(mymodel);
static std::string ret; static std::string ret;
if (id == llama_token_eos()) { if (id == llama_token_eos(ctx)) {
ret = "</s>"; ret = "</s>";
} else { } else {
ret = llama_token_to_str(ctx, id); ret = llama_token_to_str(ctx, id);

View file

@ -851,7 +851,7 @@ struct sql_printer : public printer {
}; };
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) { static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
std::vector<llama_token> tokens(n_batch, llama_token_bos()); std::vector<llama_token> tokens(n_batch, llama_token_bos(ctx));
int n_processed = 0; int n_processed = 0;
while (n_processed < n_prompt) { while (n_processed < n_prompt) {
int n_tokens = std::min(n_prompt - n_processed, n_batch); int n_tokens = std::min(n_prompt - n_processed, n_batch);
@ -861,7 +861,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
} }
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
llama_token token = llama_token_bos(); llama_token token = llama_token_bos(ctx);
for (int i = 0; i < n_gen; i++) { for (int i = 0; i < n_gen; i++) {
llama_eval(ctx, &token, 1, n_past + i, n_threads); llama_eval(ctx, &token, 1, n_past + i, n_threads);
} }

View file

@ -143,7 +143,7 @@ int main(int argc, char ** argv) {
{ {
fprintf(stderr, "%s: testing memory usage for n_batch = %d, n_ctx = %d\n", __func__, params.n_batch, params.n_ctx); fprintf(stderr, "%s: testing memory usage for n_batch = %d, n_ctx = %d\n", __func__, params.n_batch, params.n_ctx);
const std::vector<llama_token> tmp(params.n_batch, llama_token_bos()); const std::vector<llama_token> tmp(params.n_batch, llama_token_bos(ctx));
llama_eval(ctx, tmp.data(), tmp.size(), params.n_ctx, params.n_threads); llama_eval(ctx, tmp.data(), tmp.size(), params.n_ctx, params.n_threads);
} }
@ -345,10 +345,9 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
{ {
auto it = params.logit_bias.find(llama_token_eos()); auto it = params.logit_bias.find(llama_token_eos(ctx));
if (it != params.logit_bias.end() && it->second == -INFINITY) { if (it != params.logit_bias.end() && it->second == -INFINITY) {
fprintf(stderr, fprintf(stderr, "%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
"%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
} }
} }
@ -398,7 +397,7 @@ int main(int argc, char ** argv) {
// do one empty run to warm up the model // do one empty run to warm up the model
{ {
const std::vector<llama_token> tmp = { llama_token_bos(), }; const std::vector<llama_token> tmp = { llama_token_bos(ctx), };
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads); llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
llama_reset_timings(ctx); llama_reset_timings(ctx);
} }
@ -582,7 +581,7 @@ int main(int argc, char ** argv) {
} }
// Apply penalties // Apply penalties
float nl_logit = logits[llama_token_nl()]; float nl_logit = logits[llama_token_nl(ctx)];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
llama_sample_repetition_penalty(ctx, &candidates_p, llama_sample_repetition_penalty(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
@ -591,7 +590,7 @@ int main(int argc, char ** argv) {
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence); last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl) { if (!penalize_nl) {
logits[llama_token_nl()] = nl_logit; logits[llama_token_nl(ctx)] = nl_logit;
} }
if (grammar != NULL) { if (grammar != NULL) {
@ -697,7 +696,7 @@ int main(int argc, char ** argv) {
} }
// deal with end of text token in interactive mode // deal with end of text token in interactive mode
if (last_n_tokens.back() == llama_token_eos()) { if (last_n_tokens.back() == llama_token_eos(ctx)) {
if (params.interactive) { if (params.interactive) {
if (params.antiprompt.size() != 0) { if (params.antiprompt.size() != 0) {
// tokenize and inject first reverse prompt // tokenize and inject first reverse prompt
@ -721,7 +720,7 @@ int main(int argc, char ** argv) {
} }
if (params.input_prefix_bos) { if (params.input_prefix_bos) {
embd_inp.push_back(llama_token_bos()); embd_inp.push_back(llama_token_bos(ctx));
} }
std::string buffer; std::string buffer;
@ -786,7 +785,7 @@ int main(int argc, char ** argv) {
} }
// end of text token // end of text token
if (!embd.empty() && embd.back() == llama_token_eos() && !(params.instruct || params.interactive)) { if (!embd.empty() && embd.back() == llama_token_eos(ctx) && !(params.instruct || params.interactive)) {
fprintf(stderr, " [end of text]\n"); fprintf(stderr, " [end of text]\n");
break; break;
} }

View file

@ -63,7 +63,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
// add BOS token for the first batch of each chunk // add BOS token for the first batch of each chunk
if (j == 0) { if (j == 0) {
tokens[batch_start] = llama_token_bos(); tokens[batch_start] = llama_token_bos(ctx);
} }
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {

View file

@ -279,7 +279,7 @@ struct llama_server_context
grammar_parser::print_grammar(stderr, parsed_grammar); grammar_parser::print_grammar(stderr, parsed_grammar);
{ {
auto it = params.logit_bias.find(llama_token_eos()); auto it = params.logit_bias.find(llama_token_eos(ctx));
if (it != params.logit_bias.end() && it->second == -INFINITY) { if (it != params.logit_bias.end() && it->second == -INFINITY) {
LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {}); LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
} }
@ -402,7 +402,7 @@ struct llama_server_context
if (params.n_predict == 0) if (params.n_predict == 0)
{ {
has_next_token = false; has_next_token = false;
result.tok = llama_token_eos(); result.tok = llama_token_eos(ctx);
return result; return result;
} }
@ -442,7 +442,7 @@ struct llama_server_context
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
// Apply penalties // Apply penalties
float nl_logit = logits[llama_token_nl()]; float nl_logit = logits[llama_token_nl(ctx)];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx); auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
llama_sample_repetition_penalty(ctx, &candidates_p, llama_sample_repetition_penalty(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
@ -452,7 +452,7 @@ struct llama_server_context
last_n_repeat, alpha_frequency, alpha_presence); last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl) if (!penalize_nl)
{ {
logits[llama_token_nl()] = nl_logit; logits[llama_token_nl(ctx)] = nl_logit;
} }
if (grammar != nullptr) { if (grammar != nullptr) {
@ -515,7 +515,7 @@ struct llama_server_context
// decrement remaining sampling budget // decrement remaining sampling budget
--n_remain; --n_remain;
if (!embd.empty() && embd.back() == llama_token_eos()) if (!embd.empty() && embd.back() == llama_token_eos(ctx))
{ {
// stopping_word = llama_token_to_str(ctx, embd.back()); // stopping_word = llama_token_to_str(ctx, embd.back());
has_next_token = false; has_next_token = false;
@ -949,7 +949,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
static json format_generation_settings(llama_server_context &llama) static json format_generation_settings(llama_server_context &llama)
{ {
const auto eos_bias = llama.params.logit_bias.find(llama_token_eos()); const auto eos_bias = llama.params.logit_bias.find(llama_token_eos(llama.ctx));
const bool ignore_eos = eos_bias != llama.params.logit_bias.end() && const bool ignore_eos = eos_bias != llama.params.logit_bias.end() &&
eos_bias->second < 0.0f && std::isinf(eos_bias->second); eos_bias->second < 0.0f && std::isinf(eos_bias->second);
@ -1084,7 +1084,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla
llama.params.logit_bias.clear(); llama.params.logit_bias.clear();
if (body.value("ignore_eos", false)) if (body.value("ignore_eos", false))
{ {
llama.params.logit_bias[llama_token_eos()] = -INFINITY; llama.params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY;
} }
const auto &logit_bias = body.find("logit_bias"); const auto &logit_bias = body.find("logit_bias");

View file

@ -106,7 +106,7 @@ int main(int argc, char ** argv) {
new_token_id = llama_sample_token_greedy(ctx , &candidates_p); new_token_id = llama_sample_token_greedy(ctx , &candidates_p);
// is it an end of stream ? // is it an end of stream ?
if (new_token_id == llama_token_eos()) { if (new_token_id == llama_token_eos(ctx)) {
fprintf(stderr, " [end of text]\n"); fprintf(stderr, " [end of text]\n");
break; break;
} }

View file

@ -1996,7 +1996,7 @@ void print_tokens_batch(struct llama_context* ctx, struct ggml_tensor * tokens)
} }
} }
void get_example_targets(const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) { void get_example_targets(struct llama_context * lctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
int n_tokens = tokens_input->ne[0]; int n_tokens = tokens_input->ne[0];
int n_vocab = target_logits->ne[0]; int n_vocab = target_logits->ne[0];
@ -2005,7 +2005,7 @@ void get_example_targets(const int * train_samples, size_t n_train_samples, cons
ggml_set_f32(target_logits, -1.0f/n_vocab); ggml_set_f32(target_logits, -1.0f/n_vocab);
ggml_set_f32(target_probs, 0.0f); ggml_set_f32(target_probs, 0.0f);
ggml_set_i32_1d(tokens_input, 0, llama_token_bos()); ggml_set_i32_1d(tokens_input, 0, llama_token_bos(lctx));
for (int i=1; i<n_tokens+1; ++i) { for (int i=1; i<n_tokens+1; ++i) {
int token = clamp(train_data[sample+i-1], 0, n_vocab-1); int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
set_f32_2d(target_logits, token, i-1, +1.0f); set_f32_2d(target_logits, token, i-1, +1.0f);
@ -2016,7 +2016,7 @@ void get_example_targets(const int * train_samples, size_t n_train_samples, cons
} }
} }
void get_example_targets_batch(struct llama_context * /*lctx*/, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) { void get_example_targets_batch(struct llama_context * lctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
GGML_ASSERT(tokens_input->n_dims == 2); GGML_ASSERT(tokens_input->n_dims == 2);
GGML_ASSERT(target_logits->n_dims == 3); GGML_ASSERT(target_logits->n_dims == 3);
GGML_ASSERT(target_probs->n_dims == 3); GGML_ASSERT(target_probs->n_dims == 3);
@ -2036,7 +2036,7 @@ void get_example_targets_batch(struct llama_context * /*lctx*/, const int * trai
size_t sample = train_samples[(example_id*n_batch + k) % n_train_samples]; size_t sample = train_samples[(example_id*n_batch + k) % n_train_samples];
GGML_ASSERT(sample+n_tokens-1 < n_train_data); GGML_ASSERT(sample+n_tokens-1 < n_train_data);
set_i32_2d(tokens_input, 0, k, llama_token_bos()); set_i32_2d(tokens_input, 0, k, llama_token_bos(lctx));
for (int i=1; i<n_tokens+1; ++i) { for (int i=1; i<n_tokens+1; ++i) {
int token = clamp(train_data[sample+i-1], 0, n_vocab-1); int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
// print_token(lctx, token); // print_token(lctx, token);
@ -2294,7 +2294,7 @@ llama_token sample(struct my_llama_sampler * sampler, float * logits, const llam
const auto params = sampler->params; const auto params = sampler->params;
// Apply penalties // Apply penalties
const float nl_logit = logits[llama_token_nl()]; const float nl_logit = logits[llama_token_nl(ctx)];
const int n_last = std::min(std::min(n_last_tokens, params.repeat_last_n), sampler->n_ctx); const int n_last = std::min(std::min(n_last_tokens, params.repeat_last_n), sampler->n_ctx);
@ -2313,7 +2313,7 @@ llama_token sample(struct my_llama_sampler * sampler, float * logits, const llam
params.alpha_presence); params.alpha_presence);
if (!params.penalize_nl) { if (!params.penalize_nl) {
logits[llama_token_nl()] = nl_logit; logits[llama_token_nl(ctx)] = nl_logit;
} }
llama_token token = 0; llama_token token = 0;
@ -3181,7 +3181,7 @@ int main(int argc, char ** argv) {
std::vector<int> train_samples; std::vector<int> train_samples;
train_samples.push_back(0); train_samples.push_back(0);
for (int i = 1; i < (int) train_tokens.size() - n_tokens; ++i) { for (int i = 1; i < (int) train_tokens.size() - n_tokens; ++i) {
if (!params.samples_start_after_nl || (train_tokens[i-1] == llama_token_nl())) { if (!params.samples_start_after_nl || (train_tokens[i-1] == llama_token_nl(lctx))) {
train_samples.push_back(i); train_samples.push_back(i);
} }
} }
@ -3341,7 +3341,7 @@ int main(int argc, char ** argv) {
struct ggml_tensor * target_logits = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens); struct ggml_tensor * target_logits = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens);
struct ggml_tensor * target_probs = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens); struct ggml_tensor * target_probs = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens);
get_example_targets(train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), rand()%train_samples.size(), tokens_input, target_logits, target_probs); get_example_targets(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), rand()%train_samples.size(), tokens_input, target_logits, target_probs);
for (int i=sample_ctx; i<n_tokens; ++i) { for (int i=sample_ctx; i<n_tokens; ++i) {
ggml_set_i32_1d(tokens_input, i, n_vocab/2); ggml_set_i32_1d(tokens_input, i, n_vocab/2);
} }

View file

@ -780,13 +780,14 @@ struct llama_vocab {
std::unordered_map<token, id> token_to_id; std::unordered_map<token, id> token_to_id;
std::vector<token_score> id_to_token; std::vector<token_score> id_to_token;
id special_bos_id = -1; // default LLaMA special tokens
id special_eos_id = -1; id special_bos_id = 1;
id special_eos_id = 2;
id special_unk_id = -1; id special_unk_id = -1;
id special_sep_id = -1; id special_sep_id = -1;
id special_pad_id = -1; id special_pad_id = -1;
id linefeed_id = -1; id linefeed_id = 13;
}; };
struct llama_model { struct llama_model {
@ -2351,21 +2352,11 @@ static bool llama_is_control_token(const llama_vocab & vocab, llama_token token)
} }
static bool llama_is_bos_token(const llama_vocab & vocab, llama_token token) { static bool llama_is_bos_token(const llama_vocab & vocab, llama_token token) {
if (llama_vocab_type(vocab) == "spm") { return token == vocab.special_bos_id;
return token == 1;
}
// TODO: improve?
return false;
} }
static bool llama_is_eos_token(const llama_vocab & vocab, llama_token token) { static bool llama_is_eos_token(const llama_vocab & vocab, llama_token token) {
if (llama_vocab_type(vocab) == "spm") { return token == vocab.special_eos_id;
return token == 2;
}
// TODO: improve?
return false;
} }
static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token token) { static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token token) {
@ -2608,7 +2599,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
} }
if (bos) { if (bos) {
output.push_back(llama_token_bos()); output.push_back(vocab.special_bos_id);
} }
std::string text; std::string text;
@ -3293,7 +3284,7 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
} }
} }
const llama_token eos = llama_token_eos(); const llama_token eos = llama_token_eos(ctx);
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded; std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
std::vector<llama_grammar_candidate> candidates_grammar; std::vector<llama_grammar_candidate> candidates_grammar;
@ -3503,7 +3494,7 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
if (token == llama_token_eos()) { if (token == llama_token_eos(ctx)) {
for (const auto & stack : grammar->stacks) { for (const auto & stack : grammar->stacks) {
if (stack.empty()) { if (stack.empty()) {
return; return;
@ -4340,7 +4331,7 @@ struct llama_context * llama_new_context_with_model(
// build worst-case graph // build worst-case graph
int n_tokens = std::min((int)hparams.n_ctx, params.n_batch); int n_tokens = std::min((int)hparams.n_ctx, params.n_batch);
int n_past = hparams.n_ctx - n_tokens; int n_past = hparams.n_ctx - n_tokens;
llama_token token = llama_token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
ggml_cgraph * gf = llama_build_graph(*ctx, &token, NULL, n_tokens, n_past); ggml_cgraph * gf = llama_build_graph(*ctx, &token, NULL, n_tokens, n_past);
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
if (params.n_gpu_layers > 0) { if (params.n_gpu_layers > 0) {
@ -4950,7 +4941,7 @@ int llama_eval_export(struct llama_context * ctx, const char * fname) {
const int n_batch = 1; const int n_batch = 1;
const int n_ctx = 512 - n_batch; const int n_ctx = 512 - n_batch;
const std::vector<llama_token> tmp(n_batch, llama_token_bos()); const std::vector<llama_token> tmp(n_batch, llama_token_bos(ctx));
if (!llama_eval_internal(*ctx, tmp.data(), nullptr, tmp.size(), n_ctx, 1, fname)) { if (!llama_eval_internal(*ctx, tmp.data(), nullptr, tmp.size(), n_ctx, 1, fname)) {
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
@ -4989,16 +4980,16 @@ int llama_model_get_vocab(
return n; return n;
} }
llama_token llama_token_bos(void) { llama_token llama_token_bos(const struct llama_context * ctx) {
return 1; return ctx->model.vocab.special_bos_id;
} }
llama_token llama_token_eos(void) { llama_token llama_token_eos(const struct llama_context * ctx) {
return 2; return ctx->model.vocab.special_eos_id;
} }
llama_token llama_token_nl(void) { llama_token llama_token_nl(const struct llama_context * ctx) {
return 13; return ctx->model.vocab.linefeed_id;
} }
int llama_tokenize( int llama_tokenize(

View file

@ -340,9 +340,9 @@ extern "C" {
int capacity); int capacity);
// Special tokens // Special tokens
LLAMA_API llama_token llama_token_bos(/*struct llama_model * model*/ void); // beginning-of-sentence LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence
LLAMA_API llama_token llama_token_eos(/*struct llama_model * model*/ void); // end-of-sentence LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence
LLAMA_API llama_token llama_token_nl (/*struct llama_model * model*/ void); // next-line LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line
// //
// Tokenization // Tokenization