Initial implementation of stop keywords
This commit is contained in:
parent
16ffc013c6
commit
3eed8c0914
3 changed files with 70 additions and 43 deletions
110
main.cpp
110
main.cpp
|
@ -1016,6 +1016,13 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(params.stop_keyword.size()) {
|
||||||
|
for (auto stop_keyword : params.stop_keyword) {
|
||||||
|
fprintf(stderr, "Stop keyword: '%s'\n", stop_keyword.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
|
fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
|
||||||
fprintf(stderr, "\n\n");
|
fprintf(stderr, "\n\n");
|
||||||
|
|
||||||
|
@ -1129,58 +1136,75 @@ int main(int argc, char ** argv) {
|
||||||
printf(ANSI_COLOR_RESET);
|
printf(ANSI_COLOR_RESET);
|
||||||
}
|
}
|
||||||
|
|
||||||
// in interactive mode, and not currently processing queued inputs;
|
// If we are not processing queued inputs, check for reverse prompt and stop keywords
|
||||||
// check if we should prompt the user for more
|
if((int) embd_inp.size() <= input_consumed) {
|
||||||
if (params.interactive && (int) embd_inp.size() <= input_consumed) {
|
// Build the output string
|
||||||
// check for reverse prompt
|
// TODO - Recomputing this whole string every iteration is not efficient
|
||||||
std::string last_output;
|
std::string last_output;
|
||||||
for (auto id : last_n_tokens) {
|
for (auto id : last_n_tokens) {
|
||||||
last_output += vocab.id_to_token[id];
|
last_output += vocab.id_to_token[id];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if each of the reverse prompts appears at the end of the output.
|
// Check for stop keywords
|
||||||
for (std::string antiprompt : params.antiprompt) {
|
bool stop = false;
|
||||||
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
|
for (std::string stop_keyword : params.stop_keyword) {
|
||||||
is_interacting = true;
|
if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) {
|
||||||
|
stop = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (is_interacting) {
|
if(stop) {
|
||||||
if (params.instruct) {
|
break;
|
||||||
input_consumed = embd_inp.size();
|
}
|
||||||
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
|
|
||||||
|
// in interactive mode, and not currently processing queued inputs;
|
||||||
printf("\n> ");
|
// check if we should prompt the user for more
|
||||||
}
|
if (params.interactive) {
|
||||||
|
|
||||||
// currently being interactive
|
// Check if each of the reverse prompts appears at the end of the output.
|
||||||
if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
|
for (std::string antiprompt : params.antiprompt) {
|
||||||
std::string buffer;
|
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
|
||||||
std::string line;
|
is_interacting = true;
|
||||||
bool another_line = true;
|
break;
|
||||||
do {
|
}
|
||||||
std::getline(std::cin, line);
|
}
|
||||||
if (line.empty() || line.back() != '\\') {
|
if (is_interacting) {
|
||||||
another_line = false;
|
if (params.instruct) {
|
||||||
} else {
|
input_consumed = embd_inp.size();
|
||||||
line.pop_back(); // Remove the continue character
|
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
|
||||||
}
|
|
||||||
buffer += line + '\n'; // Append the line to the result
|
printf("\n> ");
|
||||||
} while (another_line);
|
}
|
||||||
if (params.use_color) printf(ANSI_COLOR_RESET);
|
|
||||||
|
// currently being interactive
|
||||||
std::vector<llama_vocab::id> line_inp = ::llama_tokenize(vocab, buffer, false);
|
if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
|
||||||
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
|
std::string buffer;
|
||||||
|
std::string line;
|
||||||
if (params.instruct) {
|
bool another_line = true;
|
||||||
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
|
do {
|
||||||
}
|
std::getline(std::cin, line);
|
||||||
|
if (line.empty() || line.back() != '\\') {
|
||||||
remaining_tokens -= line_inp.size();
|
another_line = false;
|
||||||
|
} else {
|
||||||
input_noecho = true; // do not echo this again
|
line.pop_back(); // Remove the continue character
|
||||||
|
}
|
||||||
|
buffer += line + '\n'; // Append the line to the result
|
||||||
|
} while (another_line);
|
||||||
|
if (params.use_color) printf(ANSI_COLOR_RESET);
|
||||||
|
|
||||||
|
std::vector<llama_vocab::id> line_inp = ::llama_tokenize(vocab, buffer, false);
|
||||||
|
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
|
||||||
|
|
||||||
|
if (params.instruct) {
|
||||||
|
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
remaining_tokens -= line_inp.size();
|
||||||
|
|
||||||
|
input_noecho = true; // do not echo this again
|
||||||
|
}
|
||||||
|
is_interacting = false;
|
||||||
}
|
}
|
||||||
is_interacting = false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// end of text token
|
// end of text token
|
||||||
|
|
|
@ -72,6 +72,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
params.use_color = true;
|
params.use_color = true;
|
||||||
} else if (arg == "-r" || arg == "--reverse-prompt") {
|
} else if (arg == "-r" || arg == "--reverse-prompt") {
|
||||||
params.antiprompt.push_back(argv[++i]);
|
params.antiprompt.push_back(argv[++i]);
|
||||||
|
} else if (arg == "--stop") {
|
||||||
|
params.stop_keyword.push_back(argv[++i]);
|
||||||
} else if (arg == "--perplexity") {
|
} else if (arg == "--perplexity") {
|
||||||
params.perplexity = true;
|
params.perplexity = true;
|
||||||
} else if (arg == "--ignore-eos") {
|
} else if (arg == "--ignore-eos") {
|
||||||
|
|
1
utils.h
1
utils.h
|
@ -32,6 +32,7 @@ struct gpt_params {
|
||||||
std::string prompt = "";
|
std::string prompt = "";
|
||||||
|
|
||||||
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
||||||
|
std::vector<std::string> stop_keyword; // string upon seeing which the model will stop
|
||||||
|
|
||||||
bool memory_f16 = false; // use f16 instead of f32 for memory kv
|
bool memory_f16 = false; // use f16 instead of f32 for memory kv
|
||||||
bool random_prompt = false; // do not randomize prompt if none provided
|
bool random_prompt = false; // do not randomize prompt if none provided
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue