llama : prefix input text for tokenization with whitespace

This commit is contained in:
Georgi Gerganov 2023-08-26 17:08:59 +03:00
parent 5cad62bce4
commit 5d0ffb69f5
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 96 additions and 64 deletions

View file

@ -56,9 +56,6 @@ int main(int argc, char ** argv) {
int n_past = 0; int n_past = 0;
// Add a space in front of the first character to match OG llama tokenizer behavior
params.prompt.insert(0, 1, ' ');
// tokenize the prompt // tokenize the prompt
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true); auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);

View file

@ -195,11 +195,6 @@ int main(int argc, char ** argv) {
// tokenize the prompt // tokenize the prompt
std::vector<llama_token> embd_inp; std::vector<llama_token> embd_inp;
if (llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM) {
// Add a space in front of the first character to match OG llama tokenizer behavior
params.prompt.insert(0, 1, ' ');
}
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos); embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
} else { } else {

View file

@ -1635,7 +1635,7 @@ static void llm_load_hparams(
} }
// TODO: This should probably be in llama.h // TODO: This should probably be in llama.h
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos); static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos);
static void llm_load_vocab( static void llm_load_vocab(
llama_model_loader & ml, llama_model_loader & ml,
@ -3026,10 +3026,8 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
return vocab.token_to_id.at(buf); return vocab.token_to_id.at(buf);
} }
static std::string llama_escape_whitespace(const std::string& text) { static void llama_escape_whitespace(std::string & text) {
std::string result = text; replace_all(text, " ", "\xe2\x96\x81");
replace_all(result, " ", "\xe2\x96\x81");
return result;
} }
static void llama_unescape_whitespace(std::string & word) { static void llama_unescape_whitespace(std::string & word) {
@ -3373,22 +3371,25 @@ private:
llm_bigram_bpe::queue work_queue; llm_bigram_bpe::queue work_queue;
}; };
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos) { static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos) {
std::vector<llama_vocab::id> output; std::vector<llama_vocab::id> output;
if (raw_text.empty()) {
return output;
}
if (bos && vocab.special_bos_id != -1) { if (bos && vocab.special_bos_id != -1) {
output.push_back(vocab.special_bos_id); output.push_back(vocab.special_bos_id);
} }
if (raw_text.empty()) {
return output;
}
raw_text = " " + raw_text;
switch (vocab.type) { switch (vocab.type) {
case LLAMA_VOCAB_TYPE_SPM: case LLAMA_VOCAB_TYPE_SPM:
{ {
llm_tokenizer_spm tokenizer(vocab); llm_tokenizer_spm tokenizer(vocab);
tokenizer.tokenize(llama_escape_whitespace(raw_text), output); llama_escape_whitespace(raw_text);
tokenizer.tokenize(raw_text, output);
} break; } break;
case LLAMA_VOCAB_TYPE_BPE: case LLAMA_VOCAB_TYPE_BPE:
{ {

View file

@ -6,7 +6,7 @@
#include <map> #include <map>
#include <vector> #include <vector>
static std::string unescape_whitespace(llama_context* ctx, const std::vector<llama_token>& tokens) { static std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens) {
std::string result; std::string result;
for (size_t i = 0; i < tokens.size(); ++i) { for (size_t i = 0; i < tokens.size(); ++i) {
result += llama_token_to_str(ctx, tokens[i]); result += llama_token_to_str(ctx, tokens[i]);
@ -16,38 +16,40 @@ static std::string unescape_whitespace(llama_context* ctx, const std::vector<lla
static const std::map<std::string, std::vector<llama_token>> & k_tests() { static const std::map<std::string, std::vector<llama_token>> & k_tests() {
static std::map<std::string, std::vector<llama_token>> _k_tests = { static std::map<std::string, std::vector<llama_token>> _k_tests = {
{ " ", { 1, 259, }, }, { "" , { }, },
{ " ", { 1, 1678, }, }, { " ", { 259, }, },
{ " ", { 1, 268, }, }, { " ", { 1678, }, },
{ "\t", { 1, 29871, 12, }, }, { " ", { 268, }, },
{ "\n", { 1, 29871, 13, }, }, { "\t", { 29871, 12, }, },
{ "\t\n", { 1, 29871, 12, 13, }, }, { "\n", { 29871, 13, }, },
{ "Hello world", { 1, 15043, 3186, }, }, { "\t\n", { 29871, 12, 13, }, },
{ " Hello world", { 1, 29871, 15043, 3186, }, }, { "Hello world", { 15043, 3186, }, },
{ "Hello World", { 1, 15043, 2787, }, }, { " Hello world", { 29871, 15043, 3186, }, },
{ " Hello World", { 1, 29871, 15043, 2787, }, }, { "Hello World", { 15043, 2787, }, },
{ " Hello World!", { 1, 29871, 15043, 2787, 29991, }, }, { " Hello World", { 29871, 15043, 2787, }, },
{ "Hello, world!", { 1, 15043, 29892, 3186, 29991, }, }, { " Hello World!", { 29871, 15043, 2787, 29991, }, },
{ " Hello, world!", { 1, 29871, 15043, 29892, 3186, 29991, }, }, { "Hello, world!", { 15043, 29892, 3186, 29991, }, },
{ " this is 🦙.cpp", { 1, 29871, 445, 338, 29871, 243, 162, 169, 156, 29889, 8223, }, }, { " Hello, world!", { 29871, 15043, 29892, 3186, 29991, }, },
{ "w048 7tuijk dsdfhu", { 1, 281, 29900, 29946, 29947, 29871, 29955, 9161, 13535, 18031, 2176, 6905, }, }, { " this is 🦙.cpp", { 29871, 445, 338, 29871, 243, 162, 169, 156, 29889, 8223, }, },
{ "нещо на Български", { 1, 1538, 4851, 665, 1386, 29713, 1305, }, }, { "w048 7tuijk dsdfhu", { 281, 29900, 29946, 29947, 29871, 29955, 9161, 13535, 18031, 2176, 6905, }, },
{ "កាន់តែពិសេសអាចខលចេញ", { 1, 29871, 31849, 31324, 31934, 228, 162, 142, 228, 161, { "нещо на Български", { 1538, 4851, 665, 1386, 29713, 1305, }, },
{ "កាន់តែពិសេសអាចខលចេញ",
{ 29871, 31849, 31324, 31934, 228, 162, 142, 228, 161,
146, 228, 162, 133, 228, 161, 153, 228, 161, 186, 146, 228, 162, 133, 228, 161, 153, 228, 161, 186,
31708, 228, 162, 132, 31708, 228, 161, 165, 31324, 228, 31708, 228, 162, 132, 31708, 228, 161, 165, 31324, 228,
161, 136, 228, 161, 132, 228, 161, 158, 228, 161, 161, 136, 228, 161, 132, 228, 161, 158, 228, 161,
136, 228, 162, 132, 228, 161, 140, }, }, 136, 228, 162, 132, 228, 161, 140, }, },
{ "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)", { "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)",
{ 1, 29871, 243, 162, 157, 131, 313, 8945, 29897, 29871, { 29871, 243, 162, 157, 131, 313, 8945, 29897, 29871,
243, 162, 155, 185, 30722, 243, 162, 143, 174, 30598, 243, 162, 155, 185, 30722, 243, 162, 143, 174, 30598,
313, 20787, 953, 3848, 275, 16125, 630, 29897, 29871, 31681, 313, 20787, 953, 3848, 275, 16125, 630, 29897, 29871, 31681,
313, 6194, 953, 29877, 2397, 393, 756, 967, 1914, 5993, 29897, }, }, 313, 6194, 953, 29877, 2397, 393, 756, 967, 1914, 5993, 29897, }, },
{ "Hello", { 1, 15043, }, }, { "Hello", { 15043, }, },
{ " Hello", { 1, 29871, 15043, }, }, { " Hello", { 29871, 15043, }, },
{ " Hello", { 1, 259, 15043, }, }, { " Hello", { 259, 15043, }, },
{ " Hello", { 1, 1678, 15043, }, }, { " Hello", { 1678, 15043, }, },
{ " Hello", { 1, 268, 15043, }, }, { " Hello", { 268, 15043, }, },
{ " Hello\n Hello", { 1, 268, 15043, 13, 1678, 15043, }, }, { " Hello\n Hello", { 268, 15043, 13, 1678, 15043, }, },
}; };
return _k_tests; return _k_tests;
@ -102,15 +104,18 @@ int main(int argc, char **argv) {
bool success = true; bool success = true;
for (const auto & test_kv : k_tests()) { for (const auto & test_kv : k_tests()) {
// Add a space in front of the first character to match OG llama tokenizer behavior const std::vector<llama_token> res_bos = llama_tokenize(ctx, test_kv.first, true);
std::vector<llama_token> res = llama_tokenize(ctx, " " + test_kv.first, true); const std::vector<llama_token> res_nobos = llama_tokenize(ctx, test_kv.first, false);
fprintf(stderr, "%s : '%s' tokenized to '%s'\n",
__func__, test_kv.first.c_str(), unescape_whitespace(ctx, res).c_str());
bool correct = res.size() == test_kv.second.size(); fprintf(stderr, "%s : '%s' tokenized to '%s'\n", __func__, test_kv.first.c_str(), llama_detokenize(ctx, res_bos).c_str());
for (int i = 0; i < (int) res.size() && correct; ++i) { bool correct = res_nobos.size() == test_kv.second.size() && res_bos.size() == res_nobos.size() + 1 && res_bos[0] == 1;
if (res[i] != test_kv.second[i]) {
for (int i = 0; i < (int) res_nobos.size() && correct; ++i) {
if (test_kv.second[i] != res_bos[i + 1]) {
correct = false;
}
if (test_kv.second[i] != res_nobos[i]) {
correct = false; correct = false;
} }
} }
@ -118,14 +123,15 @@ int main(int argc, char **argv) {
if (!correct) { if (!correct) {
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str()); fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__, fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
unescape_whitespace(ctx, res).c_str(), unescape_whitespace(ctx, test_kv.second).c_str()); llama_detokenize(ctx, res_nobos).c_str(),
llama_detokenize(ctx, test_kv.second).c_str());
fprintf(stderr, "%s : expected tokens: ", __func__); fprintf(stderr, "%s : expected tokens: ", __func__);
for (const auto & t : test_kv.second) { for (const auto & t : test_kv.second) {
fprintf(stderr, "%6d, ", t); fprintf(stderr, "%6d, ", t);
} }
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "%s : got tokens: ", __func__); fprintf(stderr, "%s : got tokens: ", __func__);
for (const auto & t : res) { for (const auto & t : res_nobos) {
fprintf(stderr, "%6d, ", t); fprintf(stderr, "%6d, ", t);
} }
fprintf(stderr, "\n"); fprintf(stderr, "\n");

View file

@ -12,7 +12,40 @@ dir_tokenizer = args.dir_tokenizer
tokenizer = SentencePieceProcessor(dir_tokenizer + '/tokenizer.model') tokenizer = SentencePieceProcessor(dir_tokenizer + '/tokenizer.model')
text = 'Hello, world!' tests = [
print(text) ""
" ",
" ",
" ",
"\t",
"\n",
"\t\n",
"Hello world",
" Hello world",
"Hello World",
" Hello World",
" Hello World!",
"Hello, world!",
" Hello, world!",
" this is 🦙.cpp",
"w048 7tuijk dsdfhu",
"нещо на Български",
"កាន់តែពិសេសអាចខលចេញ",
"🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)",
"Hello",
" Hello",
" Hello",
" Hello",
" Hello",
" Hello\n Hello",
]
for text in tests:
print('text: ', text)
print('\nwith bos:')
print(tokenizer.encode(text, add_bos=True)) print(tokenizer.encode(text, add_bos=True))
print(tokenizer.decode(tokenizer.encode(text, add_bos=True))) print(tokenizer.decode(tokenizer.encode(text, add_bos=True)))
print('\nwithout bos:')
print(tokenizer.encode(text, add_bos=False))
print(tokenizer.decode(tokenizer.encode(text, add_bos=False)))