From c2406383749aec4c3617c90054cd1d64134fcf15 Mon Sep 17 00:00:00 2001 From: jaime-m-p <> Date: Thu, 8 Aug 2024 01:35:20 +0200 Subject: [PATCH] Reimplement unicode_regex_split() --- src/unicode.cpp | 417 +++++++++++++++++++----------------------------- 1 file changed, 167 insertions(+), 250 deletions(-) diff --git a/src/unicode.cpp b/src/unicode.cpp index 6ebef0ec9..4a5728ed6 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -644,141 +644,128 @@ uint32_t unicode_tolower(uint32_t cp) { } std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) { - // std::regex does not support unicode categories: \p{N}, \p{L}, \p{Lu}, \p{Ll} ... - // std::regex does not support unicode whitespaces \s: 0x85, 0xA0, 0x001680 ... 0x003000. - // Generate a "collapsed" representation of the regex, where all unicode categories are replaced by codepoints ranges. - // Generate a "collapsed" representation of the text, where all codepoints are forced to fall into generated category ranges. - // Text codepoints not found in generated category ranges are replaced by a "collapsed" codepoint. - // This implementation generalizes the original implementation adding support to unicode subcategories: - // https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935 + // std::wregex does not support unicode categories: \p{N}, \p{L}, \p{Lu}, \p{Ll} ... + // std::wregex does not support unicode whitespaces \s: 0x85, 0xA0, 0x001680 ... 0x003000. + // std::wregex allows full wchar_t 32 bit codepoints, not limited to standard max 0x110000. + // The main idea is to insert unicode category bits into all regex and text codepoints. + // Max unicode codepoint 0x110000 fits in 21 bits. + // Store unicode category and subcategory in 10 bits. + // Set the high bit to zero to keep wchar_t positive (uint32_t codepoints). + // Categorized codepoint: + // 1 bit zero + 7 bits category + 3 bits subcategory index + 21 bits codepoint + // 0b0'XXXXXXX'xxx'ccccccccccccccccccccc + // A "categorized codepoint" re-defines the ordering keeping category hierarchy. + // All high category codepoints \p{X} fall into the range: + // 0b0'XXXXXXX'000'000000000000000000000 + // 0b0'XXXXXXX'111'111111111111111111111 + // All subcategory codepoints \p{Xx} fall into the range: + // 0b0'XXXXXXX'xxx'000000000000000000000 + // 0b0'XXXXXXX'xxx'111111111111111111111 + // Processing steps: + // Build a lists of "categorized codepoints/ranges" for replacing regex \s \w and \d. + // Replace all regex codepoints/ranges with respective "categorized codepoints/ranges". + // Replace all text codepoints with respective "categorized codepoints". + // Caveats: + // Some regex ranges starts and ends with different category/subcategory. + // Split the ranges in sub-ranges to ensure a single category to maintain the new hierarchy. + // This forces iterating all ranges and could produce long sub-range sequences. - // Definitions: - // - Unicode cagegory: high unicode categories, \p{C}, \p{L}, \p{M}, \p{N}, \p{P}, \p{S}, \p{Z}. - // - Unicode subcagegory: including all unicode categories, \p{Cc}, \p{Cf}, \p{Co}, \p{Cs}, ..., \p{Zs}. - // - Collapsed codepoint: unused codepoint representing a unicode subcategory. - // - Collapsed range: sequence of "collapsed" codepoint, representing one unicode category. - // - Collapsed regex: original regex including "collapsed" codepoints and ranges. + //TODO: Regex processing can be cached. - // (1) Build the "collapsed" regex: - // (1.1) Generate a replacement list of codepoint ranges: - // (1.1.1) For each unicode category. - // (1.1.2) For each unicode subcategory. - // (1.1.3) Expand \s adding unicode whitespaces. - // (1.2) Each list includes its respective "collaped" codepoint/range. - // (1.3) [Optimization] Only build lists of categories present in the regex. - // (1.4) Build the "collapsed" regex replacing categories and subcategories by this "collapsed" lists. - // (2) Build a list of codepoint ranges. - // (2.1) If a codepoint is not found in this list, then it is "collapsable". - // (2.2) [Optimization] Only build lists of ranges present in the regex. - // (3) For each input text: - // (3.1) Search codepoints in the regex codepoint ranges. - // (3.2) If found, it is a valid codepoint (the "collapsed" regex uses it), literal copy. - // (3.3) If not found, replace with its "collapsed" codepoint so the "collapsed" regex can process it. + // insert unicode category and subcategory before codepoint bits + // 1 bit zero + 7 bits category + 3 bits subcategory index + 21 bits zero + static const auto categorized_prefix = [] (const codepoint_categ categ) -> wchar_t { + static const uint32_t MASK = codepoint_categ::MASK; // category mask + static const uint32_t SUBMASK = codepoint_categ::SUBMASK & ~codepoint_categ::MASK; // subcategory mask + return (wchar_t) (((categ.encoded & MASK) << (21+3)) | ((categ.encoded & SUBMASK) << (21-7))); + }; - //TODO: Refactor optimizations - // Steps (1) and (2) only depends on the regex expression text. - // Step (3) needs 'regex_expr_ranges' for text "collapsing" and 'wregex_collapsed'. - // Optimization: store and reuse 'wregex_collapsed' and 'regex_expr_ranges'. + // insert unicode category and subcategory before codepoint bits + // 1 bit zero + 7 bits category + 3 bits subcategory index + 21 bits codepoint + static const auto categorize_codepoint = [] (const uint32_t cpt) -> wchar_t { + GGML_ASSERT(cpt < (1 << 21)); + return categorized_prefix(unicode_cpt_category(cpt)) | (wchar_t)cpt; + }; - // 0xDB80 to 0xDBFF: Private Use High Surrogate (128 range values) - static const uint32_t COLLAPSE_CPT_RANGE_FIRST = 0xDB80; - static const uint32_t COLLAPSE_CPT_RANGE_LAST = 0xDBFF; + // remove the categorized prefix bits and restore original codepoint bits + static const auto decategorize_codepoint = [] (const wchar_t cpt) -> uint32_t { + return (uint32_t) cpt & ((1 << 21) - 1); + }; - // return the collapsed codepoint of an unicode category or subcategory - auto category_to_collapsed_cpt = [] (const codepoint_categ categ) { - const uint16_t subindex = categ.get_subcategory() >> 7; // subcategory stored in 3 bits - switch(categ.get_category()) { // category fits in other 3 bits - case codepoint_categ::UNDEF: return COLLAPSE_CPT_RANGE_FIRST + ((0 << 3) | subindex); - case codepoint_categ::C: return COLLAPSE_CPT_RANGE_FIRST + ((1 << 3) | subindex); - case codepoint_categ::L: return COLLAPSE_CPT_RANGE_FIRST + ((2 << 3) | subindex); - case codepoint_categ::M: return COLLAPSE_CPT_RANGE_FIRST + ((3 << 3) | subindex); - case codepoint_categ::N: return COLLAPSE_CPT_RANGE_FIRST + ((4 << 3) | subindex); - case codepoint_categ::P: return COLLAPSE_CPT_RANGE_FIRST + ((5 << 3) | subindex); - case codepoint_categ::S: return COLLAPSE_CPT_RANGE_FIRST + ((6 << 3) | subindex); - case codepoint_categ::Z: return COLLAPSE_CPT_RANGE_FIRST + ((7 << 3) | subindex); - default: GGML_ABORT("invalid category"); + // returns the respective categorized codepoint range of the category/subcategory + static const auto categorize_range_from_chars = [] (const char categ, const char subcateg) { + const wchar_t range_ini = categorized_prefix(codepoint_categ::from_chars(categ, subcateg)); + const wchar_t range_end = (wchar_t) (range_ini | (subcateg ? (1<<21)-1 : (1<<24)-1)); + return std::pair(range_ini, range_end); + }; + + // helper function to append/concat regex expressions + auto wregex_append_subregex = [] (std::wstring & wregex, const std::wstring & subregex, const bool add_squares, const bool negated) { + if (add_squares) { + wregex += '['; + if (negated) { + wregex += '^'; + } + wregex += subregex; + wregex += ']'; + } else { + GGML_ASSERT(!negated); //TODO: negation inside square brackets: \S \W \D + wregex += subregex; } }; - // return the collapsed range of an unicode category (range including all subcategories) - auto category_to_collapsed_range = [&] (const codepoint_categ categ) { - // \p{Ll} --> \p{Ll} to \p{Ll} // has subcategory ? yes - // \p{Lu} --> \p{Lu} to \p{Lu} // has subcategory ? yes - // \p{L} --> \p{Ll} to \p{Lu} // has subcategory ? no - GGML_ASSERT((COLLAPSE_CPT_RANGE_FIRST & 0x7) == 0); - const uint32_t collapsed = category_to_collapsed_cpt(categ); - const uint32_t range = (collapsed & 0x7) ? 0 : 0x7; // has subcategory ? - return std::pair(collapsed, collapsed + range); + // \d digits replacement + static const std::wstring wregex_digits = { + categorize_codepoint('0'), '-', categorize_codepoint('9'), }; - GGML_ASSERT(sizeof(wchar_t) == sizeof(u_int32_t)); + // \w words replacement + static const std::wstring wregex_words = { + categorize_codepoint('_'), + categorize_codepoint('0'), '-', categorize_codepoint('9'), + categorize_codepoint('A'), '-', categorize_codepoint('Z'), + categorize_codepoint('a'), '-', categorize_codepoint('z'), + }; - const auto cpts = unicode_cpts_from_utf8(text); + // \s whitespaces replacement + static const std::wstring wregex_whitespaces = [] { + std::wstring wregex_whitespaces; + for (const auto & range : unicode_ranges_whitespace) { + wregex_whitespaces += categorize_codepoint(range.first); + if (range.second > range.first) { + wregex_whitespaces += '-'; + wregex_whitespaces += categorize_codepoint(range.second); + } + } + return wregex_whitespaces; + }(); - std::vector bpe_offsets = { cpts.size() }; + GGML_ASSERT(sizeof(wchar_t) == sizeof(uint32_t)); + std::wstring wtext = unicode_wstring_from_utf8(text); + + std::vector offsets = { wtext.size() }; for (auto & regex_expr : regex_exprs) { // first, see if we have an efficient custom regex implementation - auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets); + auto tmp = unicode_regex_split_custom(text, regex_expr, offsets); if (!tmp.empty()) { - bpe_offsets = std::move(tmp); + offsets = std::move(tmp); continue; } - std::vector> regex_expr_ranges; // start codepoint, last codepoint - std::vector> regex_expr_categs; // offset, codepoint category - std::map map_categ_wregex; // categ --> regex utf32 string - std::wstring wregex_collapsed; - std::wstring wtext_collapsed; + std::wstring wregex; bool inside_square = false; bool is_cpt_range = false; - // (2) Build a list of codepoint ranges - // common ranges: \w \d - regex_expr_ranges.emplace_back('a', 'z'); - regex_expr_ranges.emplace_back('A', 'Z'); - regex_expr_ranges.emplace_back('0', '9'); - regex_expr_ranges.emplace_back('_', '_'); - - // (2) Build a list of codepoint ranges - // common ranges: \s - for (uint32_t cpt : unicode_vec_whitespace) { - const auto categ_prev = unicode_cpt_category(regex_expr_ranges.back().second); - const auto categ_last = unicode_cpt_category(cpt); - if (categ_prev == categ_last && regex_expr_ranges.back().second + 1 == cpt) { - regex_expr_ranges.back().second = cpt; - } else { - regex_expr_ranges.emplace_back(cpt, cpt); - } - } - - // (1.1.3) Expand \s adding unicode whitespaces. - // std::wregex \s does not match non-ASCII whitespaces - static const codepoint_categ categ_whitespace(codepoint_categ::MASK + 1); // UNDEF category, subcategory 1 - std::wstring & wregex_whitespaces = map_categ_wregex[categ_whitespace.get_subcategory()]; - wregex_whitespaces += L"\\s"; - for (uint32_t cpt : unicode_vec_whitespace) { - if (cpt >= 0x80) { // non-ASCII whitespaces - if (wregex_whitespaces.back() + 1 == (wchar_t) cpt) { - if (*(wregex_whitespaces.end() - 2) == '-') { - wregex_whitespaces.back() = cpt; - } else { - wregex_whitespaces += '-'; - wregex_whitespaces += cpt; - } - } else { - wregex_whitespaces += (wchar_t) cpt; - } - } - } - const auto cpts_regex = unicode_cpts_from_utf8(regex_expr); + wregex.reserve(2 * cpts_regex.size()); for (size_t i = 0; i < cpts_regex.size(); ++i) { uint32_t cpt = cpts_regex[i]; - // skip regex metacharacters + // parse regex metacharacters + wregex += (wchar_t) cpt; if (inside_square) { switch(cpt) { case '^': @@ -811,6 +798,7 @@ std::vector unicode_regex_split(const std::string & text, const std case '{': while (cpt && cpt != '}') { cpt = cpts_regex[++i]; + wregex += (wchar_t) cpt; } continue; case '}': @@ -819,12 +807,19 @@ std::vector unicode_regex_split(const std::string & text, const std case '(': if (cpts_regex[i + 1] == '?') { // (?: (?i: (?= (?! (?<= (? unicode_regex_split(const std::string & text, const std continue; } } + wregex.pop_back(); - // parse unicode categories and subcategories + // parse unicode categories and subcategories, replace category with the categorized range if (cpt == '\\' && cpts_regex[i + 1] == 'p' && cpts_regex[i + 2] == '{') { GGML_ASSERT(cpts_regex[i + 3] && cpts_regex[i + 4]); - codepoint_categ categ = {}; + std::pair range; if (cpts_regex[i + 4] == '}') { - categ = codepoint_categ::from_chars((char)cpts_regex[i + 3]); + range = categorize_range_from_chars((char)cpts_regex[i + 3], (char)'\0'); + i += 4; } else { - categ = codepoint_categ::from_chars((char)cpts_regex[i + 3], (char)cpts_regex[i + 4]); - GGML_ASSERT(cpts_regex[i + 5] == '}'); + range = categorize_range_from_chars((char)cpts_regex[i + 3], (char)cpts_regex[i + 4]); + i += 5; } - // (2) Build a list of codepoint ranges. (2.2) [Optimization] Only build lists of ranges present in the regex. - categ.set_flag(codepoint_categ::WHITESPACE, inside_square); //NOTE: reusing flag 'WHITESPACE' to store 'inside square brackets' - regex_expr_categs.emplace_back((uint32_t)i, categ); - i += cpts_regex[i + 4] == '}' ? 4 : 5; + GGML_ASSERT(cpts_regex[i] == '}'); + const std::wstring subregex = {range.first, '-', range.second}; + wregex_append_subregex(wregex, subregex, !inside_square, false); continue; } - if (cpt == '\\') { - if (cpts_regex[i + 1] == 's' || cpts_regex[i + 1] == 'S') { // \s \S - // (2) Build a list of codepoint ranges. (2.2) [Optimization] Only build lists of ranges present in the regex. - regex_expr_categs.emplace_back((uint32_t)i, categ_whitespace); - //NOTE: reusing flag 'WHITESPACE' to store 'inside square brackets' - regex_expr_categs.back().second.set_flag(codepoint_categ::WHITESPACE, inside_square); - i += 1; - continue; - } - } - // parse more metcharacters and espaped characters if (cpt == '\\') { switch (cpts_regex[i + 1]) { - case 's': ++i; continue; // \s whitespaces - case 'w': ++i; continue; // \w words - case 'd': ++i; continue; // \d digits - case 'S': ++i; continue; // \S no whitespaces - case 'W': ++i; continue; // \W no words - case 'D': ++i; continue; // \D no digits + case 's': // \s whitespaces + case 'S': // \S no whitespaces + wregex_append_subregex(wregex, wregex_whitespaces, !inside_square, cpts_regex[++i] == 'S'); + continue; + case 'w': // \w words + case 'W': // \W no words + wregex_append_subregex(wregex, wregex_words, !inside_square, cpts_regex[++i] == 'W'); + continue; + case 'd': // \d digits + case 'D': // \D no digits + wregex_append_subregex(wregex, wregex_digits, !inside_square, cpts_regex[++i] == 'D'); + continue; case 't': ++i; cpt = '\t'; break; case 'r': ++i; cpt = '\r'; break; case 'n': ++i; cpt = '\n'; break; @@ -886,139 +877,65 @@ std::vector unicode_regex_split(const std::string & text, const std GGML_ASSERT(!is_cpt_range); cpt = cpts_regex[++i]; GGML_ASSERT(cpt < 0x80); - break; + break; } } - // ensure there is not a collission with any "collapsed" codepoints - GGML_ASSERT(cpt < COLLAPSE_CPT_RANGE_FIRST || COLLAPSE_CPT_RANGE_LAST < cpt); - - // (2) Build a list of codepoint ranges if (is_cpt_range) { + // Some regex ranges starts and ends with different category/subcategory. + // Split the ranges in sub-ranges to ensure a single category to maintain the new hierarchy. + // Warning: This forces iterating all ranges and could produce long sub-range sequences. + GGML_ASSERT(wregex.size() && wregex.back() == '-'); + wregex.pop_back(); + wchar_t categorized = wregex.back(); + uint32_t range_ini = decategorize_codepoint(categorized); + const uint32_t range_end = cpt; + GGML_ASSERT(range_ini <= range_end); + codepoint_categ range_categ = unicode_cpt_category(range_ini); + for (cpt = range_ini + 1; cpt <= range_end; ++cpt) { + codepoint_categ categ = unicode_cpt_category(cpt); + if (categ == range_categ) { // still same range category ? + ++categorized; + if (cpt == range_ini + 1) { // single step, no need range + wregex += categorized; + } else if (cpt == range_ini + 2) { // need range if +2 step + wregex.back() = '-'; + wregex += categorized; + } else { + wregex.back() = categorized; // keep range growing + } + } else { // new range category + categorized = categorize_codepoint(cpt); + wregex += categorized; + range_categ = categ; + range_ini = cpt; + } + } is_cpt_range = false; - regex_expr_ranges.back().second = cpt; } else { - regex_expr_ranges.emplace_back(cpt, cpt); + wregex += categorize_codepoint(cpt); } } - // assign collapsed codepoint to each category regex \p{...} - for (auto offset_categ : regex_expr_categs) { - const uint16_t subcateg = offset_categ.second.get_subcategory(); - auto it = map_categ_wregex.find(subcateg); - if (it == map_categ_wregex.end()) { - // (1.2) Each list includes its respective "collaped" codepoint/range. - const auto collapsed_range = category_to_collapsed_range(offset_categ.second); - map_categ_wregex[subcateg] = (wchar_t) collapsed_range.first; - if (collapsed_range.first < collapsed_range.second) { - map_categ_wregex[subcateg] += (wchar_t) '-'; - map_categ_wregex[subcateg] += (wchar_t) collapsed_range.second; - } + // categorize all wtext codepoints + if (wtext.size() && wtext[0] < MAX_CODEPOINTS) { // if not already categorized + for (size_t i = 0; i < wtext.size(); ++i) { + wtext[i] = categorize_codepoint((uint32_t) wtext[i]); } } - // copy found regex ranges to each category regex - uint32_t regex_expr_ranges_uniques = 0; - std::pair prev_range = {0, -1}; - std::sort(regex_expr_ranges.begin(), regex_expr_ranges.end()); - for (auto range : regex_expr_ranges) { - range.first = std::max(range.first, prev_range.second + 1); // prevent overlapping //TODO: as error? - if (range.first > range.second) { // skip overlapping and repetitions - continue; - } - // (1.1) Generate a replacement list of codepoint ranges - codepoint_categ categ = unicode_cpt_category(range.first); - GGML_ASSERT(categ == unicode_cpt_category(range.second)); - auto it0 = map_categ_wregex.find(categ.get_category()); - auto it1 = map_categ_wregex.find(categ.get_subcategory()); - for (const auto & it : {it0, it1}) { - if (it != map_categ_wregex.end()) { - it->second += (wchar_t) range.first; - if (range.first < range.second) { - it->second += (wchar_t) '-'; - it->second += (wchar_t) range.second; - } - } - } - prev_range = range; - regex_expr_ranges[regex_expr_ranges_uniques++] = range; - } - regex_expr_ranges.resize(regex_expr_ranges_uniques); - - // replace categories with respective collapsed codepoint and ranges - uint32_t i = 0; - wregex_collapsed.reserve(regex_expr.size()); - for (auto offset_categ : regex_expr_categs) { - while (i < offset_categ.first) { // copy original regex until reaching the category - wregex_collapsed += (wchar_t) cpts_regex[i]; - i++; - } - GGML_ASSERT(cpts_regex[i] == '\\'); - const uint32_t cpt_next = cpts_regex[i + 1]; - const bool is_negated = cpt_next < 'a'; // is uppercase - if (cpt_next == 'p' || cpt_next == 'P') { - GGML_ASSERT(cpts_regex[i + 2] == '{' && cpts_regex[i + 3]); - i += cpts_regex[i + 4] == '}' ? 5 : 6; - GGML_ASSERT(cpts_regex[i - 1] == '}'); - } else { - GGML_ASSERT(cpt_next == 's' || cpt_next == 'w' || cpt_next == 'd' || // \s \w \d - cpt_next == 'S' || cpt_next == 'W' || cpt_next == 'D'); // \S \W \D - i += 2; - } - // (1.4) Build the "collapsed" regex replacing categories and subcategories by this "collapsed" lists. - const codepoint_categ categ = offset_categ.second; - auto it = map_categ_wregex.find(categ.get_subcategory()); - GGML_ASSERT(it != map_categ_wregex.end()); - if (it != map_categ_wregex.end()) { - if (categ.is_whitespace()) { // inside square brackets //NOTE: reusing flag WHITESPACE - GGML_ASSERT(is_negated == false); - wregex_collapsed += it->second; - } else if(it->second.size() == 1 && !is_negated) { - wregex_collapsed += it->second; - } else { - wregex_collapsed += '['; - if (is_negated) { - wregex_collapsed += '^'; - } - wregex_collapsed += it->second; - wregex_collapsed += ']'; - } - } - } - while (i < (uint32_t)cpts_regex.size()) { - wregex_collapsed += cpts_regex[i]; - i++; - } - - // collapse text codepoints not included in 'regex_expr_ranges' - wtext_collapsed.reserve(cpts.size()); - for (uint32_t cpt : cpts) { - const codepoint_categ categ = unicode_cpt_category(cpt); - // (3.1) Search codepoints in the regex codepoint ranges. - auto it = std::lower_bound(regex_expr_ranges.begin(), regex_expr_ranges.end(), cpt, - [] (const std::pair range, const uint32_t cpt) { - return range.second < cpt; - } - ); - if (it == regex_expr_ranges.end() || cpt < it->first || it->second < cpt) { - // (3.3) If not found, replace with its "collapsed" codepoint so the "collapsed" regex can process it. - cpt = category_to_collapsed_cpt(categ); // not found, collapse to category codepoint - } - // (3.2) If found, it is a valid codepoint (the "collapsed" regex uses it), literal copy. - wtext_collapsed += (wchar_t) cpt; - } - - bpe_offsets = unicode_regex_split_stl(wtext_collapsed, wregex_collapsed, bpe_offsets); + offsets = unicode_regex_split_stl(wtext, wregex, offsets); } std::vector bpe_words; - bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size + bpe_words.reserve(offsets.size()); // reserve memory for the approximate size size_t start = 0; - for (size_t & offset : bpe_offsets) { + for (size_t & offset : offsets) { bpe_words.emplace_back(); for (size_t i = start; i < start + offset; ++i) { - bpe_words.back() += unicode_cpt_to_utf8(cpts[i]); + const uint32_t cpt = decategorize_codepoint(wtext[i]); + bpe_words.back() += unicode_cpt_to_utf8(cpt); } start += offset; }