simplify code

This commit is contained in:
Evan Jones 2023-05-09 23:16:40 -04:00
parent b4d04d1613
commit 2041e1e0b5
3 changed files with 25 additions and 38 deletions

View file

@ -365,8 +365,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");
fprintf(stderr, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n");
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n");
fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT\n");
fprintf(stderr, " (can be specified more than once for multiple reverse prompts).\n");
fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n");
fprintf(stderr, " specified more than once for multiple prompts).\n");
fprintf(stderr, " --stop KEYWORD a string that, when output by the model, will stop generation\n");
fprintf(stderr, " (can be specified more than once for multiple keywords).\n");
fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n");

View file

@ -46,10 +46,10 @@ struct gpt_params {
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
std::string prompt = "";
std::string path_session = ""; // path to file for saving/loading model eval state
std::string input_prefix = ""; // string to prefix user inputs with
std::string input_suffix = ""; // string to suffix user inputs with
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string path_session = ""; // path to file for saving/loading model eval state
std::string input_prefix = ""; // string to prefix user inputs with
std::string input_suffix = ""; // string to suffix user inputs with
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::vector<std::string> stop_keywords; // string upon seeing which the model will stop
std::string lora_adapter = ""; // lora adapter path

View file

@ -266,7 +266,7 @@ int main(int argc, char ** argv) {
}
if (params.stop_keywords.size()) {
for (auto stop_keyword : params.stop_keywords) {
for (auto & stop_keyword : params.stop_keywords) {
fprintf(stderr, "Stop keyword: '%s'\n", stop_keyword.c_str());
}
}
@ -516,22 +516,17 @@ int main(int argc, char ** argv) {
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
}
// in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more
if (params.interactive && (int) embd_inp.size() <= n_consumed) {
// check for stop keywords if we're processing generations
if (params.stop_keywords.size() && (int) embd_inp.size() <= n_consumed) {
std::string last_output;
if (params.antiprompt.size() || params.stop_keywords.size()) {
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}
// Check for stop keywords, a configurable alternative to the end-of-text token
// This should stop also the interactive mode, useful to stop interactive mode without SIGTERM
bool stop = false;
for (std::string stop_keyword : params.stop_keywords) {
if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) {
for (auto & stop_keyword : params.stop_keywords) {
const size_t stop_pos = last_output.find(stop_keyword.c_str(),
last_output.length() - stop_keyword.length(), stop_keyword.length());
if (stop_pos != std::string::npos) {
stop = true;
break;
}
@ -539,9 +534,19 @@ int main(int argc, char ** argv) {
if (stop) {
break;
}
}
// in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more
if (params.interactive && (int) embd_inp.size() <= n_consumed) {
// check for reverse prompt
if (params.antiprompt.size()) {
std::string last_output;
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}
is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output.
for (std::string & antiprompt : params.antiprompt) {
@ -608,24 +613,6 @@ int main(int argc, char ** argv) {
}
}
// Check for stop keywords, a configurable alternative to the end-of-text token
if (!params.interactive && params.stop_keywords.size() && !is_interacting) {
std::string last_output;
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}
bool stop = false;
for (std::string stop_keyword : params.stop_keywords) {
if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) {
stop = true;
break;
}
}
if (stop) {
break;
}
}
// end of text token
if (!embd.empty() && embd.back() == llama_token_eos()) {
if (params.instruct) {