From 7c60721217d59cced92647a6594760bc58a7721d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Pazdiora?= Date: Sun, 9 Apr 2023 19:49:33 +0200 Subject: [PATCH] update implementation --- configs/alpaca.txt | 10 +++-- configs/chat-with-bob.txt | 11 ++++-- configs/vicuna-simple.txt | 11 ++++-- configs/vicuna-stop.txt | 12 ++++-- configs/vicuna.txt | 10 +++-- examples/common.cpp | 39 +++++++++++-------- examples/common.h | 12 ++++-- examples/main/main.cpp | 80 +++++++++++++++++++++++++-------------- 8 files changed, 120 insertions(+), 65 deletions(-) diff --git a/configs/alpaca.txt b/configs/alpaca.txt index 06158cdb9..5e1558d18 100644 --- a/configs/alpaca.txt +++ b/configs/alpaca.txt @@ -1,5 +1,9 @@ ---clean-interface --interactive-first --keep -1 ---in-prefix-bos --in-prefix "\n\n### Instruction:\n\n" -r "### Instruction:\n\n" ---in-suffix "\n\n### Response:\n\n" +--clean-interface +--interactive-first +--keep -1 +--ins-prefix-bos +--ins-prefix "\n\n### Instruction:\n\n" +--ins-suffix "\n\n### Response:\n\n" +--reverse-prompt "### Instruction:\n\n" -p "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n" diff --git a/configs/chat-with-bob.txt b/configs/chat-with-bob.txt index 9019753d0..ede63d35a 100644 --- a/configs/chat-with-bob.txt +++ b/configs/chat-with-bob.txt @@ -1,7 +1,10 @@ ---interactive-first --keep -1 ---in-prefix-bos ---in-prefix "\nUser: " -r "User: " ---in-suffix "\nBob: " +--interactive-first +--keep -1 +--ins-prefix-bos +--ins-prefix "\nUser: " +--ins-suffix "\nBob: " +--reverse-prompt "User: " +--rm-trailing-space-workaround -p "Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision. User: Hello, Bob. diff --git a/configs/vicuna-simple.txt b/configs/vicuna-simple.txt index 053391479..a87c70082 100644 --- a/configs/vicuna-simple.txt +++ b/configs/vicuna-simple.txt @@ -1,4 +1,7 @@ ---interactive-first --keep -1 ---in-prefix-bos ---in-prefix "\n### Human: " --reverse-prompt "### Human: " ---in-suffix "\n### Assistant: " \ No newline at end of file +--interactive-first +--keep -1 +--ins-prefix-bos +--ins-prefix "\n### Human: " +--ins-suffix "\n### Assistant: " +--reverse-prompt "### Human: " +--rm-trailing-space-workaround diff --git a/configs/vicuna-stop.txt b/configs/vicuna-stop.txt index e88a9fd05..a8d890724 100644 --- a/configs/vicuna-stop.txt +++ b/configs/vicuna-stop.txt @@ -1,4 +1,8 @@ ---interactive-first --keep -1 ---in-prefix-bos ---in-prefix "\n### Human: " --reverse-prompt "### Human: " ---in-suffix "\n### Assistant: " --stop-prompt "### Assistant: " \ No newline at end of file +--interactive-first +--keep -1 +--ins-prefix-bos +--ins-prefix "\n### Human: " +--ins-suffix "\n### Assistant: " +--reverse-prompt "### Human: " +--stop-prompt "### Assistant: " +--rm-trailing-space-workaround diff --git a/configs/vicuna.txt b/configs/vicuna.txt index bf3b966c7..f4d0d2d74 100644 --- a/configs/vicuna.txt +++ b/configs/vicuna.txt @@ -1,5 +1,9 @@ ---interactive-first --keep -1 ---in-prefix-bos --in-prefix "\n### Human: " -r "### Human: " ---in-suffix "\n### Assistant: " +--interactive-first +--keep -1 +--ins-prefix-bos +--ins-prefix "\n### Human: " +--ins-suffix "\n### Assistant: " +--reverse-prompt "### Human: " +--rm-trailing-space-workaround -p "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." diff --git a/examples/common.cpp b/examples/common.cpp index 59b0fadfc..9bf5152fe 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -231,18 +231,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { "--clean-interface " "--interactive-first " "--keep -1 " - "--in-prefix-bos " - "--in-prefix \"\\n\\n### Instruction:\\n\\n\" " - "--in-suffix \"\\n\\n### Response:\\n\\n\" " + "--ins-prefix-bos " + "--ins-prefix \"\\n\\n### Instruction:\\n\\n\" " + "--ins-suffix \"\\n\\n### Response:\\n\\n\" " "-r \"### Instruction:\\n\\n\" " "\n\n"); // params.instruct = true; params.clean_interface = true; params.interactive_start = true; params.n_keep = -1; - params.input_prefix_bos = true; - params.input_prefix = "\n\n### Instruction:\n\n"; - params.input_suffix = "\n\n### Response:\n\n"; + params.instruct_prefix_bos = true; + params.instruct_prefix = "\n\n### Instruction:\n\n"; + params.instruct_suffix = "\n\n### Response:\n\n"; params.antiprompt.push_back("### Instruction:\n\n"); } else if (arg == "--color") { params.use_color = true; @@ -268,6 +268,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.stopprompt.push_back(args[i]); + } else if (arg == "--rm-trailing-space-workaround") { + params.rm_trailing_space_workaround = true; } else if (arg == "--perplexity") { params.perplexity = true; } else if (arg == "--ignore-eos") { @@ -283,22 +285,28 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { exit(0); } else if (arg == "--random-prompt") { params.random_prompt = true; - } else if (arg == "--in-prefix-bos") { - params.input_prefix_bos = true; } else if (arg == "--in-prefix") { if (++i >= args_c) { invalid_param = true; break; } params.input_prefix = args[i]; - } else if (arg == "--in-suffix-bos") { - params.input_suffix_bos = true; - } else if (arg == "--in-suffix") { + } else if (arg == "--ins-prefix-bos") { + params.instruct_prefix_bos = true; + } else if (arg == "--ins-prefix") { if (++i >= args_c) { invalid_param = true; break; } - params.input_suffix = args[i]; + params.instruct_prefix = args[i]; + } else if (arg == "--ins-suffix-bos") { + params.instruct_suffix_bos = true; + } else if (arg == "--ins-suffix") { + if (++i >= args_c) { + invalid_param = true; + break; + } + params.instruct_suffix = args[i]; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); gpt_print_usage(argv[0], default_params); @@ -332,10 +340,11 @@ void gpt_print_usage(char * argv_0, const gpt_params & params) { fprintf(stderr, " -p PROMPT, --prompt PROMPT\n"); fprintf(stderr, " prompt to start generation with (default: empty)\n"); fprintf(stderr, " --random-prompt start with a randomized prompt.\n"); - fprintf(stderr, " --in-prefix-bos append bos token to input prefix.\n"); fprintf(stderr, " --in-prefix STRING string to prefix user inputs with (default: empty)\n"); - fprintf(stderr, " --in-suffix-bos append bos token to input suffix.\n"); - fprintf(stderr, " --in-suffix STRING string to suffix user inputs with (default: empty)\n"); + fprintf(stderr, " --ins-prefix STRING (instruct) prefix user inputs with tokenized string (default: empty)\n"); + fprintf(stderr, " --ins-prefix-bos (instruct) prepend bos token to instruct prefix.\n"); + fprintf(stderr, " --ins-suffix STRING (instruct) suffix user inputs with tokenized string (default: empty)\n"); + fprintf(stderr, " --ins-suffix-bos (instruct) prepend bos token to instruct suffix.\n"); fprintf(stderr, " -f FNAME, --file FNAME\n"); fprintf(stderr, " prompt file to start generation.\n"); fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict); diff --git a/examples/common.h b/examples/common.h index 2c4632b41..413604600 100644 --- a/examples/common.h +++ b/examples/common.h @@ -32,12 +32,16 @@ struct gpt_params { std::string model = "models/lamma-7B/ggml-model.bin"; // model path std::string prompt = ""; std::string input_prefix = ""; // string to prefix user inputs with - bool input_prefix_bos = false; // append bos token to input prefix - std::string input_suffix = ""; // string to suffix user inputs with - bool input_suffix_bos = false; // append bos token to input suffix + + std::string instruct_prefix = ""; // prefix user inputs with tokenized string + bool instruct_prefix_bos = false; // prepend bos token to instruct prefix + std::string instruct_suffix = ""; // suffix user inputs with tokenized string + bool instruct_suffix_bos = false; // prepend bos token to instruct suffix std::vector antiprompt; // string upon seeing which more user input is prompted - std::vector stopprompt; // string upon seeing which more user input is prompted (without adding prefixes and suffixes) + std::vector stopprompt; // string upon seeing which more user input is prompted (without adding instruct prefixes and suffixes) + + bool rm_trailing_space_workaround = false; // workaround for removing trailing space from reverse/stop prompts bool memory_f16 = true; // use f16 instead of f32 for memory kv bool random_prompt = false; // do not randomize prompt if none provided diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 7ea9c3654..47969d218 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -85,6 +85,8 @@ int main(int argc, char ** argv) { params.prompt = gpt_random_prompt(rng); } + bool instruct_mode = params.instruct_prefix.empty() && params.instruct_suffix.empty(); + // params.prompt = R"(// this function checks if the number n is prime //bool is_prime(int n) {)"; @@ -154,10 +156,12 @@ int main(int argc, char ** argv) { } // prefix & suffix for instruct mode - const auto inp_pfx = ::llama_tokenize(ctx, params.input_prefix, params.input_prefix_bos); - std::string input_suffix = params.input_suffix; - if (input_suffix.back() == ' ') { input_suffix.pop_back(); } // (remove trailing space workaround) - const auto inp_sfx = ::llama_tokenize(ctx, input_suffix, params.input_suffix_bos); + const auto inp_pfx = ::llama_tokenize(ctx, params.instruct_prefix, params.instruct_prefix_bos); + std::string instruct_suffix = params.instruct_suffix; + if (params.rm_trailing_space_workaround) { + if (instruct_suffix.back() == ' ') { instruct_suffix.pop_back(); } + } + const auto inp_sfx = ::llama_tokenize(ctx, instruct_suffix, params.instruct_suffix_bos); // enable interactive mode if reverse prompt or interactive start is specified if (params.antiprompt.size() != 0 || params.stopprompt.size() != 0 || params.interactive_start) { @@ -209,10 +213,13 @@ int main(int argc, char ** argv) { } if (!params.input_prefix.empty()) { - fprintf(stderr, "Input prefix %s: '%s'\n", params.input_prefix_bos ? "(with bos token)" : "", params.input_prefix.c_str()); + fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); } - if (!params.input_suffix.empty()) { - fprintf(stderr, "Input suffix %s: '%s'\n", params.input_suffix_bos ? "(with bos token)" : "", params.input_suffix.c_str()); + if (!params.instruct_prefix.empty()) { + fprintf(stderr, "Instruct prefix %s: '%s'\n", params.instruct_prefix_bos ? "(with bos token)" : "", params.instruct_prefix.c_str()); + } + if (!params.instruct_suffix.empty()) { + fprintf(stderr, "Instruct suffix %s: '%s'\n", params.instruct_suffix_bos ? "(with bos token)" : "", params.instruct_suffix.c_str()); } } fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", @@ -232,9 +239,9 @@ int main(int argc, char ** argv) { ); if (params.multiline_mode) { #if defined (_WIN32) - fprintf(stderr, " - Press Ctrl+Z and Return (EOF) to return control to LLaMa.\n\n"); + fprintf(stderr, " - [MULTILINE MODE] Press Ctrl+Z and Return (EOF) to return control to LLaMa.\n\n"); #else - fprintf(stderr, " - Press Ctrl+D (EOF) to return control to LLaMa.\n\n"); + fprintf(stderr, " - [MULTILINE MODE] Press Ctrl+D (EOF) to return control to LLaMa.\n\n"); #endif } else { @@ -320,7 +327,7 @@ int main(int argc, char ** argv) { } // replace end of text token with newline token when in interactive mode - if (id == llama_token_eos() && params.interactive && params.input_prefix.empty()) { + if (id == llama_token_eos() && params.interactive && instruct_mode) { id = llama_token_newline.front(); if (params.antiprompt.size() != 0) { // tokenize and inject first reverse prompt @@ -377,8 +384,10 @@ int main(int argc, char ** argv) { antiprompt.is_stop_prompt = false; // Check if each of the reverse prompts appears at the end of the output. for (std::string & prompt : params.antiprompt) { - antiprompt.trailing_space = prompt.back() == ' '; - antiprompt.len = prompt.length() - (antiprompt.trailing_space ? 1 : 0); + if (params.rm_trailing_space_workaround) { + antiprompt.trailing_space = prompt.back() == ' '; + antiprompt.len = prompt.length() - (antiprompt.trailing_space ? 1 : 0); + } if (last_output.find(prompt.c_str(), last_output.length() - antiprompt.len, antiprompt.len) != std::string::npos) { is_interacting = true; antiprompt.any = true; @@ -389,8 +398,10 @@ int main(int argc, char ** argv) { } if (!antiprompt.any) { for (std::string & prompt : params.stopprompt) { - antiprompt.trailing_space = prompt.back() == ' '; - antiprompt.len = prompt.length() - (antiprompt.trailing_space ? 1 : 0); + if (params.rm_trailing_space_workaround) { + antiprompt.trailing_space = prompt.back() == ' '; + antiprompt.len = prompt.length() - (antiprompt.trailing_space ? 1 : 0); + } if (last_output.find(prompt.c_str(), last_output.length() - antiprompt.len, antiprompt.len) != std::string::npos) { is_interacting = true; antiprompt.any = true; @@ -406,21 +417,26 @@ int main(int argc, char ** argv) { if (n_past > 0 && is_interacting) { std::string buffer; - if (!params.clean_interface && !params.input_prefix.empty() && !antiprompt.any) { + if (!params.clean_interface && !params.instruct_prefix.empty() && !antiprompt.any) { // avoid printing again user's new line (TODO: try to revert enter press and print newline) - int i = params.input_prefix.front() == '\n' ? 1 : 0; + int i = params.instruct_prefix.front() == '\n' ? 1 : 0; for (; i < inp_pfx.size(); i++) { printf("%s", llama_token_to_str(ctx, inp_pfx[i])); } fflush(stdout); } - if (antiprompt.any && antiprompt.trailing_space) { - // add back removed trailing space to buffer(workaround) - buffer += ' '; - if (!params.clean_interface) { - printf("%s", buffer.c_str()); + if (params.rm_trailing_space_workaround) { + // add only if not stopprompt (as stopprompt could be used to pause + // assistant and then continue without input - adding back trailing + // space may mess it up.) + if (!antiprompt.is_stop_prompt && antiprompt.any && antiprompt.trailing_space) { + // add back removed trailing space to buffer(workaround) + buffer += ' '; + if (!params.clean_interface) { + printf("%s", buffer.c_str()); + } + fflush(stdout); } - fflush(stdout); } // potentially set color to indicate we are taking user input @@ -435,6 +451,11 @@ int main(int argc, char ** argv) { printf("\n> "); } + if (!params.input_prefix.empty()) { + buffer += params.input_prefix; + printf("%s", buffer.c_str()); + } + if (!get_input_text(buffer, !params.multiline_mode)) { // input stream is bad return 1; @@ -446,13 +467,16 @@ int main(int argc, char ** argv) { // done taking input, reset color set_console_color(con_st, CONSOLE_COLOR_DEFAULT); - if (!params.clean_interface && !params.input_suffix.empty() && !antiprompt.is_stop_prompt) { + if (!params.clean_interface && !params.instruct_suffix.empty() && !antiprompt.is_stop_prompt) { // avoid printing again user's new line (TODO: try to revert enter press and print newline) - int i = params.input_suffix.front() == '\n' ? 1 : 0; + int i = params.instruct_suffix.front() == '\n' ? 1 : 0; for (; i < inp_sfx.size(); i++) { printf("%s", llama_token_to_str(ctx, inp_sfx[i])); } - // we won't add back removed trailing space here (workaround) + // if (remove trailing space workaround) { + // We won't add back removed trailing space here, because assistant continues here, + // and it may mess up it's output (remove trailing space workaround). + // } fflush(stdout); } @@ -461,7 +485,7 @@ int main(int argc, char ** argv) { if (buffer.length() > 1) { // insert input prefix - if (!params.input_prefix.empty() && !antiprompt.any) { + if (!params.instruct_prefix.empty() && !antiprompt.any) { n_consumed = embd_inp.size(); embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); } @@ -470,7 +494,7 @@ int main(int argc, char ** argv) { embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); // insert response suffix - if (!params.input_suffix.empty() && !antiprompt.is_stop_prompt) { + if (!params.instruct_suffix.empty() && !antiprompt.is_stop_prompt) { embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); } @@ -487,7 +511,7 @@ int main(int argc, char ** argv) { // end of text token if (!embd.empty() && embd.back() == llama_token_eos()) { - if (params.interactive && !params.input_prefix.empty()) { + if (instruct_mode) { is_interacting = true; } else { fprintf(stderr, " [end of text]\n");