From 2b496429c3a42f91b8c62beecf3e7cbe072fbfbf Mon Sep 17 00:00:00 2001 From: "Berthold, Alexander" Date: Thu, 8 Jun 2023 12:41:34 +0200 Subject: [PATCH] Fix crash in server example caused by oob due to no show words scanning - Replaced scanning code by lookahead based strategy --- examples/server/server.cpp | 128 +++++++++++++++++++------------------ 1 file changed, 66 insertions(+), 62 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 31d8087ef..f22529ca6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1,5 +1,5 @@ -#include -#include +#include "httplib.h" +#include "json.hpp" #include "common.h" #include "llama.h" @@ -27,7 +27,8 @@ struct llama_server_context std::vector llama_token_newline; std::vector embd_inp; std::vector> no_show_words; - std::vector tokens_predicted; + std::deque tokens_predicted; + std::vector::size_type n_read_ahead = 0; llama_context *ctx; gpt_params params; @@ -294,63 +295,53 @@ struct llama_server_context return result; } - std::string doCompletion() - { - llama_token token = nextToken(); - if (token == -1) { - return ""; - } - tokens_predicted.clear(); - tokens_predicted.push_back(token); - - // Avoid add the no show words to the response - for (std::vector word_tokens : no_show_words) - { - size_t match_token = 1; - if (tokens_predicted.front() == word_tokens.front()) - { - bool execute_matching = true; - if (tokens_predicted.size() > 1) { // if previus tokens had been tested - for (size_t i = 1; i < word_tokens.size(); i++) - { - if (i >= tokens_predicted.size()) { - match_token = i; - break; - } - if (tokens_predicted[i] == word_tokens[i]) - { - continue; - } - else - { - execute_matching = false; - break; - } - } - } - while (execute_matching) { - if (match_token == word_tokens.size()) { - return ""; - } - token = nextToken(); - tokens_predicted.push_back(token); - if (token == word_tokens[match_token]) - { // the token follow the sequence - match_token++; - } - else if (match_token < word_tokens.size()) - { // no complete all word sequence - break; - } - } - } - } - if(as_loop) { + std::string doCompletion() { + if (as_loop) { generated_text = ""; } - for (llama_token tkn : tokens_predicted) - { - generated_text += llama_token_to_str(ctx, tkn); + + // Avoid add the no show words to the response + bool removed_no_show_words; + bool past_end_of_tokens = false; + do { + removed_no_show_words = false; + + // Fill predicted tokens to `read_ahead` tokens if possible + while (tokens_predicted.size() < n_read_ahead) { + llama_token token = nextToken(); + if (token == -1) { + past_end_of_tokens = true; + break; + } + tokens_predicted.push_back(token); + } + + // Remove sequences of no_show_words in `predicted_tokens` + for (const auto &no_show : no_show_words) { + + const auto &occurrence = + std::search(tokens_predicted.begin(), tokens_predicted.end(), + no_show.begin(), no_show.end()); + + if (occurrence != tokens_predicted.end()) { + tokens_predicted.erase(occurrence, occurrence + no_show.size()); + removed_no_show_words = true; + } + } + + // Continue until end of tokens or as long as sequences have been removed + } while (removed_no_show_words && !past_end_of_tokens); + + if (past_end_of_tokens) { + // If end of tokens, return all and clear + for (llama_token tkn : tokens_predicted) { + generated_text += llama_token_to_str(ctx, tkn); + } + tokens_predicted.clear(); + } else { + // Else just pick the 1st token and add it + generated_text += llama_token_to_str(ctx, tokens_predicted[0]); + tokens_predicted.pop_front(); } return generated_text; } @@ -476,6 +467,15 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para { params.embedding = true; } + else if (arg == "--keep") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + params.n_keep = std::stoi(argv[i]); + } else if (arg == "-h" || arg == "--help") { server_print_usage(argc, argv, default_params); @@ -622,18 +622,22 @@ bool parse_options_completion(json body, llama_server_context& llama, Response & if (!body["stop"].is_null()) { std::vector stop_words = body["stop"].get>(); - for (std::string stop_word : stop_words) + for (const std::string& stop_word : stop_words) { llama.params.antiprompt.push_back(stop_word); - llama.no_show_words.push_back(::llama_tokenize(llama.ctx, stop_word, false)); + auto tokens = ::llama_tokenize(llama.ctx, stop_word, false); + llama.n_read_ahead = std::max(llama.n_read_ahead, tokens.size()); + llama.no_show_words.push_back(tokens); } } if (!body["exclude"].is_null()) { std::vector no_show_words = body["exclude"].get>(); - for (std::string no_show : no_show_words) + for (const std::string& no_show : no_show_words) { - llama.no_show_words.push_back(::llama_tokenize(llama.ctx, no_show, false)); + auto tokens = ::llama_tokenize(llama.ctx, no_show, false); + llama.n_read_ahead = std::max(llama.n_read_ahead, tokens.size()); + llama.no_show_words.push_back(tokens); } } return true;