From cf7619522335100284f35b94be5104dded20b407 Mon Sep 17 00:00:00 2001 From: "Wang Haoran(Robin)" Date: Thu, 22 Jun 2023 21:35:37 -0700 Subject: [PATCH] server: fix issue when handling probability output for incomplete tokens for multibyte character generation --- examples/server/server.cpp | 135 ++++++++++++++++++++----------------- 1 file changed, 73 insertions(+), 62 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 4cad658d3..d9af1308e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -26,17 +26,6 @@ struct server_params { int32_t write_timeout = 600; }; -// completion string output with probabilities -struct completion_string_output { - struct token_prob { - std::string tok_str; - float prob; - }; - - std::vector probs; - std::string tok_str; -}; - // completion token output with probabilities struct completion_token_output { struct token_prob { @@ -108,6 +97,36 @@ static void server_log(const char * level, const char * function, int line, fflush(stdout); } +// format incomplete utf-8 multibyte character for output +static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { + const std::string out = token == -1 ? "" : llama_token_to_str(ctx, token); + if (out[0] > 127) { + out = "byte: \\x" + std::format("{:x}", out[0]); + } + return out; +} + +// convert a vector of completion_token_output to json +static json probs_vector_to_json(const llama_context * ctx, const vector probs) { + json out = json::array(); + for (const auto & prob : probs) { + json probs_for_token = json::array(); + for (const auto & p : prob.probs) { + std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok); + probs_for_token.push_back(json { + { "tok_str", tok_str }, + { "prob", p.prob }, + }); + } + std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok); + out.push_back(json { + {"content", tok_str}, + {"probs", probs_for_token}, + }); + } + return out; +} + static bool server_verbose = false; #if SERVER_VERBOSE != 1 @@ -129,7 +148,7 @@ struct llama_server_context { bool stream = false; bool has_next_token = false; std::string generated_text; - std::vector generated_text_probs; + std::vector generated_token_probs; size_t num_tokens_predicted = 0; size_t n_past = 0; @@ -160,7 +179,7 @@ struct llama_server_context { num_tokens_predicted = 0; generated_text = ""; generated_text.reserve(params.n_ctx); - generated_text_probs.clear(); + generated_token_probs.clear(); truncated = false; stopped_eos = false; stopped_word = false; @@ -406,22 +425,16 @@ struct llama_server_context { return stop_pos; } - completion_string_output doCompletion() { + completion_token_output doCompletion() { const completion_token_output token_with_probs = nextToken(); - completion_string_output result; const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(ctx, token_with_probs.tok); - result.tok_str = token_text; generated_text += token_text; - // iterate through token_with_probs.probs, if tok is valid, convert it to string and add to result.prob - for (const auto & prob : token_with_probs.probs) { - const std::string prob_text = prob.tok == -1 ? "" : llama_token_to_str(ctx, prob.tok); - result.probs.push_back({prob_text, prob.prob}); + if (params.n_probs > 0) { + generated_token_probs.push_back(token_with_probs); } - generated_text_probs.push_back(result); - if (multibyte_pending > 0) { multibyte_pending -= token_text.size(); } else if (token_text.size() == 1) { @@ -451,7 +464,7 @@ struct llama_server_context { LOG_VERBOSE("next token", { { "token", token_with_probs.tok }, - { "token_text", llama_token_to_str(ctx, token_with_probs.tok) }, + { "token_text", tokens_to_output_formatted_string(ctx, token_with_probs.tok) }, { "has_next_token", has_next_token }, { "n_remain", n_remain }, { "num_tokens_predicted", num_tokens_predicted }, @@ -461,7 +474,7 @@ struct llama_server_context { { "stopping_word", stopping_word }, }); - return result; + return token_with_probs; } std::vector getEmbedding() { @@ -713,26 +726,10 @@ static json format_embedding_response(llama_server_context & llama) { }; } -static json format_final_response(llama_server_context & llama, const std::string & content, const std::vector & probs) { +static json format_final_response(llama_server_context & llama, const std::string & content, const std::vector & probs) { - json completion_probabilities_json = json::array(); - for (const auto & prob : probs) { - json probs_for_token = json::array(); - for (const auto & p : prob.probs) { - probs_for_token.push_back(json { - { "tok_str", p.tok_str }, - { "prob", p.prob }, - }); - } - completion_probabilities_json.push_back(json { - {"content", prob.tok_str}, - {"probs", probs_for_token}, - }); - } - - return json { + json res = json { { "content", content }, - { "completion_probabilities", completion_probabilities_json}, { "stop", true }, { "model", llama.params.model_alias }, { "tokens_predicted", llama.num_tokens_predicted }, @@ -743,25 +740,25 @@ static json format_final_response(llama_server_context & llama, const std::strin { "stopped_word", llama.stopped_word }, { "stopped_limit", llama.stopped_limit }, { "stopping_word", llama.stopping_word }, - }; + } + + if (llama.params.n_probs > 0) { + json completion_probabilities_json = probs_vector_to_json(llama.ctx, probs); + res["completion_probabilities"] = completion_probabilities_json; + } + + return res; } -static json format_partial_response(const std::string & content, const completion_string_output & probs) { +static json format_partial_response(llama_server_context & llama, const std::string & content, const std::vector & probs) { json res = json { { "content", content }, { "stop", false }, }; - // iterate through probs.probs, and add to res - json probs_json = json::array(); - for (const auto & prob : probs.probs) { - probs_json.push_back(json { - { "tok_str", prob.tok_str }, - { "prob", prob.prob }, - }); - } - if (probs.probs.size() > 0) { - res["probs"] = probs_json; + if (llama.params.n_probs > 0) { + json completion_probabilities_json = probs_vector_to_json(llama.ctx, probs); + res["completion_probabilities"] = completion_probabilities_json; } return res; @@ -897,8 +894,8 @@ int main(int argc, char ** argv) { size_t stop_pos = std::string::npos; while (llama.has_next_token) { - const completion_string_output token_text_with_probs = llama.doCompletion(); - const std::string token_text = token_text_with_probs.tok_str; + const completion_token_output token_with_probs = llama.doCompletion(); + const std::string token_text = llama_token_to_str(llama.ctx, token_with_probs.tok); stop_pos = llama.findStoppingStrings(llama.generated_text, token_text.size(), STOP_FULL); @@ -912,7 +909,7 @@ int main(int argc, char ** argv) { llama.generated_text.end()); } - const json data = format_final_response(llama, llama.generated_text, llama.generated_text_probs); + const json data = format_final_response(llama, llama.generated_text, llama.generated_token_probs); llama_print_timings(llama.ctx); @@ -921,9 +918,11 @@ int main(int argc, char ** argv) { } else { const auto chunked_content_provider = [&](size_t, DataSink & sink) { size_t sent_count = 0; + size_t sent_token_probs_index = 0; while (llama.has_next_token) { - const completion_string_output token_text_with_probs = llama.doCompletion(); + const completion_token_output token_with_probs = llama.doCompletion(); + const std::string token_text = llama_token_to_str(llama.ctx, token_with_probs.tok); if (llama.multibyte_pending > 0) { continue; } @@ -932,24 +931,36 @@ int main(int argc, char ** argv) { const std::string str_test = llama.generated_text.substr(pos); size_t stop_pos = - llama.findStoppingStrings(str_test, token_text_with_probs.tok_str.size(), STOP_FULL); + llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL); if (stop_pos != std::string::npos) { llama.generated_text.erase( llama.generated_text.begin() + pos + stop_pos, llama.generated_text.end()); pos = std::min(sent_count, llama.generated_text.size()); } else { - stop_pos = llama.findStoppingStrings(str_test, token_text_with_probs.tok_str.size(), + stop_pos = llama.findStoppingStrings(str_test, token_text.size(), STOP_PARTIAL); } const std::string to_send = llama.generated_text.substr(pos, stop_pos); sent_count += to_send.size(); + std::vector probs_output = {}; + + if (llama.params.n_probs > 0) { + const std::vector to_send_toks = llama_tokenize(llama.ctx, to_send, false); + size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); + size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); + if (probs_pos < probs_stop_pos) { + probs_output = std::vector(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); + } + sent_token_probs_index = probs_stop_pos; + } + const json data = llama.has_next_token - ? format_partial_response(to_send, token_text_with_probs) + ? format_partial_response(llama, to_send, probs_output) // Generation is done, send extra information. - : format_final_response(llama, to_send, {token_text_with_probs}); + : format_final_response(llama, to_send, probs_output); const std::string str = "data: " +