add --in-prefix-bos to prefix BOS to user inputs; keep EOS

The BOS precedes the string specified by `--in-prefix`.
Model generated EOS is now kept in the context.

It provides a way to strictly following the prompt format used in
Llama-2-chat.

The EOS handling also benefits some existing finetunes that uses
EOS to mark the end of turn.
This commit is contained in:
Xiao-Yong Jin 2023-07-20 22:04:06 -05:00
parent e782c9e735
commit 47031e4094
3 changed files with 34 additions and 17 deletions

View file

@ -421,6 +421,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
exit(0); exit(0);
} else if (arg == "--random-prompt") { } else if (arg == "--random-prompt") {
params.random_prompt = true; params.random_prompt = true;
} else if (arg == "--in-prefix-bos") {
params.input_prefix_bos = true;
} else if (arg == "--in-prefix") { } else if (arg == "--in-prefix") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -484,6 +486,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " not supported with --interactive or other interactive options\n"); fprintf(stderr, " not supported with --interactive or other interactive options\n");
fprintf(stderr, " --prompt-cache-ro if specified, uses the prompt cache but does not update it.\n"); fprintf(stderr, " --prompt-cache-ro if specified, uses the prompt cache but does not update it.\n");
fprintf(stderr, " --random-prompt start with a randomized prompt.\n"); fprintf(stderr, " --random-prompt start with a randomized prompt.\n");
fprintf(stderr, " --in-prefix-bos prefix BOS to user inputs, preceding the `--in-prefix` string\n");
fprintf(stderr, " --in-prefix STRING string to prefix user inputs with (default: empty)\n"); fprintf(stderr, " --in-prefix STRING string to prefix user inputs with (default: empty)\n");
fprintf(stderr, " --in-suffix STRING string to suffix after user inputs with (default: empty)\n"); fprintf(stderr, " --in-suffix STRING string to suffix after user inputs with (default: empty)\n");
fprintf(stderr, " -f FNAME, --file FNAME\n"); fprintf(stderr, " -f FNAME, --file FNAME\n");

View file

@ -64,6 +64,7 @@ struct gpt_params {
std::string input_prefix = ""; // string to prefix user inputs with std::string input_prefix = ""; // string to prefix user inputs with
std::string input_suffix = ""; // string to suffix user inputs with std::string input_suffix = ""; // string to suffix user inputs with
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
bool input_prefix_bos = false; // if true, prefix BOS to user inputs, before input_prefix
std::string lora_adapter = ""; // lora adapter path std::string lora_adapter = ""; // lora adapter path
std::string lora_base = ""; // base model path for the lora adapter std::string lora_base = ""; // base model path for the lora adapter

View file

@ -327,6 +327,10 @@ int main(int argc, char ** argv) {
} }
} }
if (params.input_prefix_bos) {
fprintf(stderr, "Input prefix with BOS\n");
}
if (!params.input_prefix.empty()) { if (!params.input_prefix.empty()) {
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
} }
@ -602,16 +606,6 @@ int main(int argc, char ** argv) {
last_n_tokens.push_back(id); last_n_tokens.push_back(id);
} }
// replace end of text token with newline token when in interactive mode
if (id == llama_token_eos() && params.interactive && !params.instruct) {
id = llama_token_newline.front();
if (params.antiprompt.size() != 0) {
// tokenize and inject first reverse prompt
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
}
}
// add it to the context // add it to the context
embd.push_back(id); embd.push_back(id);
@ -677,11 +671,34 @@ int main(int argc, char ** argv) {
} }
} }
// deal with end of text token in interactive mode
if (last_n_tokens.back() == llama_token_eos()) {
if (params.interactive) {
if (params.antiprompt.size() != 0) {
// tokenize and inject first reverse prompt
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
is_antiprompt = true;
}
is_interacting = true;
printf("\n");
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
fflush(stdout);
} else if (params.instruct) {
is_interacting = true;
}
}
if (n_past > 0 && is_interacting) { if (n_past > 0 && is_interacting) {
if (params.instruct) { if (params.instruct) {
printf("\n> "); printf("\n> ");
} }
if (params.input_prefix_bos) {
embd_inp.push_back(llama_token_bos());
}
std::string buffer; std::string buffer;
if (!params.input_prefix.empty()) { if (!params.input_prefix.empty()) {
buffer += params.input_prefix; buffer += params.input_prefix;
@ -733,14 +750,10 @@ int main(int argc, char ** argv) {
} }
// end of text token // end of text token
if (!embd.empty() && embd.back() == llama_token_eos()) { if (!embd.empty() && embd.back() == llama_token_eos() && !(params.instruct || params.interactive)) {
if (params.instruct) {
is_interacting = true;
} else {
fprintf(stderr, " [end of text]\n"); fprintf(stderr, " [end of text]\n");
break; break;
} }
}
// In interactive mode, respect the maximum number of tokens and drop back to user input when reached. // In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
if (params.interactive && n_remain <= 0 && params.n_predict != -1) { if (params.interactive && n_remain <= 0 && params.n_predict != -1) {