fix stopping strings

This commit is contained in:
anon 2023-05-31 20:31:58 -03:00
parent 342604bb81
commit e9b1f0bf5c

View file

@ -20,6 +20,33 @@ static size_t common_part(const std::vector<llama_token> & a, const std::vector<
return i; return i;
} }
enum stop_type {
STOP_FULL,
STOP_PARTIAL,
};
bool ends_with(const std::string &str, const std::string &suffix)
{
return str.size() >= suffix.size() &&
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}
size_t find_partial_stop_string(const std::string &stop, const std::string &text)
{
if (!text.empty()) {
const char text_last_char = text.back();
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
if (stop[char_index] == text_last_char) {
const std::string current_partial = stop.substr(0, char_index + 1);
if (ends_with(text, current_partial)) {
return text.size() - char_index - 1;
}
}
}
}
return std::string::npos;
}
struct llama_server_context struct llama_server_context
{ {
bool stream = false; bool stream = false;
@ -248,6 +275,31 @@ struct llama_server_context
return result; return result;
} }
size_t findStoppingStrings(const std::string &text, const size_t last_token_size,
const stop_type type)
{
size_t stop_pos = std::string::npos;
for (const std::string &word : params.antiprompt) {
size_t pos;
if (type == STOP_FULL) {
const size_t tmp = word.size() + last_token_size;
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
pos = text.find(word, from_pos);
} else {
pos = find_partial_stop_string(word, text);
}
if (pos != std::string::npos &&
(stop_pos == std::string::npos || pos < stop_pos)) {
if (type == STOP_FULL) {
stopping_word = word;
has_next_token = false;
}
stop_pos = pos;
}
}
return stop_pos;
}
std::string doCompletion() std::string doCompletion()
{ {
llama_token token = nextToken(); llama_token token = nextToken();
@ -272,16 +324,6 @@ struct llama_server_context
stopping_word.c_str()); stopping_word.c_str());
} }
for (const std::string& word : params.antiprompt) {
size_t i = generated_text.find(word, generated_text.size() - (word.size() + token_text.size()));
if (i != std::string::npos) {
generated_text.erase(generated_text.begin() + i, generated_text.end());
stopping_word = word;
has_next_token = false;
break;
}
}
return token_text; return token_text;
} }
@ -711,7 +753,14 @@ int main(int argc, char **argv)
if (!llama.stream) { if (!llama.stream) {
while (llama.has_next_token) { while (llama.has_next_token) {
llama.doCompletion(); const std::string token_text = llama.doCompletion();
const size_t stop_pos = llama.findStoppingStrings(
llama.generated_text, token_text.size(), STOP_FULL);
if (stop_pos != std::string::npos) {
llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
llama.generated_text.end());
}
} }
json data = {{"content", llama.generated_text}, json data = {{"content", llama.generated_text},
@ -724,7 +773,7 @@ int main(int argc, char **argv)
llama_print_timings(llama.ctx); llama_print_timings(llama.ctx);
return res.set_content( res.set_content(
data.dump(llama.json_indent, ' ', false, json::error_handler_t::replace), data.dump(llama.json_indent, ' ', false, json::error_handler_t::replace),
"application/json"); "application/json");
} else { } else {
@ -733,7 +782,7 @@ int main(int argc, char **argv)
int32_t multibyte_pending = 0; int32_t multibyte_pending = 0;
while (llama.has_next_token) { while (llama.has_next_token) {
std::string token_text = llama.doCompletion(); const std::string token_text = llama.doCompletion();
if (multibyte_pending > 0) { if (multibyte_pending > 0) {
multibyte_pending -= token_text.size(); multibyte_pending -= token_text.size();
@ -761,8 +810,22 @@ int main(int argc, char **argv)
continue; continue;
} }
const size_t pos = std::min(sent_count, llama.generated_text.size()); size_t pos = std::min(sent_count, llama.generated_text.size());
std::string to_send = llama.generated_text.substr(pos);
const char *str_test = llama.generated_text.c_str() + pos;
size_t stop_pos =
llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
if (stop_pos != std::string::npos) {
llama.generated_text.erase(
llama.generated_text.begin() + pos + stop_pos,
llama.generated_text.end());
pos = std::min(sent_count, llama.generated_text.size());
} else {
stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
STOP_PARTIAL);
}
std::string to_send = llama.generated_text.substr(pos, stop_pos);
sent_count += to_send.size(); sent_count += to_send.size();
json data; json data;
@ -808,7 +871,6 @@ int main(int argc, char **argv)
} }
}); });
svr.Post("/tokenize", [&llama](const Request &req, Response &res) svr.Post("/tokenize", [&llama](const Request &req, Response &res)
{ {
json body = json::parse(req.body); json body = json::parse(req.body);