Fix tokenization for variable-length characters

The current tokenization logic assumes that each character is one byte long,
and users will get garbled output when they enter multi-byte characters.

Separate processing by Unicode allows users to obtain basic multilingual
support, but tokenization may still fail because not all required Unicode is
included in the LLAMA tokenizer.

Signed-off-by: yuguorui <yuguorui@pku.edu.cn>
This commit is contained in:
yuguorui 2023-03-19 13:37:24 +08:00
parent d7def1a752
commit 1602ca681c

View file

@ -8,6 +8,8 @@
#include <iterator>
#include <string>
#include <math.h>
#include <locale>
#include <codecvt>
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
@ -294,7 +296,10 @@ std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, const std::st
std::vector<gpt_vocab::id> res;
std::vector<int> score;
std::vector<gpt_vocab::id> prev;
int len = text.length();
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
std::wstring wtext = converter.from_bytes(text);
int len = wtext.length();
score.resize(len + 1);
prev.resize(len + 1);
@ -303,8 +308,8 @@ std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, const std::st
for (int i = 0; i < len; i++) {
int max_len = std::min(len - i, MAX_TOKEN_LEN);
for (int sub_len = 1; sub_len <= max_len; sub_len++) {
auto sub = text.substr(i, sub_len);
auto token = vocab.token_to_id.find(sub);
auto sub = wtext.substr(i, sub_len);
auto token = vocab.token_to_id.find(converter.to_bytes(sub));
if (token != vocab.token_to_id.end()) {
int token_score = sub.length() * sub.length();
int local_score = score[i] + token_score;
@ -322,12 +327,10 @@ std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, const std::st
while (i > 0) {
gpt_vocab::id token_id = prev[i];
if (token_id == 0) {
// TODO: Return error or something more meaningful
printf("failed to tokenize string!\n");
break;
throw std::range_error("failed to tokenize string");
}
res.push_back(token_id);
auto token = (*vocab.id_to_token.find(token_id)).second;
auto token = converter.from_bytes((*vocab.id_to_token.find(token_id)).second);
i -= token.length();
}