From c160818ec056311d8499810cd108716b61caf15d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 27 Apr 2024 00:28:36 +0300 Subject: [PATCH] wip --- llama.cpp | 7 +++-- tests/test-tokenizer-0-llama-v3.cpp | 4 +-- unicode.cpp | 43 ++++++++++++++++++++++++++--- 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/llama.cpp b/llama.cpp index 923a49877..e80e0c862 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12032,10 +12032,11 @@ struct llm_tokenizer_bpe { case LLAMA_VOCAB_PRE_TYPE_LLAMA3: word_collection = unicode_regex_split(text, { // TODO: ?????????????? - //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+ + //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", - // TODO: this is wrong - need to use ReFlex and update unicode.cpp to support the regex above - "\\p{P}+", + // TODO: this is not the same as the original regex: + // - need to use ReFlex and update unicode.cpp to support the regex above + // - or implement a custom function similar to unicode_gpt2_regex_preprocess() "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", "\\p{N}+", "[0-9][0-9][0-9]" diff --git a/tests/test-tokenizer-0-llama-v3.cpp b/tests/test-tokenizer-0-llama-v3.cpp index 2e91b717f..ab471c853 100644 --- a/tests/test-tokenizer-0-llama-v3.cpp +++ b/tests/test-tokenizer-0-llama-v3.cpp @@ -141,12 +141,12 @@ int main(int argc, char **argv) { llama_detokenize_bpe(ctx, test_kv.second).c_str()); fprintf(stderr, "%s : expected tokens: ", __func__); for (const auto & t : test_kv.second) { - fprintf(stderr, "%6d, ", t); + fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str()); } fprintf(stderr, "\n"); fprintf(stderr, "%s : got tokens: ", __func__); for (const auto & t : res) { - fprintf(stderr, "%6d, ", t); + fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str()); } fprintf(stderr, "\n"); diff --git a/unicode.cpp b/unicode.cpp index c87b24835..0c5d796cb 100644 --- a/unicode.cpp +++ b/unicode.cpp @@ -1,4 +1,4 @@ -#include "unicode.h" +#include "unicode.h" #include "unicode-data.h" #include @@ -225,7 +225,7 @@ static std::vector unicode_byte_encoding_process(const std::vector< } static std::vector unicode_gpt2_regex_preprocess(const std::wstring & wtext, const std::vector & offsets) { - std::vector bpe_offsets; // stroe the offset of each word + std::vector bpe_offsets; // store the offset of each word bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size size_t start = 0; @@ -364,7 +364,7 @@ static std::vector unicode_gpt2_regex_preprocess(const std::wstring & wt static std::vector unicode_regex_preprocess(const std::wstring & text, const std::vector & offsets, const std::wstring & regex_expr) { std::wregex expr(regex_expr); - std::vector bpe_offsets; // stroe the offset of each word + std::vector bpe_offsets; // store the offset of each word bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size size_t start = 0; for (auto offset : offsets) { @@ -391,6 +391,35 @@ static std::vector unicode_regex_preprocess(const std::wstring & text, c return bpe_offsets; } +static std::vector unicode_regex_preprocess_fallback(const std::string & text, const std::vector & offsets, const std::string & regex_expr) { + std::regex expr(regex_expr); + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + size_t start = 0; + for (auto offset : offsets) { + std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr); + std::cregex_iterator end; + + int64_t start_idx = 0; + while (it != end) { + std::cmatch match = *it; + if (match.position() > start_idx) { + bpe_offsets.emplace_back(match.position() - start_idx); + } + bpe_offsets.emplace_back(match.length()); + start_idx = match.position() + match.length(); + ++it; + } + + if (start_idx < (int64_t) offset) { + bpe_offsets.emplace_back(offset - start_idx); + } + start += offset; + } + + return bpe_offsets; +} + static bool unicode_regex_equivalent_wregex_exists(const std::string & regex) { return unicode_regex_equivalent_wregex.find(regex) != unicode_regex_equivalent_wregex.end(); } @@ -503,7 +532,13 @@ std::vector unicode_regex_split(const std::string & text, const std const std::wstring & wregex_expr = unicode_regex_equivalent_wregex.at(regex_expr); bpe_offsets = unicode_regex_preprocess(wtext, bpe_offsets, wregex_expr); } else { - throw std::runtime_error("Unicode regex is not found"); + try { + bpe_offsets = unicode_regex_preprocess_fallback(text, bpe_offsets, regex_expr); + } catch (std::regex_error & e) { + fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str()); + fprintf(stderr, "Regex error: %s\n", e.what()); + throw std::runtime_error("Failed to process regex"); + } } }