From af5a6ceb12786d9733567f54d3f1383be649e0b5 Mon Sep 17 00:00:00 2001 From: qhduan Date: Sat, 18 Nov 2023 10:30:54 +0800 Subject: [PATCH] remove multibyte_pending from server --- examples/server/server.cpp | 75 +++++++++++++++++--------------------- 1 file changed, 33 insertions(+), 42 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 86f8b6b77..445b5f094 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -362,7 +362,6 @@ struct llama_client_slot int32_t num_prompt_tokens = 0; int32_t num_prompt_tokens_processed = 0; - int32_t multibyte_pending = 0; json prompt; std::string generated_text; @@ -405,7 +404,6 @@ struct llama_client_slot stopped_word = false; stopped_limit = false; stopping_word = ""; - multibyte_pending = 0; n_past = 0; sent_count = 0; sent_token_probs_index = 0; @@ -949,6 +947,37 @@ struct llama_server_context return stop_pos; } + bool is_valid_utf8(const std::string& str) { + int bytesToProcess = 0; + + for (unsigned char c : str) { + if (bytesToProcess == 0) { + if ((c >> 7) == 0x0) { + // 1-byte character + continue; + } else if ((c >> 5) == 0x6) { + // 2-byte character: 110xxxxx 10xxxxxx + bytesToProcess = 1; + } else if ((c >> 4) == 0xE) { + // 3-byte character: 1110xxxx 10xxxxxx 10xxxxxx + bytesToProcess = 2; + } else if ((c >> 3) == 0x1E) { + // 4-byte character: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + bytesToProcess = 3; + } else { + return false; // Invalid first byte of a character + } + } else { + if ((c >> 6) != 0x2) { // check this: 10xxxxxx + return false; // Invalid subsequent byte + } + --bytesToProcess; + } + } + + return bytesToProcess == 0; // True if all characters were processed completely + } + bool process_token(completion_token_output &result, llama_client_slot &slot) { // remember which tokens were sampled - used for repetition penalties during sampling const std::string token_str = llama_token_to_piece(ctx, result.tok); @@ -958,45 +987,7 @@ struct llama_server_context slot.generated_text += token_str; slot.has_next_token = true; - if (slot.multibyte_pending > 0) - { - slot.multibyte_pending -= token_str.size(); - } - else if (token_str.size() == 1) - { - const char c = token_str[0]; - // 2-byte characters: 110xxxxx 10xxxxxx - if ((c & 0xE0) == 0xC0) - { - slot.multibyte_pending = 1; - // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx - } - else if ((c & 0xF0) == 0xE0) - { - slot.multibyte_pending = 2; - // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx - } - else if ((c & 0xF8) == 0xF0) - { - slot.multibyte_pending = 3; - } - else - { - slot.multibyte_pending = 0; - } - } - else if (token_str.size() == 2) - { - const char c0 = token_str[0]; - const char c1 = token_str[1]; - if (((c0 & 0xF0) == 0xE0) && ((c1 & 0xC0) == 0x80)) - { - slot.multibyte_pending = 1; - // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx - } - } - - if (slot.multibyte_pending == 0) + if (is_valid_utf8(token_str)) { size_t pos = std::min(slot.sent_count, slot.generated_text.size()); const std::string str_test = slot.generated_text.substr(pos); @@ -1031,7 +1022,7 @@ struct llama_server_context } } - if (slot.multibyte_pending > 0 && !slot.has_next_token) + if (!is_valid_utf8(token_str) && !slot.has_next_token) { slot.has_next_token = true; }