Check for reverse prompt by characters instead of tokens (#292)
This commit is contained in:
parent
074bea2eb1
commit
b646ffa1b1
1 changed files with 19 additions and 20 deletions
39
main.cpp
39
main.cpp
|
@ -12,6 +12,7 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||||
#include <signal.h>
|
#include <signal.h>
|
||||||
|
@ -877,16 +878,9 @@ int main(int argc, char ** argv) {
|
||||||
params.interactive = true;
|
params.interactive = true;
|
||||||
params.antiprompt.push_back("### Instruction:\n\n");
|
params.antiprompt.push_back("### Instruction:\n\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
// tokenize the reverse prompt
|
|
||||||
std::vector<std::vector<gpt_vocab::id>> antipromptv_inp;
|
|
||||||
|
|
||||||
for (auto antiprompt : params.antiprompt) {
|
|
||||||
antipromptv_inp.push_back(::llama_tokenize(vocab, antiprompt, false));
|
|
||||||
}
|
|
||||||
|
|
||||||
// enable interactive mode if reverse prompt is specified
|
// enable interactive mode if reverse prompt is specified
|
||||||
if (antipromptv_inp.size() != 0) {
|
if (params.antiprompt.size() != 0) {
|
||||||
params.interactive = true;
|
params.interactive = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -910,15 +904,9 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
fprintf(stderr, "%s: interactive mode on.\n", __func__);
|
fprintf(stderr, "%s: interactive mode on.\n", __func__);
|
||||||
|
|
||||||
if(antipromptv_inp.size()) {
|
if(params.antiprompt.size()) {
|
||||||
for (size_t apindex = 0; apindex < antipromptv_inp.size(); ++apindex) {
|
for (auto antiprompt : params.antiprompt) {
|
||||||
auto antiprompt_inp = antipromptv_inp.at(apindex);
|
fprintf(stderr, "Antiprompt: %s\n", antiprompt);
|
||||||
fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.at(apindex).c_str());
|
|
||||||
fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size());
|
|
||||||
for (int i = 0; i < (int) antiprompt_inp.size(); i++) {
|
|
||||||
fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str());
|
|
||||||
}
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1035,12 +1023,23 @@ int main(int argc, char ** argv) {
|
||||||
// check if we should prompt the user for more
|
// check if we should prompt the user for more
|
||||||
if (params.interactive && embd_inp.size() <= input_consumed) {
|
if (params.interactive && embd_inp.size() <= input_consumed) {
|
||||||
// check for reverse prompt
|
// check for reverse prompt
|
||||||
for (auto antiprompt_inp : antipromptv_inp) {
|
|
||||||
if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {
|
std::stringstream last_output_ss;
|
||||||
// reverse prompt found
|
for (auto id : last_n_tokens) {
|
||||||
|
last_output_ss << vocab.id_to_token[id];
|
||||||
|
}
|
||||||
|
std::string last_output = last_output_ss.str();
|
||||||
|
|
||||||
|
for (std::string antiprompt : params.antiprompt) {
|
||||||
|
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
|
||||||
is_interacting = true;
|
is_interacting = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
/*if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {
|
||||||
|
// reverse prompt found
|
||||||
|
is_interacting = true;
|
||||||
|
break;
|
||||||
|
}*/
|
||||||
}
|
}
|
||||||
if (is_interacting) {
|
if (is_interacting) {
|
||||||
if (params.instruct) {
|
if (params.instruct) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue