diff --git a/src/unicode.cpp b/src/unicode.cpp index dd413c809..68cadf0c4 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -636,66 +636,38 @@ uint32_t unicode_tolower(uint32_t cp) { } std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) { - // unicode categories - static const std::map k_ucat_enum = { - { "\\p{N}", codepoint_categ::N }, - { "\\p{L}", codepoint_categ::L }, - { "\\p{P}", codepoint_categ::P }, - }; - - static const std::map k_ucat_cpt = { - { codepoint_categ::N, 0xD1 }, - { codepoint_categ::L, 0xD2 }, - { codepoint_categ::P, 0xD3 }, - }; - - static const std::map k_ucat_map = { - { codepoint_categ::N, "\x30-\x39" }, // 0-9 - { codepoint_categ::L, "\x41-\x5A\x61-\x7A" }, // A-Za-z - { codepoint_categ::P, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} - }; - - // compute collapsed codepoints only if needed by at least one regex - bool need_collapse = false; - for (auto & regex_expr : regex_exprs) { - // search for unicode categories - for (const auto & ucat : k_ucat_enum) { - if (std::string::npos != regex_expr.find(ucat.first)) { - need_collapse = true; - break; - } - } - } - - const auto cpts = unicode_cpts_from_utf8(text); - + //TODO: update and add more comments // generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte // ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935 - std::string text_collapsed; - if (need_collapse) { - // collapse all unicode categories - text_collapsed.resize(cpts.size()); - for (size_t i = 0; i < cpts.size(); ++i) { - // keep single-byte codepoints as is - if (cpts[i] < 128) { - text_collapsed[i] = cpts[i]; - continue; - } - - const auto categ = unicode_cpt_category(cpts[i]); - - if (categ.is_whitespace()) { - //NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does. - //text_collapsed[i] = (char) 0x85; // as whitespace fallback - text_collapsed[i] = (char) 0x0B; // as whitespace fallback - } else if (k_ucat_cpt.find(categ.get_category()) != k_ucat_cpt.end()) { - text_collapsed[i] = k_ucat_cpt.at(categ.get_category()); - } else { - text_collapsed[i] = (char) 0xD0; // fallback - } + // 0xDB80 to 0xDBFF: Private Use High Surrogate (128 range values) + static const uint32_t COLLAPSE_CPT_RANGE_FIRST = 0xDB80; + static const uint32_t COLLAPSE_CPT_RANGE_LAST = 0xDBFF; + auto category_to_collapsed_cpt = [] (const codepoint_categ categ) { + const uint16_t subindex = categ.get_subcategory() >> 7; // subcategory stored in 3 bits + switch(categ.get_category()) { // category fits in other 3 bits + case codepoint_categ::UNDEF: return COLLAPSE_CPT_RANGE_FIRST + ((0 << 3) | subindex); + case codepoint_categ::C: return COLLAPSE_CPT_RANGE_FIRST + ((1 << 3) | subindex); + case codepoint_categ::L: return COLLAPSE_CPT_RANGE_FIRST + ((2 << 3) | subindex); + case codepoint_categ::M: return COLLAPSE_CPT_RANGE_FIRST + ((3 << 3) | subindex); + case codepoint_categ::N: return COLLAPSE_CPT_RANGE_FIRST + ((4 << 3) | subindex); + case codepoint_categ::P: return COLLAPSE_CPT_RANGE_FIRST + ((5 << 3) | subindex); + case codepoint_categ::S: return COLLAPSE_CPT_RANGE_FIRST + ((6 << 3) | subindex); + case codepoint_categ::Z: return COLLAPSE_CPT_RANGE_FIRST + ((7 << 3) | subindex); + default: assert (false); return COLLAPSE_CPT_RANGE_FIRST; } - } + }; + auto category_to_collapsed_range = [&] (const codepoint_categ categ) { + // \p{Ll} --> \p{Ll} to \p{Ll} // has subcategory ? yes + // \p{Lu} --> \p{Lu} to \p{Lu} // has subcategory ? yes + // \p{L} --> \p{Ll} to \p{Lu} // has subcategory ? no + assert ((COLLAPSE_CPT_RANGE_FIRST & 0b111) == 0); + const uint32_t collapsed = category_to_collapsed_cpt(categ); + const uint32_t range = (collapsed & 0b111) ? 0 : 0b111; // has subcategory ? + return std::pair(collapsed, collapsed + range); + }; + + const auto cpts = unicode_cpts_from_utf8(text); std::vector bpe_offsets = { cpts.size() }; @@ -708,91 +680,272 @@ std::vector unicode_regex_split(const std::string & text, const std continue; } - // fallback to general-purpose std::regex / std::wregex - try { - // if a unicode category is used in the regex, we use the collapsed text and replace the unicode category - // with the corresponding collapsed representation - bool use_collapsed = false; - for (auto & ucat : k_ucat_enum) { - if (std::string::npos != regex_expr.find(ucat.first)) { - use_collapsed = true; + std::vector> regex_expr_ranges; // start codepoint, last codepoint + std::vector> regex_expr_categs; // offset, codepoint category + std::map map_categ_wregex; // categ --> regex utf32 string + std::wstring wregex_collapsed; + std::wstring wtext_collapsed; + bool inside_square = false; + bool is_cpt_range = false; + + // common ranges: \w \d + regex_expr_ranges.emplace_back('a', 'z'); + regex_expr_ranges.emplace_back('A', 'Z'); + regex_expr_ranges.emplace_back('0', '9'); + regex_expr_ranges.emplace_back('_', '_'); + // common ranges: \s + for (uint32_t cpt : unicode_vec_whitespace) { + const auto categ_prev = unicode_cpt_category(regex_expr_ranges.back().second); + const auto categ_last = unicode_cpt_category(cpt); + if (categ_prev == categ_last && regex_expr_ranges.back().second + 1 == cpt) { + regex_expr_ranges.back().second = cpt; + } else { + regex_expr_ranges.emplace_back(cpt, cpt); + } + } + + // std::wregex \s does not match non-ASCII whitespaces + static const codepoint_categ categ_whitespace(codepoint_categ::MASK + 1); // UNDEF category, subcategory 1 + std::wstring & wregex_whitespaces = map_categ_wregex[categ_whitespace.get_subcategory()]; + wregex_whitespaces += L"\\s"; + for (uint32_t cpt : unicode_vec_whitespace) { + if (cpt >= 0x80) { // non-ASCII whitespaces + if (wregex_whitespaces.back() + 1 == cpt) { + if (*(wregex_whitespaces.end() - 2) == '-') { + wregex_whitespaces.back() = cpt; + } else { + wregex_whitespaces += '-'; + wregex_whitespaces += cpt; + } + } else { + wregex_whitespaces += cpt; + } + } + } + + const auto cpts_regex = unicode_cpts_from_utf8(regex_expr); + + for (size_t i = 0; i < cpts_regex.size(); ++i) { + uint32_t cpt = cpts_regex[i]; + + if (inside_square) { + switch(cpt) { + case '^': + if (cpts_regex[i - 1] != '[') { + break; + } + continue; + case ']': + inside_square = false; + continue; + case '-': + is_cpt_range = true; + continue; + } + } else { + switch(cpt) { + case '^': + if (i > 0) { + break; + } + continue; + case '$': + if (i + 1 < cpts_regex.size()) { + break; + } + continue; + case '[': + inside_square = true; + continue; + case '{': + while (cpt && cpt != '}') { + cpt = cpts_regex[++i]; + } + continue; + case '}': + case ']': + assert (false); + case '(': + if (cpts_regex[i + 1] == '?') { // (?: (?i: (?= (?! (?<= (?= 128) { - throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported"); - } - } + assert (cpt < COLLAPSE_CPT_RANGE_FIRST || COLLAPSE_CPT_RANGE_LAST < cpt); - // generate a collapsed representation of the regex - std::string regex_expr_collapsed; - - // track if we are inside [], because nested [] are not allowed - bool inside = false; - for (size_t i = 0; i < regex_expr.size(); ++i) { - if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) { - regex_expr_collapsed += '['; - inside = true; - continue; - } - - if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') { - regex_expr_collapsed += ']'; - inside = false; - continue; - } - - if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() && - regex_expr[i + 1] == 'p' && - regex_expr[i + 2] == '{' && - regex_expr[i + 4] == '}') { - const std::string pat = regex_expr.substr(i, 5); - if (k_ucat_enum.find(pat) != k_ucat_enum.end()) { - if (!inside) { - regex_expr_collapsed += '['; - } - regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat)); - regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat)); - if (!inside) { - regex_expr_collapsed += ']'; - } - i += 4; - continue; - } - } - - regex_expr_collapsed += regex_expr[i]; - } - - //printf("text_collapsed: %s\n", text_collapsed.c_str()); - //printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str()); - bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets); + if (is_cpt_range) { + is_cpt_range = false; + regex_expr_ranges.back().second = cpt; } else { - // no unicode category used, we can use std::wregex directly - const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr); - - // std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback - std::wstring wtext(cpts.begin(), cpts.end()); - for (size_t i = 0; i < wtext.size(); ++i) { - if (wtext[i] > 0x7F && unicode_cpt_category(wtext[i]).is_whitespace()) { - wtext[i] = 0x0B; - } - } - - //printf("text: %s\n", text.c_str()); - //printf("regex_expr: %s\n", regex_expr.c_str()); - bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets); + regex_expr_ranges.emplace_back(cpt, cpt); } - } catch (std::regex_error & e) { - fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str()); - fprintf(stderr, "Regex error: %s\n", e.what()); - throw std::runtime_error("Failed to process regex"); } + + // assign collapsed codepoint to each category regex \p{...} + for (auto offset_categ : regex_expr_categs) { + const uint16_t subcateg = offset_categ.second.get_subcategory(); + auto it = map_categ_wregex.find(subcateg); + if (it == map_categ_wregex.end()) { + const auto collapsed_range = category_to_collapsed_range(offset_categ.second); + map_categ_wregex[subcateg] = (wchar_t) collapsed_range.first; + if (collapsed_range.first < collapsed_range.second) { + map_categ_wregex[subcateg] += (wchar_t) '-'; + map_categ_wregex[subcateg] += (wchar_t) collapsed_range.second; + } + } + } + + // copy found regex ranges to each category regex + uint32_t regex_expr_ranges_uniques = 0; + std::pair prev_range = {0, -1}; + std::sort(regex_expr_ranges.begin(), regex_expr_ranges.end()); + for (auto range : regex_expr_ranges) { + range.first = std::max(range.first, prev_range.second + 1); // prevent overlapping //TODO: as error? + if (range.first > range.second) { // skip overlapping and repetitions + continue; + } + codepoint_categ categ = unicode_cpt_category(range.first); + assert (categ == unicode_cpt_category(range.second)); + auto it0 = map_categ_wregex.find(categ.get_category()); + auto it1 = map_categ_wregex.find(categ.get_subcategory()); + for (const auto & it : {it0, it1}) { + if (it != map_categ_wregex.end()) { + it->second += (wchar_t) range.first; + if (range.first < range.second) { + it->second += (wchar_t) '-'; + it->second += (wchar_t) range.second; + } + } + } + prev_range = range; + regex_expr_ranges[regex_expr_ranges_uniques++] = range; + } + regex_expr_ranges.resize(regex_expr_ranges_uniques); + + // replace categories with respective collapsed codepoint and ranges + uint32_t i = 0; + wregex_collapsed.reserve(regex_expr.size()); + for (auto offset_categ : regex_expr_categs) { + while (i < offset_categ.first) { // copy original regex until reaching the category + wregex_collapsed += (wchar_t) cpts_regex[i]; + i++; + } + assert (cpts_regex[i] == '\\'); + const uint32_t cpt_next = cpts_regex[i + 1]; + const bool is_negated = cpt_next < 'a'; // is uppercase + if (cpt_next == 'p' || cpt_next == 'P') { + assert (cpts_regex[i + 2] == '{' && cpts_regex[i + 3]); + i += cpts_regex[i + 4] == '}' ? 5 : 6; + assert (cpts_regex[i - 1] == '}'); + } else { + assert (cpt_next == 's' || cpt_next == 'w' || cpt_next == 'd' || // \s \w \d + cpt_next == 'S' || cpt_next == 'W' || cpt_next == 'D'); // \S \W \D + i += 2; + } + const codepoint_categ categ = offset_categ.second; + auto it = map_categ_wregex.find(categ.get_subcategory()); + assert (it != map_categ_wregex.end()); + if (it != map_categ_wregex.end()) { + if (categ.is_whitespace()) { // inside square brackets //NOTE: reusing flag WHITESPACE + assert (is_negated == false); + wregex_collapsed += it->second; + } else if(it->second.size() == 1 && !is_negated) { + wregex_collapsed += it->second; + } else { + wregex_collapsed += '['; + if (is_negated) { + wregex_collapsed += '^'; + } + wregex_collapsed += it->second; + wregex_collapsed += ']'; + } + } + } + while (i < (uint32_t)cpts_regex.size()) { + wregex_collapsed += cpts_regex[i]; + i++; + } + + // collapse text codepoints not included in 'regex_expr_ranges' + wtext_collapsed.reserve(cpts.size()); + for (uint32_t cpt : cpts) { + const codepoint_categ categ = unicode_cpt_category(cpt); + auto it = std::lower_bound(regex_expr_ranges.begin(), regex_expr_ranges.end(), cpt, + [] (const std::pair range, const uint32_t cpt) { + return range.second < cpt; + } + ); + if (it == regex_expr_ranges.end() || cpt < it->first || it->second < cpt) { + cpt = category_to_collapsed_cpt(categ); // not found, collapse to category codepoint + } + wtext_collapsed += (wchar_t) cpt; + } + + bpe_offsets = unicode_regex_split_stl(wtext_collapsed, wregex_collapsed, bpe_offsets); } std::vector bpe_words;