From 5d2656d6707399860c7902c2f48fd1c23b3fcd40 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 18 Aug 2023 17:29:20 +0300 Subject: [PATCH] llama : avoid hardcoded special tokens --- common/common.cpp | 8 +++- common/common.h | 6 ++- examples/embd-input/embd-input-lib.cpp | 6 +-- examples/llama-bench/llama-bench.cpp | 4 +- examples/main/main.cpp | 19 ++++---- examples/perplexity/perplexity.cpp | 2 +- examples/server/server.cpp | 14 +++--- examples/simple/simple.cpp | 2 +- .../train-text-from-scratch.cpp | 16 +++---- llama.cpp | 43 ++++++++----------- llama.h | 6 +-- 11 files changed, 61 insertions(+), 65 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 8ea7bdda0..d7e1a5725 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -427,7 +427,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } params.hellaswag_tasks = std::stoi(argv[i]); } else if (arg == "--ignore-eos") { - params.logit_bias[llama_token_eos()] = -INFINITY; + params.ignore_eos = true; } else if (arg == "--no-penalize-nl") { params.penalize_nl = false; } else if (arg == "-l" || arg == "--logit-bias") { @@ -662,7 +662,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param return lparams; } -std::tuple llama_init_from_gpt_params(const gpt_params & params) { +std::tuple llama_init_from_gpt_params(gpt_params & params) { auto lparams = llama_context_params_from_gpt_params(params); llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams); @@ -691,6 +691,10 @@ std::tuple llama_init_from_gpt_par } } + if (params.ignore_eos) { + params.logit_bias[llama_token_eos(lctx)] = -INFINITY; + } + return std::make_tuple(model, lctx); } diff --git a/common/common.h b/common/common.h index 50145c932..c50a6edfc 100644 --- a/common/common.h +++ b/common/common.h @@ -32,7 +32,6 @@ struct gpt_params { float rope_freq_scale = 1.0f; // RoPE frequency scaling factor // sampling parameters - std::unordered_map logit_bias; // logit bias for specific tokens int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled float tfs_z = 1.00f; // 1.0 = disabled @@ -46,6 +45,8 @@ struct gpt_params { float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate + std::unordered_map logit_bias; // logit bias for specific tokens + // Classifier-Free Guidance // https://arxiv.org/abs/2306.17806 std::string cfg_negative_prompt; // string to help guidance @@ -81,6 +82,7 @@ struct gpt_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix + bool ignore_eos = false; // ignore generated EOS tokens bool instruct = false; // instruction mode (used for Alpaca models) bool penalize_nl = true; // consider newlines as a repeatable token bool perplexity = false; // compute perplexity over the prompt @@ -102,7 +104,7 @@ std::string gpt_random_prompt(std::mt19937 & rng); // Model utils // -std::tuple llama_init_from_gpt_params(const gpt_params & params); +std::tuple llama_init_from_gpt_params(gpt_params & params); struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); // diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index 2185b9b0e..8a6ad882e 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -167,7 +167,7 @@ llama_token sampling_id(struct MyModel* mymodel) { llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; // TODO: Apply penalties - // float nl_logit = logits[llama_token_nl()]; + // float nl_logit = logits[llama_token_nl(ctx)]; // auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); // llama_sample_repetition_penalty(ctx, &candidates_p, // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, @@ -176,7 +176,7 @@ llama_token sampling_id(struct MyModel* mymodel) { // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, // last_n_repeat, alpha_frequency, alpha_presence); // if (!penalize_nl) { - // logits[llama_token_nl()] = nl_logit; + // logits[llama_token_nl(ctx)] = nl_logit; // } if (temp <= 0) { @@ -211,7 +211,7 @@ const char * sampling(struct MyModel * mymodel) { llama_context * ctx = mymodel->ctx; int id = sampling_id(mymodel); static std::string ret; - if (id == llama_token_eos()) { + if (id == llama_token_eos(ctx)) { ret = ""; } else { ret = llama_token_to_str(ctx, id); diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 1e2b892fa..d11fff288 100755 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -851,7 +851,7 @@ struct sql_printer : public printer { }; static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) { - std::vector tokens(n_batch, llama_token_bos()); + std::vector tokens(n_batch, llama_token_bos(ctx)); int n_processed = 0; while (n_processed < n_prompt) { int n_tokens = std::min(n_prompt - n_processed, n_batch); @@ -861,7 +861,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat } static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { - llama_token token = llama_token_bos(); + llama_token token = llama_token_bos(ctx); for (int i = 0; i < n_gen; i++) { llama_eval(ctx, &token, 1, n_past + i, n_threads); } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 5c2f64883..388e1f7d7 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -143,7 +143,7 @@ int main(int argc, char ** argv) { { fprintf(stderr, "%s: testing memory usage for n_batch = %d, n_ctx = %d\n", __func__, params.n_batch, params.n_ctx); - const std::vector tmp(params.n_batch, llama_token_bos()); + const std::vector tmp(params.n_batch, llama_token_bos(ctx)); llama_eval(ctx, tmp.data(), tmp.size(), params.n_ctx, params.n_threads); } @@ -345,10 +345,9 @@ int main(int argc, char ** argv) { fprintf(stderr, "\n"); { - auto it = params.logit_bias.find(llama_token_eos()); + auto it = params.logit_bias.find(llama_token_eos(ctx)); if (it != params.logit_bias.end() && it->second == -INFINITY) { - fprintf(stderr, - "%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); + fprintf(stderr, "%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); } } @@ -398,7 +397,7 @@ int main(int argc, char ** argv) { // do one empty run to warm up the model { - const std::vector tmp = { llama_token_bos(), }; + const std::vector tmp = { llama_token_bos(ctx), }; llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads); llama_reset_timings(ctx); } @@ -582,7 +581,7 @@ int main(int argc, char ** argv) { } // Apply penalties - float nl_logit = logits[llama_token_nl()]; + float nl_logit = logits[llama_token_nl(ctx)]; auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); llama_sample_repetition_penalty(ctx, &candidates_p, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, @@ -591,7 +590,7 @@ int main(int argc, char ** argv) { last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, last_n_repeat, alpha_frequency, alpha_presence); if (!penalize_nl) { - logits[llama_token_nl()] = nl_logit; + logits[llama_token_nl(ctx)] = nl_logit; } if (grammar != NULL) { @@ -697,7 +696,7 @@ int main(int argc, char ** argv) { } // deal with end of text token in interactive mode - if (last_n_tokens.back() == llama_token_eos()) { + if (last_n_tokens.back() == llama_token_eos(ctx)) { if (params.interactive) { if (params.antiprompt.size() != 0) { // tokenize and inject first reverse prompt @@ -721,7 +720,7 @@ int main(int argc, char ** argv) { } if (params.input_prefix_bos) { - embd_inp.push_back(llama_token_bos()); + embd_inp.push_back(llama_token_bos(ctx)); } std::string buffer; @@ -786,7 +785,7 @@ int main(int argc, char ** argv) { } // end of text token - if (!embd.empty() && embd.back() == llama_token_eos() && !(params.instruct || params.interactive)) { + if (!embd.empty() && embd.back() == llama_token_eos(ctx) && !(params.instruct || params.interactive)) { fprintf(stderr, " [end of text]\n"); break; } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index b9b28a20b..9eadbeaa9 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -63,7 +63,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) { // add BOS token for the first batch of each chunk if (j == 0) { - tokens[batch_start] = llama_token_bos(); + tokens[batch_start] = llama_token_bos(ctx); } if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 9337e2104..a04f1910c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -279,7 +279,7 @@ struct llama_server_context grammar_parser::print_grammar(stderr, parsed_grammar); { - auto it = params.logit_bias.find(llama_token_eos()); + auto it = params.logit_bias.find(llama_token_eos(ctx)); if (it != params.logit_bias.end() && it->second == -INFINITY) { LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {}); } @@ -402,7 +402,7 @@ struct llama_server_context if (params.n_predict == 0) { has_next_token = false; - result.tok = llama_token_eos(); + result.tok = llama_token_eos(ctx); return result; } @@ -442,7 +442,7 @@ struct llama_server_context llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; // Apply penalties - float nl_logit = logits[llama_token_nl()]; + float nl_logit = logits[llama_token_nl(ctx)]; auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx); llama_sample_repetition_penalty(ctx, &candidates_p, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, @@ -452,7 +452,7 @@ struct llama_server_context last_n_repeat, alpha_frequency, alpha_presence); if (!penalize_nl) { - logits[llama_token_nl()] = nl_logit; + logits[llama_token_nl(ctx)] = nl_logit; } if (grammar != nullptr) { @@ -515,7 +515,7 @@ struct llama_server_context // decrement remaining sampling budget --n_remain; - if (!embd.empty() && embd.back() == llama_token_eos()) + if (!embd.empty() && embd.back() == llama_token_eos(ctx)) { // stopping_word = llama_token_to_str(ctx, embd.back()); has_next_token = false; @@ -949,7 +949,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, static json format_generation_settings(llama_server_context &llama) { - const auto eos_bias = llama.params.logit_bias.find(llama_token_eos()); + const auto eos_bias = llama.params.logit_bias.find(llama_token_eos(llama.ctx)); const bool ignore_eos = eos_bias != llama.params.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); @@ -1084,7 +1084,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla llama.params.logit_bias.clear(); if (body.value("ignore_eos", false)) { - llama.params.logit_bias[llama_token_eos()] = -INFINITY; + llama.params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY; } const auto &logit_bias = body.find("logit_bias"); diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 476b31b2e..132f7fbf9 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -106,7 +106,7 @@ int main(int argc, char ** argv) { new_token_id = llama_sample_token_greedy(ctx , &candidates_p); // is it an end of stream ? - if (new_token_id == llama_token_eos()) { + if (new_token_id == llama_token_eos(ctx)) { fprintf(stderr, " [end of text]\n"); break; } diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 79599951c..922518da4 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1996,7 +1996,7 @@ void print_tokens_batch(struct llama_context* ctx, struct ggml_tensor * tokens) } } -void get_example_targets(const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) { +void get_example_targets(struct llama_context * lctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) { int n_tokens = tokens_input->ne[0]; int n_vocab = target_logits->ne[0]; @@ -2005,7 +2005,7 @@ void get_example_targets(const int * train_samples, size_t n_train_samples, cons ggml_set_f32(target_logits, -1.0f/n_vocab); ggml_set_f32(target_probs, 0.0f); - ggml_set_i32_1d(tokens_input, 0, llama_token_bos()); + ggml_set_i32_1d(tokens_input, 0, llama_token_bos(lctx)); for (int i=1; in_dims == 2); GGML_ASSERT(target_logits->n_dims == 3); GGML_ASSERT(target_probs->n_dims == 3); @@ -2036,7 +2036,7 @@ void get_example_targets_batch(struct llama_context * /*lctx*/, const int * trai size_t sample = train_samples[(example_id*n_batch + k) % n_train_samples]; GGML_ASSERT(sample+n_tokens-1 < n_train_data); - set_i32_2d(tokens_input, 0, k, llama_token_bos()); + set_i32_2d(tokens_input, 0, k, llama_token_bos(lctx)); for (int i=1; iparams; // Apply penalties - const float nl_logit = logits[llama_token_nl()]; + const float nl_logit = logits[llama_token_nl(ctx)]; const int n_last = std::min(std::min(n_last_tokens, params.repeat_last_n), sampler->n_ctx); @@ -2313,7 +2313,7 @@ llama_token sample(struct my_llama_sampler * sampler, float * logits, const llam params.alpha_presence); if (!params.penalize_nl) { - logits[llama_token_nl()] = nl_logit; + logits[llama_token_nl(ctx)] = nl_logit; } llama_token token = 0; @@ -3181,7 +3181,7 @@ int main(int argc, char ** argv) { std::vector train_samples; train_samples.push_back(0); for (int i = 1; i < (int) train_tokens.size() - n_tokens; ++i) { - if (!params.samples_start_after_nl || (train_tokens[i-1] == llama_token_nl())) { + if (!params.samples_start_after_nl || (train_tokens[i-1] == llama_token_nl(lctx))) { train_samples.push_back(i); } } @@ -3341,7 +3341,7 @@ int main(int argc, char ** argv) { struct ggml_tensor * target_logits = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens); struct ggml_tensor * target_probs = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens); - get_example_targets(train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), rand()%train_samples.size(), tokens_input, target_logits, target_probs); + get_example_targets(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), rand()%train_samples.size(), tokens_input, target_logits, target_probs); for (int i=sample_ctx; i token_to_id; std::vector id_to_token; - id special_bos_id = -1; - id special_eos_id = -1; + // default LLaMA special tokens + id special_bos_id = 1; + id special_eos_id = 2; id special_unk_id = -1; id special_sep_id = -1; id special_pad_id = -1; - id linefeed_id = -1; + id linefeed_id = 13; }; struct llama_model { @@ -2351,21 +2352,11 @@ static bool llama_is_control_token(const llama_vocab & vocab, llama_token token) } static bool llama_is_bos_token(const llama_vocab & vocab, llama_token token) { - if (llama_vocab_type(vocab) == "spm") { - return token == 1; - } - - // TODO: improve? - return false; + return token == vocab.special_bos_id; } static bool llama_is_eos_token(const llama_vocab & vocab, llama_token token) { - if (llama_vocab_type(vocab) == "spm") { - return token == 2; - } - - // TODO: improve? - return false; + return token == vocab.special_eos_id; } static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token token) { @@ -2608,7 +2599,7 @@ static std::vector llama_tokenize_internal(const llama_vocab & } if (bos) { - output.push_back(llama_token_bos()); + output.push_back(vocab.special_bos_id); } std::string text; @@ -3293,7 +3284,7 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c } } - const llama_token eos = llama_token_eos(); + const llama_token eos = llama_token_eos(ctx); std::vector, llama_partial_utf8>> candidates_decoded; std::vector candidates_grammar; @@ -3503,7 +3494,7 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { const int64_t t_start_sample_us = ggml_time_us(); - if (token == llama_token_eos()) { + if (token == llama_token_eos(ctx)) { for (const auto & stack : grammar->stacks) { if (stack.empty()) { return; @@ -4340,7 +4331,7 @@ struct llama_context * llama_new_context_with_model( // build worst-case graph int n_tokens = std::min((int)hparams.n_ctx, params.n_batch); int n_past = hparams.n_ctx - n_tokens; - llama_token token = llama_token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph ggml_cgraph * gf = llama_build_graph(*ctx, &token, NULL, n_tokens, n_past); #ifdef GGML_USE_METAL if (params.n_gpu_layers > 0) { @@ -4950,7 +4941,7 @@ int llama_eval_export(struct llama_context * ctx, const char * fname) { const int n_batch = 1; const int n_ctx = 512 - n_batch; - const std::vector tmp(n_batch, llama_token_bos()); + const std::vector tmp(n_batch, llama_token_bos(ctx)); if (!llama_eval_internal(*ctx, tmp.data(), nullptr, tmp.size(), n_ctx, 1, fname)) { LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); @@ -4989,16 +4980,16 @@ int llama_model_get_vocab( return n; } -llama_token llama_token_bos(void) { - return 1; +llama_token llama_token_bos(const struct llama_context * ctx) { + return ctx->model.vocab.special_bos_id; } -llama_token llama_token_eos(void) { - return 2; +llama_token llama_token_eos(const struct llama_context * ctx) { + return ctx->model.vocab.special_eos_id; } -llama_token llama_token_nl(void) { - return 13; +llama_token llama_token_nl(const struct llama_context * ctx) { + return ctx->model.vocab.linefeed_id; } int llama_tokenize( diff --git a/llama.h b/llama.h index e2b28afbb..54081840b 100644 --- a/llama.h +++ b/llama.h @@ -340,9 +340,9 @@ extern "C" { int capacity); // Special tokens - LLAMA_API llama_token llama_token_bos(/*struct llama_model * model*/ void); // beginning-of-sentence - LLAMA_API llama_token llama_token_eos(/*struct llama_model * model*/ void); // end-of-sentence - LLAMA_API llama_token llama_token_nl (/*struct llama_model * model*/ void); // next-line + LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence + LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence + LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line // // Tokenization