From 3f436ea3f301293eba839223cdbc6049ea8bbc42 Mon Sep 17 00:00:00 2001 From: Jhen Date: Wed, 23 Aug 2023 15:52:49 +0800 Subject: [PATCH] avoid unnecessary empty data event & send rest of partial tokens on stop --- examples/server/server.cpp | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8f5433d8c..0b1c9ee04 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1330,39 +1330,42 @@ int main(int argc, char **argv) size_t pos = std::min(sent_count, llama.generated_text.size()); const std::string str_test = llama.generated_text.substr(pos); + bool is_stop_full = false; size_t stop_pos = llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL); if (stop_pos != std::string::npos) { + is_stop_full = true; llama.generated_text.erase( llama.generated_text.begin() + pos + stop_pos, llama.generated_text.end()); pos = std::min(sent_count, llama.generated_text.size()); } else { + is_stop_full = false; stop_pos = llama.findStoppingStrings(str_test, token_text.size(), STOP_PARTIAL); } - const std::string to_send = stop_pos == std::string::npos - ? llama.generated_text.substr(pos, std::string::npos) - : ""; // just don't send anything if we're not done + if ( + stop_pos == std::string::npos || + // Send rest of the text if we are at the end of the generation + (!llama.has_next_token && !is_stop_full && stop_pos > 0) + ) { + const std::string to_send = llama.generated_text.substr(pos, std::string::npos); - sent_count += to_send.size(); + sent_count += to_send.size(); - std::vector probs_output = {}; + std::vector probs_output = {}; - if (llama.params.n_probs > 0) { - const std::vector to_send_toks = llama_tokenize(llama.ctx, to_send, false); - size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); - size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); - if (probs_pos < probs_stop_pos) { - probs_output = std::vector(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); + if (llama.params.n_probs > 0) { + const std::vector to_send_toks = llama_tokenize(llama.ctx, to_send, false); + size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); + size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); + if (probs_pos < probs_stop_pos) { + probs_output = std::vector(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); + } + sent_token_probs_index = probs_stop_pos; } - sent_token_probs_index = probs_stop_pos; - } - { - // Always send partial response - // so we can get the correct partial response of the last to_send in the client const json data = format_partial_response(llama, to_send, probs_output); const std::string str =