From e9b1f0bf5c6b810a8f77c1173fb4b6bb6b6f72e8 Mon Sep 17 00:00:00 2001 From: anon Date: Wed, 31 May 2023 20:31:58 -0300 Subject: [PATCH] fix stopping strings --- examples/server/server.cpp | 94 +++++++++++++++++++++++++++++++------- 1 file changed, 78 insertions(+), 16 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ecec71db8..c12a84fa7 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -20,6 +20,33 @@ static size_t common_part(const std::vector & a, const std::vector< return i; } +enum stop_type { + STOP_FULL, + STOP_PARTIAL, +}; + +bool ends_with(const std::string &str, const std::string &suffix) +{ + return str.size() >= suffix.size() && + 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); +} + +size_t find_partial_stop_string(const std::string &stop, const std::string &text) +{ + if (!text.empty()) { + const char text_last_char = text.back(); + for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { + if (stop[char_index] == text_last_char) { + const std::string current_partial = stop.substr(0, char_index + 1); + if (ends_with(text, current_partial)) { + return text.size() - char_index - 1; + } + } + } + } + return std::string::npos; +} + struct llama_server_context { bool stream = false; @@ -248,6 +275,31 @@ struct llama_server_context return result; } + size_t findStoppingStrings(const std::string &text, const size_t last_token_size, + const stop_type type) + { + size_t stop_pos = std::string::npos; + for (const std::string &word : params.antiprompt) { + size_t pos; + if (type == STOP_FULL) { + const size_t tmp = word.size() + last_token_size; + const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; + pos = text.find(word, from_pos); + } else { + pos = find_partial_stop_string(word, text); + } + if (pos != std::string::npos && + (stop_pos == std::string::npos || pos < stop_pos)) { + if (type == STOP_FULL) { + stopping_word = word; + has_next_token = false; + } + stop_pos = pos; + } + } + return stop_pos; + } + std::string doCompletion() { llama_token token = nextToken(); @@ -272,16 +324,6 @@ struct llama_server_context stopping_word.c_str()); } - for (const std::string& word : params.antiprompt) { - size_t i = generated_text.find(word, generated_text.size() - (word.size() + token_text.size())); - if (i != std::string::npos) { - generated_text.erase(generated_text.begin() + i, generated_text.end()); - stopping_word = word; - has_next_token = false; - break; - } - } - return token_text; } @@ -711,7 +753,14 @@ int main(int argc, char **argv) if (!llama.stream) { while (llama.has_next_token) { - llama.doCompletion(); + const std::string token_text = llama.doCompletion(); + const size_t stop_pos = llama.findStoppingStrings( + llama.generated_text, token_text.size(), STOP_FULL); + + if (stop_pos != std::string::npos) { + llama.generated_text.erase(llama.generated_text.begin() + stop_pos, + llama.generated_text.end()); + } } json data = {{"content", llama.generated_text}, @@ -724,7 +773,7 @@ int main(int argc, char **argv) llama_print_timings(llama.ctx); - return res.set_content( + res.set_content( data.dump(llama.json_indent, ' ', false, json::error_handler_t::replace), "application/json"); } else { @@ -733,7 +782,7 @@ int main(int argc, char **argv) int32_t multibyte_pending = 0; while (llama.has_next_token) { - std::string token_text = llama.doCompletion(); + const std::string token_text = llama.doCompletion(); if (multibyte_pending > 0) { multibyte_pending -= token_text.size(); @@ -761,8 +810,22 @@ int main(int argc, char **argv) continue; } - const size_t pos = std::min(sent_count, llama.generated_text.size()); - std::string to_send = llama.generated_text.substr(pos); + size_t pos = std::min(sent_count, llama.generated_text.size()); + + const char *str_test = llama.generated_text.c_str() + pos; + size_t stop_pos = + llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL); + if (stop_pos != std::string::npos) { + llama.generated_text.erase( + llama.generated_text.begin() + pos + stop_pos, + llama.generated_text.end()); + pos = std::min(sent_count, llama.generated_text.size()); + } else { + stop_pos = llama.findStoppingStrings(str_test, token_text.size(), + STOP_PARTIAL); + } + + std::string to_send = llama.generated_text.substr(pos, stop_pos); sent_count += to_send.size(); json data; @@ -808,7 +871,6 @@ int main(int argc, char **argv) } }); - svr.Post("/tokenize", [&llama](const Request &req, Response &res) { json body = json::parse(req.body);