From b11e63ce43f84ea870daeb18932b1907574ab958 Mon Sep 17 00:00:00 2001 From: Mathijs Henquet Date: Thu, 22 Aug 2024 00:32:28 +0200 Subject: [PATCH] Handle case if tokenizer splits along utf8 continuation bytes --- examples/server/README.md | 2 +- examples/server/server.cpp | 15 ++++++++++++++- examples/server/utils.hpp | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index dba47d94d..82f9a373f 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -508,7 +508,7 @@ Notice that each `probs` is an array of length `n_probs`. **Response:** -Returns a JSON object with a `tokens` field containing the tokenization result. The `tokens` array contains either just token IDs or objects with `id` and `piece` fields, depending on the `with_pieces` parameter. +Returns a JSON object with a `tokens` field containing the tokenization result. The `tokens` array contains either just token IDs or objects with `id` and `piece` fields, depending on the `with_pieces` parameter. The piece field is a string if the piece is valid unicode or a list of bytes otherwise. If `with_pieces` is `false`: diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 543092409..efb2121e0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3198,9 +3198,22 @@ int main(int argc, char ** argv) { if (with_pieces) { for (const auto& token : tokens) { std::string piece = llama_token_to_piece(ctx_server.ctx, token); + json piece_json; + + // Check if the piece is valid UTF-8 + if (is_valid_utf8(piece)) { + piece_json = piece; + } else { + // If not valid UTF-8, store as array of byte values + piece_json = json::array(); + for (unsigned char c : piece) { + piece_json.push_back(static_cast(c)); + } + } + tokens_response.push_back({ {"id", token}, - {"piece", piece} + {"piece", piece_json} }); } } else { diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 42635acca..6f81e4e6b 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -583,6 +583,39 @@ static json format_embeddings_response_oaicompat(const json & request, const jso return res; } +static bool is_valid_utf8(const std::string & str) { + const unsigned char* bytes = reinterpret_cast(str.data()); + const unsigned char* end = bytes + str.length(); + + while (bytes < end) { + if (*bytes <= 0x7F) { + // 1-byte sequence (0xxxxxxx) + bytes++; + } else if ((*bytes & 0xE0) == 0xC0) { + // 2-byte sequence (110xxxxx 10xxxxxx) + if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80) + return false; + bytes += 2; + } else if ((*bytes & 0xF0) == 0xE0) { + // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80) + return false; + bytes += 3; + } else if ((*bytes & 0xF8) == 0xF0) { + // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 || + (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80) + return false; + bytes += 4; + } else { + // Invalid UTF-8 lead byte + return false; + } + } + + return true; +} + static json format_tokenizer_response(const json & tokens) { return json { {"tokens", tokens}