From 5d0ffb69f5c6b8aa10ee2bb88c6a601a46df33d3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 26 Aug 2023 17:08:59 +0300 Subject: [PATCH] llama : prefix input text for tokenization with whitespace --- examples/embedding/embedding.cpp | 3 -- examples/main/main.cpp | 5 -- llama.cpp | 23 +++++---- tests/test-tokenizer-0.cpp | 88 +++++++++++++++++--------------- tests/test-tokenizer-0.py | 41 +++++++++++++-- 5 files changed, 96 insertions(+), 64 deletions(-) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 38395c75b..abe5c8781 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -56,9 +56,6 @@ int main(int argc, char ** argv) { int n_past = 0; - // Add a space in front of the first character to match OG llama tokenizer behavior - params.prompt.insert(0, 1, ' '); - // tokenize the prompt auto embd_inp = ::llama_tokenize(ctx, params.prompt, true); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 4665b82fe..0d3783d67 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -195,11 +195,6 @@ int main(int argc, char ** argv) { // tokenize the prompt std::vector embd_inp; - if (llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM) { - // Add a space in front of the first character to match OG llama tokenizer behavior - params.prompt.insert(0, 1, ' '); - } - if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos); } else { diff --git a/llama.cpp b/llama.cpp index b0a3b5768..0453dd9cf 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1635,7 +1635,7 @@ static void llm_load_hparams( } // TODO: This should probably be in llama.h -static std::vector llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos); +static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos); static void llm_load_vocab( llama_model_loader & ml, @@ -3026,10 +3026,8 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { return vocab.token_to_id.at(buf); } -static std::string llama_escape_whitespace(const std::string& text) { - std::string result = text; - replace_all(result, " ", "\xe2\x96\x81"); - return result; +static void llama_escape_whitespace(std::string & text) { + replace_all(text, " ", "\xe2\x96\x81"); } static void llama_unescape_whitespace(std::string & word) { @@ -3373,22 +3371,25 @@ private: llm_bigram_bpe::queue work_queue; }; -static std::vector llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos) { +static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos) { std::vector output; - if (raw_text.empty()) { - return output; - } - if (bos && vocab.special_bos_id != -1) { output.push_back(vocab.special_bos_id); } + if (raw_text.empty()) { + return output; + } + + raw_text = " " + raw_text; + switch (vocab.type) { case LLAMA_VOCAB_TYPE_SPM: { llm_tokenizer_spm tokenizer(vocab); - tokenizer.tokenize(llama_escape_whitespace(raw_text), output); + llama_escape_whitespace(raw_text); + tokenizer.tokenize(raw_text, output); } break; case LLAMA_VOCAB_TYPE_BPE: { diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp index 4bed054d6..8a86ad4e6 100644 --- a/tests/test-tokenizer-0.cpp +++ b/tests/test-tokenizer-0.cpp @@ -6,7 +6,7 @@ #include #include -static std::string unescape_whitespace(llama_context* ctx, const std::vector& tokens) { +static std::string llama_detokenize(llama_context * ctx, const std::vector & tokens) { std::string result; for (size_t i = 0; i < tokens.size(); ++i) { result += llama_token_to_str(ctx, tokens[i]); @@ -16,38 +16,40 @@ static std::string unescape_whitespace(llama_context* ctx, const std::vector> & k_tests() { static std::map> _k_tests = { - { " ", { 1, 259, }, }, - { " ", { 1, 1678, }, }, - { " ", { 1, 268, }, }, - { "\t", { 1, 29871, 12, }, }, - { "\n", { 1, 29871, 13, }, }, - { "\t\n", { 1, 29871, 12, 13, }, }, - { "Hello world", { 1, 15043, 3186, }, }, - { " Hello world", { 1, 29871, 15043, 3186, }, }, - { "Hello World", { 1, 15043, 2787, }, }, - { " Hello World", { 1, 29871, 15043, 2787, }, }, - { " Hello World!", { 1, 29871, 15043, 2787, 29991, }, }, - { "Hello, world!", { 1, 15043, 29892, 3186, 29991, }, }, - { " Hello, world!", { 1, 29871, 15043, 29892, 3186, 29991, }, }, - { " this is πŸ¦™.cpp", { 1, 29871, 445, 338, 29871, 243, 162, 169, 156, 29889, 8223, }, }, - { "w048 7tuijk dsdfhu", { 1, 281, 29900, 29946, 29947, 29871, 29955, 9161, 13535, 18031, 2176, 6905, }, }, - { "Π½Π΅Ρ‰ΠΎ Π½Π° Π‘ΡŠΠ»Π³Π°Ρ€ΡΠΊΠΈ", { 1, 1538, 4851, 665, 1386, 29713, 1305, }, }, - { "αž€αžΆαž“αŸ‹αžαŸ‚αž–αž·αžŸαŸαžŸαž’αžΆαž…αžαž›αž…αŸαž‰", { 1, 29871, 31849, 31324, 31934, 228, 162, 142, 228, 161, - 146, 228, 162, 133, 228, 161, 153, 228, 161, 186, - 31708, 228, 162, 132, 31708, 228, 161, 165, 31324, 228, - 161, 136, 228, 161, 132, 228, 161, 158, 228, 161, - 136, 228, 162, 132, 228, 161, 140, }, }, + { "" , { }, }, + { " ", { 259, }, }, + { " ", { 1678, }, }, + { " ", { 268, }, }, + { "\t", { 29871, 12, }, }, + { "\n", { 29871, 13, }, }, + { "\t\n", { 29871, 12, 13, }, }, + { "Hello world", { 15043, 3186, }, }, + { " Hello world", { 29871, 15043, 3186, }, }, + { "Hello World", { 15043, 2787, }, }, + { " Hello World", { 29871, 15043, 2787, }, }, + { " Hello World!", { 29871, 15043, 2787, 29991, }, }, + { "Hello, world!", { 15043, 29892, 3186, 29991, }, }, + { " Hello, world!", { 29871, 15043, 29892, 3186, 29991, }, }, + { " this is πŸ¦™.cpp", { 29871, 445, 338, 29871, 243, 162, 169, 156, 29889, 8223, }, }, + { "w048 7tuijk dsdfhu", { 281, 29900, 29946, 29947, 29871, 29955, 9161, 13535, 18031, 2176, 6905, }, }, + { "Π½Π΅Ρ‰ΠΎ Π½Π° Π‘ΡŠΠ»Π³Π°Ρ€ΡΠΊΠΈ", { 1538, 4851, 665, 1386, 29713, 1305, }, }, + { "αž€αžΆαž“αŸ‹αžαŸ‚αž–αž·αžŸαŸαžŸαž’αžΆαž…αžαž›αž…αŸαž‰", + { 29871, 31849, 31324, 31934, 228, 162, 142, 228, 161, + 146, 228, 162, 133, 228, 161, 153, 228, 161, 186, + 31708, 228, 162, 132, 31708, 228, 161, 165, 31324, 228, + 161, 136, 228, 161, 132, 228, 161, 158, 228, 161, + 136, 228, 162, 132, 228, 161, 140, }, }, { "πŸš€ (normal) πŸ˜Άβ€πŸŒ«οΈ (multiple emojis concatenated) βœ… (only emoji that has its own token)", - { 1, 29871, 243, 162, 157, 131, 313, 8945, 29897, 29871, - 243, 162, 155, 185, 30722, 243, 162, 143, 174, 30598, - 313, 20787, 953, 3848, 275, 16125, 630, 29897, 29871, 31681, - 313, 6194, 953, 29877, 2397, 393, 756, 967, 1914, 5993, 29897, }, }, - { "Hello", { 1, 15043, }, }, - { " Hello", { 1, 29871, 15043, }, }, - { " Hello", { 1, 259, 15043, }, }, - { " Hello", { 1, 1678, 15043, }, }, - { " Hello", { 1, 268, 15043, }, }, - { " Hello\n Hello", { 1, 268, 15043, 13, 1678, 15043, }, }, + { 29871, 243, 162, 157, 131, 313, 8945, 29897, 29871, + 243, 162, 155, 185, 30722, 243, 162, 143, 174, 30598, + 313, 20787, 953, 3848, 275, 16125, 630, 29897, 29871, 31681, + 313, 6194, 953, 29877, 2397, 393, 756, 967, 1914, 5993, 29897, }, }, + { "Hello", { 15043, }, }, + { " Hello", { 29871, 15043, }, }, + { " Hello", { 259, 15043, }, }, + { " Hello", { 1678, 15043, }, }, + { " Hello", { 268, 15043, }, }, + { " Hello\n Hello", { 268, 15043, 13, 1678, 15043, }, }, }; return _k_tests; @@ -102,15 +104,18 @@ int main(int argc, char **argv) { bool success = true; for (const auto & test_kv : k_tests()) { - // Add a space in front of the first character to match OG llama tokenizer behavior - std::vector res = llama_tokenize(ctx, " " + test_kv.first, true); - fprintf(stderr, "%s : '%s' tokenized to '%s'\n", - __func__, test_kv.first.c_str(), unescape_whitespace(ctx, res).c_str()); + const std::vector res_bos = llama_tokenize(ctx, test_kv.first, true); + const std::vector res_nobos = llama_tokenize(ctx, test_kv.first, false); - bool correct = res.size() == test_kv.second.size(); + fprintf(stderr, "%s : '%s' tokenized to '%s'\n", __func__, test_kv.first.c_str(), llama_detokenize(ctx, res_bos).c_str()); - for (int i = 0; i < (int) res.size() && correct; ++i) { - if (res[i] != test_kv.second[i]) { + bool correct = res_nobos.size() == test_kv.second.size() && res_bos.size() == res_nobos.size() + 1 && res_bos[0] == 1; + + for (int i = 0; i < (int) res_nobos.size() && correct; ++i) { + if (test_kv.second[i] != res_bos[i + 1]) { + correct = false; + } + if (test_kv.second[i] != res_nobos[i]) { correct = false; } } @@ -118,14 +123,15 @@ int main(int argc, char **argv) { if (!correct) { fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str()); fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__, - unescape_whitespace(ctx, res).c_str(), unescape_whitespace(ctx, test_kv.second).c_str()); + llama_detokenize(ctx, res_nobos).c_str(), + llama_detokenize(ctx, test_kv.second).c_str()); fprintf(stderr, "%s : expected tokens: ", __func__); for (const auto & t : test_kv.second) { fprintf(stderr, "%6d, ", t); } fprintf(stderr, "\n"); fprintf(stderr, "%s : got tokens: ", __func__); - for (const auto & t : res) { + for (const auto & t : res_nobos) { fprintf(stderr, "%6d, ", t); } fprintf(stderr, "\n"); diff --git a/tests/test-tokenizer-0.py b/tests/test-tokenizer-0.py index d21f8b5a1..982615e60 100644 --- a/tests/test-tokenizer-0.py +++ b/tests/test-tokenizer-0.py @@ -12,7 +12,40 @@ dir_tokenizer = args.dir_tokenizer tokenizer = SentencePieceProcessor(dir_tokenizer + '/tokenizer.model') -text = 'Hello, world!' -print(text) -print(tokenizer.encode(text, add_bos=True)) -print(tokenizer.decode(tokenizer.encode(text, add_bos=True))) +tests = [ + "" + " ", + " ", + " ", + "\t", + "\n", + "\t\n", + "Hello world", + " Hello world", + "Hello World", + " Hello World", + " Hello World!", + "Hello, world!", + " Hello, world!", + " this is πŸ¦™.cpp", + "w048 7tuijk dsdfhu", + "Π½Π΅Ρ‰ΠΎ Π½Π° Π‘ΡŠΠ»Π³Π°Ρ€ΡΠΊΠΈ", + "αž€αžΆαž“αŸ‹αžαŸ‚αž–αž·αžŸαŸαžŸαž’αžΆαž…αžαž›αž…αŸαž‰", + "πŸš€ (normal) πŸ˜Άβ€πŸŒ«οΈ (multiple emojis concatenated) βœ… (only emoji that has its own token)", + "Hello", + " Hello", + " Hello", + " Hello", + " Hello", + " Hello\n Hello", + ] + + +for text in tests: + print('text: ', text) + print('\nwith bos:') + print(tokenizer.encode(text, add_bos=True)) + print(tokenizer.decode(tokenizer.encode(text, add_bos=True))) + print('\nwithout bos:') + print(tokenizer.encode(text, add_bos=False)) + print(tokenizer.decode(tokenizer.encode(text, add_bos=False)))