From 1602ca681c43fd369a3072de3f729768b3d78338 Mon Sep 17 00:00:00 2001 From: yuguorui Date: Sun, 19 Mar 2023 13:37:24 +0800 Subject: [PATCH] Fix tokenization for variable-length characters The current tokenization logic assumes that each character is one byte long, and users will get garbled output when they enter multi-byte characters. Separate processing by Unicode allows users to obtain basic multilingual support, but tokenization may still fail because not all required Unicode is included in the LLAMA tokenizer. Signed-off-by: yuguorui --- utils.cpp | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/utils.cpp b/utils.cpp index efa2e3c35..b3c913860 100644 --- a/utils.cpp +++ b/utils.cpp @@ -8,6 +8,8 @@ #include #include #include +#include +#include #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW @@ -294,7 +296,10 @@ std::vector llama_tokenize(const gpt_vocab & vocab, const std::st std::vector res; std::vector score; std::vector prev; - int len = text.length(); + std::wstring_convert> converter; + + std::wstring wtext = converter.from_bytes(text); + int len = wtext.length(); score.resize(len + 1); prev.resize(len + 1); @@ -303,8 +308,8 @@ std::vector llama_tokenize(const gpt_vocab & vocab, const std::st for (int i = 0; i < len; i++) { int max_len = std::min(len - i, MAX_TOKEN_LEN); for (int sub_len = 1; sub_len <= max_len; sub_len++) { - auto sub = text.substr(i, sub_len); - auto token = vocab.token_to_id.find(sub); + auto sub = wtext.substr(i, sub_len); + auto token = vocab.token_to_id.find(converter.to_bytes(sub)); if (token != vocab.token_to_id.end()) { int token_score = sub.length() * sub.length(); int local_score = score[i] + token_score; @@ -322,12 +327,10 @@ std::vector llama_tokenize(const gpt_vocab & vocab, const std::st while (i > 0) { gpt_vocab::id token_id = prev[i]; if (token_id == 0) { - // TODO: Return error or something more meaningful - printf("failed to tokenize string!\n"); - break; + throw std::range_error("failed to tokenize string"); } res.push_back(token_id); - auto token = (*vocab.id_to_token.find(token_id)).second; + auto token = converter.from_bytes((*vocab.id_to_token.find(token_id)).second); i -= token.length(); }