diff --git a/src/unicode.cpp b/src/unicode.cpp index 68cadf0c4..f5d149648 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -636,13 +636,47 @@ uint32_t unicode_tolower(uint32_t cp) { } std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) { - //TODO: update and add more comments - // generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte - // ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935 + // std::regex does not support unicode categories: \p{N}, \p{L}, \p{Lu}, \p{Ll} ... + // std::regex does not support unicode whitespaces \s: 0x85, 0xA0, 0x001680 ... 0x003000. + // Generate a "collapsed" representation of the regex, where all unicode categories are replaced by codepoints ranges. + // Generate a "collapsed" representation of the text, where all codepoints are forced to fall into generated category ranges. + // Text codepoints not found in generated category ranges are replaced by a "collapsed" codepoint. + // This implementation generalizes the original implementation adding support to unicode subcategories: + // https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935 + + // Definitions: + // - Unicode cagegory: high unicode categories, \p{C}, \p{L}, \p{M}, \p{N}, \p{P}, \p{S}, \p{Z}. + // - Unicode subcagegory: including all unicode categories, \p{Cc}, \p{Cf}, \p{Co}, \p{Cs}, ..., \p{Zs}. + // - Collapsed codepoint: unused codepoint representing a unicode subcategory. + // - Collapsed range: sequence of "collapsed" codepoint, representing one unicode category. + // - Collapsed regex: original regex including "collapsed" codepoints and ranges. + + // (1) Build the "collapsed" regex: + // (1.1) Generate a replacement list of codepoint ranges: + // (1.1.1) For each unicode category. + // (1.1.2) For each unicode subcategory. + // (1.1.3) Expand \s adding unicode whitespaces. + // (1.2) Each list includes its respective "collaped" codepoint/range. + // (1.3) [Optimization] Only build lists of categories present in the regex. + // (1.4) Build the "collapsed" regex replacing categories and subcategories by this "collapsed" lists. + // (2) Build a list of codepoint ranges. + // (2.1) If a codepoint is not found in this list, then it is "collapsable". + // (2.2) [Optimization] Only build lists of ranges present in the regex. + // (3) For each input text: + // (3.1) Search codepoints in the regex codepoint ranges. + // (3.2) If found, it is a valid codepoint (the "collapsed" regex uses it), literal copy. + // (3.3) If not found, replace with its "collapsed" codepoint so the "collapsed" regex can process it. + + //TODO: Refactor optimizations + // Steps (1) and (2) only depends on the regex expression text. + // Step (3) needs 'regex_expr_ranges' for text "collapsing" and 'wregex_collapsed'. + // Optimization: store and reuse 'wregex_collapsed' and 'regex_expr_ranges'. // 0xDB80 to 0xDBFF: Private Use High Surrogate (128 range values) static const uint32_t COLLAPSE_CPT_RANGE_FIRST = 0xDB80; static const uint32_t COLLAPSE_CPT_RANGE_LAST = 0xDBFF; + + // return the collapsed codepoint of an unicode category or subcategory auto category_to_collapsed_cpt = [] (const codepoint_categ categ) { const uint16_t subindex = categ.get_subcategory() >> 7; // subcategory stored in 3 bits switch(categ.get_category()) { // category fits in other 3 bits @@ -657,6 +691,8 @@ std::vector unicode_regex_split(const std::string & text, const std default: assert (false); return COLLAPSE_CPT_RANGE_FIRST; } }; + + // return the collapsed range of an unicode category (range including all subcategories) auto category_to_collapsed_range = [&] (const codepoint_categ categ) { // \p{Ll} --> \p{Ll} to \p{Ll} // has subcategory ? yes // \p{Lu} --> \p{Lu} to \p{Lu} // has subcategory ? yes @@ -688,11 +724,14 @@ std::vector unicode_regex_split(const std::string & text, const std bool inside_square = false; bool is_cpt_range = false; + // (2) Build a list of codepoint ranges // common ranges: \w \d regex_expr_ranges.emplace_back('a', 'z'); regex_expr_ranges.emplace_back('A', 'Z'); regex_expr_ranges.emplace_back('0', '9'); regex_expr_ranges.emplace_back('_', '_'); + + // (2) Build a list of codepoint ranges // common ranges: \s for (uint32_t cpt : unicode_vec_whitespace) { const auto categ_prev = unicode_cpt_category(regex_expr_ranges.back().second); @@ -704,6 +743,7 @@ std::vector unicode_regex_split(const std::string & text, const std } } + // (1.1.3) Expand \s adding unicode whitespaces. // std::wregex \s does not match non-ASCII whitespaces static const codepoint_categ categ_whitespace(codepoint_categ::MASK + 1); // UNDEF category, subcategory 1 std::wstring & wregex_whitespaces = map_categ_wregex[categ_whitespace.get_subcategory()]; @@ -728,6 +768,7 @@ std::vector unicode_regex_split(const std::string & text, const std for (size_t i = 0; i < cpts_regex.size(); ++i) { uint32_t cpt = cpts_regex[i]; + // skip regex metacharacters if (inside_square) { switch(cpt) { case '^': @@ -788,6 +829,7 @@ std::vector unicode_regex_split(const std::string & text, const std } } + // parse unicode categories and subcategories if (cpt == '\\' && cpts_regex[i + 1] == 'p' && cpts_regex[i + 2] == '{') { assert (cpts_regex[i + 3] && cpts_regex[i + 4]); codepoint_categ categ = {}; @@ -797,6 +839,7 @@ std::vector unicode_regex_split(const std::string & text, const std categ = codepoint_categ::from_chars((char)cpts_regex[i + 3], (char)cpts_regex[i + 4]); assert (cpts_regex[i + 5] == '}'); } + // (2) Build a list of codepoint ranges. (2.2) [Optimization] Only build lists of ranges present in the regex. categ.set_flag(codepoint_categ::WHITESPACE, inside_square); //NOTE: reusing flag 'WHITESPACE' to store 'inside square brackets' regex_expr_categs.emplace_back(i, categ); i += cpts_regex[i + 4] == '}' ? 4 : 5; @@ -805,6 +848,7 @@ std::vector unicode_regex_split(const std::string & text, const std if (cpt == '\\') { if (cpts_regex[i + 1] == 's' || cpts_regex[i + 1] == 'S') { // \s \S + // (2) Build a list of codepoint ranges. (2.2) [Optimization] Only build lists of ranges present in the regex. regex_expr_categs.emplace_back(i, categ_whitespace); //NOTE: reusing flag 'WHITESPACE' to store 'inside square brackets' regex_expr_categs.back().second.set_flag(codepoint_categ::WHITESPACE, inside_square); @@ -813,6 +857,7 @@ std::vector unicode_regex_split(const std::string & text, const std } } + // parse more metcharacters and espaped characters if (cpt == '\\') { switch (cpts_regex[i + 1]) { case 's': ++i; continue; // \s whitespaces @@ -835,8 +880,10 @@ std::vector unicode_regex_split(const std::string & text, const std } } + // ensure there is not a collission with any "collapsed" codepoints assert (cpt < COLLAPSE_CPT_RANGE_FIRST || COLLAPSE_CPT_RANGE_LAST < cpt); + // (2) Build a list of codepoint ranges if (is_cpt_range) { is_cpt_range = false; regex_expr_ranges.back().second = cpt; @@ -850,6 +897,7 @@ std::vector unicode_regex_split(const std::string & text, const std const uint16_t subcateg = offset_categ.second.get_subcategory(); auto it = map_categ_wregex.find(subcateg); if (it == map_categ_wregex.end()) { + // (1.2) Each list includes its respective "collaped" codepoint/range. const auto collapsed_range = category_to_collapsed_range(offset_categ.second); map_categ_wregex[subcateg] = (wchar_t) collapsed_range.first; if (collapsed_range.first < collapsed_range.second) { @@ -868,6 +916,7 @@ std::vector unicode_regex_split(const std::string & text, const std if (range.first > range.second) { // skip overlapping and repetitions continue; } + // (1.1) Generate a replacement list of codepoint ranges codepoint_categ categ = unicode_cpt_category(range.first); assert (categ == unicode_cpt_category(range.second)); auto it0 = map_categ_wregex.find(categ.get_category()); @@ -906,6 +955,7 @@ std::vector unicode_regex_split(const std::string & text, const std cpt_next == 'S' || cpt_next == 'W' || cpt_next == 'D'); // \S \W \D i += 2; } + // (1.4) Build the "collapsed" regex replacing categories and subcategories by this "collapsed" lists. const codepoint_categ categ = offset_categ.second; auto it = map_categ_wregex.find(categ.get_subcategory()); assert (it != map_categ_wregex.end()); @@ -934,14 +984,17 @@ std::vector unicode_regex_split(const std::string & text, const std wtext_collapsed.reserve(cpts.size()); for (uint32_t cpt : cpts) { const codepoint_categ categ = unicode_cpt_category(cpt); + // (3.1) Search codepoints in the regex codepoint ranges. auto it = std::lower_bound(regex_expr_ranges.begin(), regex_expr_ranges.end(), cpt, [] (const std::pair range, const uint32_t cpt) { return range.second < cpt; } ); if (it == regex_expr_ranges.end() || cpt < it->first || it->second < cpt) { + // (3.3) If not found, replace with its "collapsed" codepoint so the "collapsed" regex can process it. cpt = category_to_collapsed_cpt(categ); // not found, collapse to category codepoint } + // (3.2) If found, it is a valid codepoint (the "collapsed" regex uses it), literal copy. wtext_collapsed += (wchar_t) cpt; }