From 47031e40948a923b62acae954003bfe693864da3 Mon Sep 17 00:00:00 2001 From: Xiao-Yong Jin Date: Thu, 20 Jul 2023 22:04:06 -0500 Subject: [PATCH] add `--in-prefix-bos` to prefix BOS to user inputs; keep EOS The BOS precedes the string specified by `--in-prefix`. Model generated EOS is now kept in the context. It provides a way to strictly following the prompt format used in Llama-2-chat. The EOS handling also benefits some existing finetunes that uses EOS to mark the end of turn. --- examples/common.cpp | 3 +++ examples/common.h | 1 + examples/main/main.cpp | 47 +++++++++++++++++++++++++++--------------- 3 files changed, 34 insertions(+), 17 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index fd6dbc0e3..e1e659d66 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -421,6 +421,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { exit(0); } else if (arg == "--random-prompt") { params.random_prompt = true; + } else if (arg == "--in-prefix-bos") { + params.input_prefix_bos = true; } else if (arg == "--in-prefix") { if (++i >= argc) { invalid_param = true; @@ -484,6 +486,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " not supported with --interactive or other interactive options\n"); fprintf(stderr, " --prompt-cache-ro if specified, uses the prompt cache but does not update it.\n"); fprintf(stderr, " --random-prompt start with a randomized prompt.\n"); + fprintf(stderr, " --in-prefix-bos prefix BOS to user inputs, preceding the `--in-prefix` string\n"); fprintf(stderr, " --in-prefix STRING string to prefix user inputs with (default: empty)\n"); fprintf(stderr, " --in-suffix STRING string to suffix after user inputs with (default: empty)\n"); fprintf(stderr, " -f FNAME, --file FNAME\n"); diff --git a/examples/common.h b/examples/common.h index 037a4eecb..f6718d274 100644 --- a/examples/common.h +++ b/examples/common.h @@ -64,6 +64,7 @@ struct gpt_params { std::string input_prefix = ""; // string to prefix user inputs with std::string input_suffix = ""; // string to suffix user inputs with std::vector antiprompt; // string upon seeing which more user input is prompted + bool input_prefix_bos = false; // if true, prefix BOS to user inputs, before input_prefix std::string lora_adapter = ""; // lora adapter path std::string lora_base = ""; // base model path for the lora adapter diff --git a/examples/main/main.cpp b/examples/main/main.cpp index bcbcf12b0..7307db7c5 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -327,6 +327,10 @@ int main(int argc, char ** argv) { } } + if (params.input_prefix_bos) { + fprintf(stderr, "Input prefix with BOS\n"); + } + if (!params.input_prefix.empty()) { fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); } @@ -602,16 +606,6 @@ int main(int argc, char ** argv) { last_n_tokens.push_back(id); } - // replace end of text token with newline token when in interactive mode - if (id == llama_token_eos() && params.interactive && !params.instruct) { - id = llama_token_newline.front(); - if (params.antiprompt.size() != 0) { - // tokenize and inject first reverse prompt - const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false); - embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end()); - } - } - // add it to the context embd.push_back(id); @@ -677,11 +671,34 @@ int main(int argc, char ** argv) { } } + // deal with end of text token in interactive mode + if (last_n_tokens.back() == llama_token_eos()) { + if (params.interactive) { + if (params.antiprompt.size() != 0) { + // tokenize and inject first reverse prompt + const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false); + embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end()); + is_antiprompt = true; + } + + is_interacting = true; + printf("\n"); + console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + fflush(stdout); + } else if (params.instruct) { + is_interacting = true; + } + } + if (n_past > 0 && is_interacting) { if (params.instruct) { printf("\n> "); } + if (params.input_prefix_bos) { + embd_inp.push_back(llama_token_bos()); + } + std::string buffer; if (!params.input_prefix.empty()) { buffer += params.input_prefix; @@ -733,13 +750,9 @@ int main(int argc, char ** argv) { } // end of text token - if (!embd.empty() && embd.back() == llama_token_eos()) { - if (params.instruct) { - is_interacting = true; - } else { - fprintf(stderr, " [end of text]\n"); - break; - } + if (!embd.empty() && embd.back() == llama_token_eos() && !(params.instruct || params.interactive)) { + fprintf(stderr, " [end of text]\n"); + break; } // In interactive mode, respect the maximum number of tokens and drop back to user input when reached.