Initial implementation of stop keywords

This commit is contained in:
Joshua Williams 2023-03-21 12:26:36 -05:00
parent 16ffc013c6
commit 3eed8c0914
3 changed files with 70 additions and 43 deletions

110
main.cpp
View file

@ -1016,6 +1016,13 @@ int main(int argc, char ** argv) {
} }
} }
} }
if(params.stop_keyword.size()) {
for (auto stop_keyword : params.stop_keyword) {
fprintf(stderr, "Stop keyword: '%s'\n", stop_keyword.c_str());
}
}
fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
fprintf(stderr, "\n\n"); fprintf(stderr, "\n\n");
@ -1129,58 +1136,75 @@ int main(int argc, char ** argv) {
printf(ANSI_COLOR_RESET); printf(ANSI_COLOR_RESET);
} }
// in interactive mode, and not currently processing queued inputs; // If we are not processing queued inputs, check for reverse prompt and stop keywords
// check if we should prompt the user for more if((int) embd_inp.size() <= input_consumed) {
if (params.interactive && (int) embd_inp.size() <= input_consumed) { // Build the output string
// check for reverse prompt // TODO - Recomputing this whole string every iteration is not efficient
std::string last_output; std::string last_output;
for (auto id : last_n_tokens) { for (auto id : last_n_tokens) {
last_output += vocab.id_to_token[id]; last_output += vocab.id_to_token[id];
} }
// Check if each of the reverse prompts appears at the end of the output. // Check for stop keywords
for (std::string antiprompt : params.antiprompt) { bool stop = false;
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { for (std::string stop_keyword : params.stop_keyword) {
is_interacting = true; if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) {
stop = true;
break; break;
} }
} }
if (is_interacting) { if(stop) {
if (params.instruct) { break;
input_consumed = embd_inp.size(); }
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
// in interactive mode, and not currently processing queued inputs;
printf("\n> "); // check if we should prompt the user for more
} if (params.interactive) {
// currently being interactive // Check if each of the reverse prompts appears at the end of the output.
if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN); for (std::string antiprompt : params.antiprompt) {
std::string buffer; if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
std::string line; is_interacting = true;
bool another_line = true; break;
do { }
std::getline(std::cin, line); }
if (line.empty() || line.back() != '\\') { if (is_interacting) {
another_line = false; if (params.instruct) {
} else { input_consumed = embd_inp.size();
line.pop_back(); // Remove the continue character embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
}
buffer += line + '\n'; // Append the line to the result printf("\n> ");
} while (another_line); }
if (params.use_color) printf(ANSI_COLOR_RESET);
// currently being interactive
std::vector<llama_vocab::id> line_inp = ::llama_tokenize(vocab, buffer, false); if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); std::string buffer;
std::string line;
if (params.instruct) { bool another_line = true;
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); do {
} std::getline(std::cin, line);
if (line.empty() || line.back() != '\\') {
remaining_tokens -= line_inp.size(); another_line = false;
} else {
input_noecho = true; // do not echo this again line.pop_back(); // Remove the continue character
}
buffer += line + '\n'; // Append the line to the result
} while (another_line);
if (params.use_color) printf(ANSI_COLOR_RESET);
std::vector<llama_vocab::id> line_inp = ::llama_tokenize(vocab, buffer, false);
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
if (params.instruct) {
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
}
remaining_tokens -= line_inp.size();
input_noecho = true; // do not echo this again
}
is_interacting = false;
} }
is_interacting = false;
} }
// end of text token // end of text token

View file

@ -72,6 +72,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.use_color = true; params.use_color = true;
} else if (arg == "-r" || arg == "--reverse-prompt") { } else if (arg == "-r" || arg == "--reverse-prompt") {
params.antiprompt.push_back(argv[++i]); params.antiprompt.push_back(argv[++i]);
} else if (arg == "--stop") {
params.stop_keyword.push_back(argv[++i]);
} else if (arg == "--perplexity") { } else if (arg == "--perplexity") {
params.perplexity = true; params.perplexity = true;
} else if (arg == "--ignore-eos") { } else if (arg == "--ignore-eos") {

View file

@ -32,6 +32,7 @@ struct gpt_params {
std::string prompt = ""; std::string prompt = "";
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::vector<std::string> stop_keyword; // string upon seeing which the model will stop
bool memory_f16 = false; // use f16 instead of f32 for memory kv bool memory_f16 = false; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided bool random_prompt = false; // do not randomize prompt if none provided