From e972e6cbf8fc24d20696715adeacae7c796afb5c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Apr 2024 18:01:59 +0300 Subject: [PATCH] unicode : clean-up --- unicode-data.cpp | 6 +-- unicode-data.h | 7 +--- unicode.cpp | 102 +++++++++++++++++++++++------------------------ 3 files changed, 53 insertions(+), 62 deletions(-) diff --git a/unicode-data.cpp b/unicode-data.cpp index 3757944d2..e6bafb3a9 100644 --- a/unicode-data.cpp +++ b/unicode-data.cpp @@ -1,4 +1,4 @@ -#include "unicode-data.h" +#include "unicode-data.h" #include #include @@ -1649,7 +1649,3 @@ const std::map unicode_map_lowercase = { {0x1E917, 0x1E939}, {0x1E918, 0x1E93A}, {0x1E919, 0x1E93B}, {0x1E91A, 0x1E93C}, {0x1E91B, 0x1E93D}, {0x1E91C, 0x1E93E}, {0x1E91D, 0x1E93F}, {0x1E91E, 0x1E940}, {0x1E91F, 0x1E941}, {0x1E920, 0x1E942}, {0x1E921, 0x1E943}, }; - -const std::set unicode_regex_with_custom_preprocessor = { - "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)" -}; diff --git a/unicode-data.h b/unicode-data.h index 1fae1c714..cb9dd8aa5 100644 --- a/unicode-data.h +++ b/unicode-data.h @@ -2,8 +2,6 @@ #include #include -#include -#include #include #include @@ -14,6 +12,5 @@ extern const std::vector> unicode_ranges_accent_ma extern const std::vector> unicode_ranges_punctuation; extern const std::vector> unicode_ranges_symbol; extern const std::vector> unicode_ranges_control; -extern const std::multimap unicode_map_nfd; -extern const std::map unicode_map_lowercase; -extern const std::set unicode_regex_with_custom_preprocessor; +extern const std::multimap unicode_map_nfd; +extern const std::map unicode_map_lowercase; diff --git a/unicode.cpp b/unicode.cpp index c31d2a5b3..55fafb6c5 100644 --- a/unicode.cpp +++ b/unicode.cpp @@ -202,20 +202,16 @@ static inline std::wstring unicode_wstring_from_utf8(const std::string & s) { return conv.from_bytes(s); } -//static inline std::string unicode_wstring_to_utf8(const std::wstring & ws) { -// std::wstring_convert> conv; -// return conv.to_bytes(ws); -//} - static std::vector unicode_byte_encoding_process(const std::vector & bpe_words) { - std::vectorbpe_encoded_words; - for (auto word : bpe_words) { - std::string text_utf = ""; + std::vector bpe_encoded_words; + for (const auto & word : bpe_words) { + std::string text_utf; auto utf_word = unicode_cpts_from_utf8(word); - for (size_t i = 0; i < utf_word.size(); ++i) + for (size_t i = 0; i < utf_word.size(); ++i) { text_utf += unicode_cpt_to_utf8(utf_word[i]); + } - std::string encoded_token = ""; + std::string encoded_token; for (char & c : text_utf) { encoded_token += unicode_byte_to_utf8(c); } @@ -225,7 +221,7 @@ static std::vector unicode_byte_encoding_process(const std::vector< } // GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+ -static std::vector unicode_gpt2_regex_preprocess(const std::string & text, const std::vector & offsets) { +static std::vector unicode_regex_split_custom_gpt2(const std::string & text, const std::vector & offsets) { std::vector bpe_offsets; // store the offset of each word bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size @@ -289,7 +285,10 @@ static std::vector unicode_gpt2_regex_preprocess(const std::string & tex if (token.size()) { bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size()); } - token = utf_char + utf_char_next + utf_char_next_next; + token = utf_char; + token += utf_char_next; + token += utf_char_next_next; + bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size()); token = ""; i += 2; @@ -298,17 +297,17 @@ static std::vector unicode_gpt2_regex_preprocess(const std::string & tex } if (!split_condition && !collecting) { - if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) { + if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) { collecting_letter = true; collecting = true; } - else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { + else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { collecting_numeric = true; collecting = true; } else if ( ((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) || - (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE) + (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE) ) { collecting_special = true; collecting = true; @@ -363,7 +362,8 @@ static std::vector unicode_gpt2_regex_preprocess(const std::string & tex return bpe_offsets; } -static std::vector unicode_regex_preprocess(const std::wstring & wtext, const std::vector & offsets, const std::wstring & regex_expr) { +// use std::wregex to split the text +static std::vector unicode_regex_split_stl(const std::wstring & wtext, const std::vector & offsets, const std::wstring & regex_expr) { std::wregex expr(regex_expr); std::vector bpe_offsets; // store the offset of each word bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size @@ -393,7 +393,7 @@ static std::vector unicode_regex_preprocess(const std::wstring & wtext, } // use std::regex to split the text -static std::vector unicode_regex_preprocess(const std::string & text, const std::vector & offsets, const std::string & regex_expr) { +static std::vector unicode_regex_split_stl(const std::string & text, const std::vector & offsets, const std::string & regex_expr) { std::regex expr(regex_expr); std::vector bpe_offsets; // store the offset of each word bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size @@ -422,15 +422,11 @@ static std::vector unicode_regex_preprocess(const std::string & text, co return bpe_offsets; } -static bool unicode_regex_with_custom_preprocessor_exists(const std::string & regex) { - return unicode_regex_with_custom_preprocessor.find(regex) != unicode_regex_with_custom_preprocessor.end(); -} - -static std::vector unicode_regex_custom_preprocess(const std::string & regex, const std::string & text, const std::vector & offsets) { +static std::vector unicode_regex_split_custom(const std::string & regex, const std::string & text, const std::vector & offsets) { std::vector bpe_offsets; if (regex == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") { - bpe_offsets = unicode_gpt2_regex_preprocess(text, offsets); + bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets); } return bpe_offsets; @@ -518,20 +514,6 @@ char32_t unicode_tolower(char32_t cp) { return it == unicode_map_lowercase.end() ? cp : it->second; } -static void replace_all(std::string & s, const std::string & search, const std::string & replace) { - std::string result; - for (size_t pos = 0; ; pos += search.length()) { - auto new_pos = s.find(search, pos); - if (new_pos == std::string::npos) { - result += s.substr(pos, s.size() - pos); - break; - } - result += s.substr(pos, new_pos - pos) + replace; - pos = new_pos; - } - s = std::move(result); -} - std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) { // unicode categories static const std::map k_ucat_enum = { @@ -547,16 +529,16 @@ std::vector unicode_regex_split(const std::string & text, const std }; static const std::map k_ucat_map = { - { CODEPOINT_TYPE_DIGIT, "\x30-\x39" }, - { CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" }, - { CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, + { CODEPOINT_TYPE_DIGIT, "\x30-\x39" }, // 0-9 + { CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z + { CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} }; // compute collapsed codepoints only if needed by at least one regex bool need_collapse = false; for (auto & regex_expr : regex_exprs) { // search for unicode categories - for (auto & ucat : k_ucat_enum) { + for (const auto & ucat : k_ucat_enum) { if (std::string::npos != regex_expr.find(ucat.first)) { need_collapse = true; break; @@ -564,10 +546,13 @@ std::vector unicode_regex_split(const std::string & text, const std } } + const auto cpts = unicode_cpts_from_utf8(text); + + // generated a "collapsed" representation of the text, where all codepoints are replaced by a single byte + // ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935 std::string text_collapsed; if (need_collapse) { // collapse all unicode categories - const auto cpts = unicode_cpts_from_utf8(text); text_collapsed.resize(cpts.size()); for (size_t i = 0; i < cpts.size(); ++i) { @@ -587,14 +572,16 @@ std::vector unicode_regex_split(const std::string & text, const std } } - const auto cpts = unicode_cpts_from_utf8(text); std::vector bpe_offsets = { cpts.size() }; for (auto & regex_expr : regex_exprs) { - if (unicode_regex_with_custom_preprocessor_exists(regex_expr)) { - bpe_offsets = unicode_regex_custom_preprocess(regex_expr, text, bpe_offsets); + // first, see if we have an efficient custom regex implementation + auto tmp = unicode_regex_split_custom(regex_expr, text, bpe_offsets); + + if (!tmp.empty()) { + bpe_offsets = std::move(tmp); } else { - // fallback + // fallback to general-purpose std::regex / std::wregex try { // if a unicode category is used in the regex, we use the collapsed text and replace the unicode category // with the corresponding collapsed representation @@ -607,20 +594,32 @@ std::vector unicode_regex_split(const std::string & text, const std } if (use_collapsed) { + // sanity-check that the original regex does not contain any non-ASCII characters + const auto cpts_regex = unicode_cpts_from_utf8(regex_expr); + for (size_t i = 0; i < cpts_regex.size(); ++i) { + if (cpts_regex[i] >= 128) { + throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported"); + } + } + + // generate a collapsed representation of the regex std::string regex_expr_collapsed; - // track if we are inside [] + // track if we are inside [], because nested [] are not allowed bool inside = false; for (size_t i = 0; i < regex_expr.size(); ++i) { if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) { regex_expr_collapsed += '['; inside = true; continue; - } else if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') { + } + + if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') { regex_expr_collapsed += ']'; inside = false; continue; } + if (regex_expr[i] == '\\' && i + 1 < regex_expr.size()) { if (regex_expr[i + 1] == 'p') { if (i + 3 < regex_expr.size() && regex_expr[i + 2] == '{') { @@ -648,7 +647,7 @@ std::vector unicode_regex_split(const std::string & text, const std //printf("text_collapsed: %s\n", text_collapsed.c_str()); //printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str()); - bpe_offsets = unicode_regex_preprocess(text_collapsed, bpe_offsets, regex_expr_collapsed); + bpe_offsets = unicode_regex_split_stl(text_collapsed, bpe_offsets, regex_expr_collapsed); } else { // no unicode category used, we can use std::wregex directly const std::wstring wtext = unicode_wstring_from_utf8(text); @@ -656,7 +655,7 @@ std::vector unicode_regex_split(const std::string & text, const std //printf("text: %s\n", text.c_str()); //printf("regex_expr: %s\n", regex_expr.c_str()); - bpe_offsets = unicode_regex_preprocess(wtext, bpe_offsets, wregex_expr); + bpe_offsets = unicode_regex_split_stl(wtext, bpe_offsets, wregex_expr); } } catch (std::regex_error & e) { fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str()); @@ -667,11 +666,10 @@ std::vector unicode_regex_split(const std::string & text, const std } std::vector bpe_words; - bpe_words.reserve(bpe_offsets.size()); // Reserve memory for the approximate size + bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size size_t start = 0; for (size_t & offset : bpe_offsets) { - //bpe_words.emplace_back(unicode_wstring_to_utf8(std::wstring(wtext, start, offset))); bpe_words.emplace_back(); for (size_t i = start; i < start + offset; ++i) { bpe_words.back() += unicode_cpt_to_utf8(cpts[i]);