diff --git a/src/unicode.cpp b/src/unicode.cpp index 20c1287c4..988dd35e4 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -451,66 +451,6 @@ static std::vector unicode_regex_split_custom_llama3(const std::string & return bpe_offsets; } -// use std::wregex to split the text -static std::vector unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector & offsets) { - std::wregex expr(regex_expr); - std::vector bpe_offsets; // store the offset of each word - bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size - size_t start = 0; - for (auto offset : offsets) { - std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr); - std::wcregex_iterator end; - - int64_t start_idx = 0; - while (it != end) { - std::wcmatch match = *it; - if (match.position() > start_idx) { - bpe_offsets.emplace_back(match.position() - start_idx); - } - bpe_offsets.emplace_back(match.length()); - start_idx = match.position() + match.length(); - ++it; - } - - if (start_idx < (int64_t) offset) { - bpe_offsets.emplace_back(offset - start_idx); - } - start += offset; - } - - return bpe_offsets; -} - -// use std::regex to split the text -static std::vector unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector & offsets) { - std::regex expr(regex_expr); - std::vector bpe_offsets; // store the offset of each word - bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size - size_t start = 0; - for (auto offset : offsets) { - std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr); - std::cregex_iterator end; - - int64_t start_idx = 0; - while (it != end) { - std::cmatch match = *it; - if (match.position() > start_idx) { - bpe_offsets.emplace_back(match.position() - start_idx); - } - bpe_offsets.emplace_back(match.length()); - start_idx = match.position() + match.length(); - ++it; - } - - if (start_idx < (int64_t) offset) { - bpe_offsets.emplace_back(offset - start_idx); - } - start += offset; - } - - return bpe_offsets; -} - static std::vector unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector & offsets) { std::vector bpe_offsets; @@ -526,6 +466,261 @@ static std::vector unicode_regex_split_custom(const std::string & text, return bpe_offsets; } +// Custom std::regex specializations for 32bit unicode codepoints +// std::wregex does not support unicode categories: \p{N}, \p{L}, \p{Lu}, \p{Ll} ... +// std::wregex does not support unicode whitespaces \s: 0x85, 0xA0, 0x001680 ... 0x003000. +// std::wregex supports full 32 bit codepoints, not limited to standard max 0x110000. +namespace std { + using codepoint = uint32_t; // codepoint type for all template specializations + + // Minimal required implementation for std::regex string processing + template<> // custom specialized std::ctype + class ctype { + public: + + using CharT = codepoint; + using char_type = CharT; + + using mask = uint8_t; //NOTE: see std::ctype_base + static const mask digit = 1; // requiered variable names + static const mask xdigit = 2; // user defined values + static const mask alpha = 3; // used to be a bitmask + static const mask upper = 4; // we do not need a bitmask + static const mask lower = 5; // using a sequence instead + + static locale::id id; // required by std::locale::facet + + bool is(mask m, char_type c) const { + switch (m) { + case digit: return ('0' <= c && c <= '9'); + case xdigit: return ('0' <= c && c <= '9') || ('A' <= c && c <= 'F'); + case alpha: return ('A' <= c && c <= 'Z') || ('a' <= c && c <= 'z'); + case upper: return ('A' <= c && c <= 'Z'); + case lower: return ('a' <= c && c <= 'z'); + default: return false; + } + } + + char_type toupper(char_type c) const { + return ('a' <= c && c <= 'z') ? c - ('a' - 'A') : c; + } + + char_type tolower(char_type c) const { + return ('A' <= c && c <= 'Z') ? c + ('a' - 'A') : c; + } + + char_type widen(char c) const { // char to codepoint + return (char_type) c; + } + + char narrow(char_type c, char dfault) const { // codepoint to char + return (c < 0x80 ? (char)c : dfault); + } + }; + + locale::id ctype::id = {}; + + template<> // specialization to use our custom specialized std::ctype + const std::ctype & use_facet>(const std::locale &) { + static std::ctype ctype_uint32 = {}; + return ctype_uint32; + } + + template<> // specialization to use our custom specialized std::ctype + const std::ctype & use_facet>(const std::locale & loc) { + return use_facet>(loc); + } + + // Minimal required implementation for std::regex string processing + template<> // custom specialized std::regex_traits + class regex_traits { + public: + + using CharT = codepoint; + using char_type = codepoint; + using size_type = size_t; + using string_type = std::basic_string; + using locale_type = std::locale; + using char_class_type = uint64_t; + + #if (defined(_WIN32) || defined(_WIN64)) // MSVC class _Regex_traits + using _Uelem = CharT; + static const auto _Ch_upper = std::ctype::upper; + static const auto _Ch_alpha = std::ctype::alpha; + #endif + + static size_type length(const CharT * str) { + return std::char_traits::length(str); + } + + CharT translate(CharT c) const { + return c; + } + + CharT translate_nocase(CharT c) const { + return unicode_tolower(c); + } + + template + string_type transform(It first, It last) const { + GGML_ASSERT(false); //TODO: not needed ? + return {first, last}; //TODO: not tested + } + + template + string_type transform_primary(It first, It last) const { + (void) first; + (void) last; + GGML_ASSERT(*first < MAX_CODEPOINTS); // valid codepoint + return {}; + } + + template + string_type lookup_collatename(It first, It last) const { + (void) last; + GGML_ASSERT(*first & (1 << 31)); + return {*first}; + } + + template + char_class_type lookup_classname(It first, It last, bool icase = false) const { + (void) last; + (void) icase; + const uint32_t encoded = *first; + codepoint_categ categ = {}; + switch(encoded) { + case 's': + case 'S': // negation is internally tracked + categ.set_flag(codepoint_categ::WHITESPACES); + return categ.expand_bits(); + case 'w': + case 'W': // negation is internally tracked + categ.set_flag(codepoint_categ::WORDS); + return categ.expand_bits(); + case 'd': + case 'D': // negation is internally tracked + categ.set_flag(codepoint_categ::DIGITS); + return categ.expand_bits(); + default: { // unicode category \p{Xx} encoded in codepoint + GGML_ASSERT(encoded & (1 << 31)); // make sure its our custom codepoint encoding the category + const bool negated = encoded & (1 << 30); // negation of 'character class expression' are not internally tracked + categ = {(uint16_t) encoded}; + return ((uint64_t) negated << 63) | categ.expand_bits(false); + } + } + } + + bool isctype(CharT c, char_class_type mask) const { + const bool negated = mask & (1llu << 63); + mask &= unicode_cpt_category(c).expand_bits(); + return negated ^ (bool) mask; + } + + int value(CharT c, int radix) const { // char to int value + switch (radix) { + case 8: return ('0' <= c && c <= '7') ? (int)c - '0' : -1; + case 10: return ('0' <= c && c <= '9') ? (int)c - '0' : -1; + case 16: return ('0' <= c && c <= '9') ? (int)c - '0' : (('A' <= c && c <= 'F') ? (int)c - 'A' + 10 : -1); + default: return -1; + } + } + + const locale_type & imbue(const locale_type &) { // set locale //NOTE: ignoring locales + return std::locale::classic(); + } + + const locale_type & getloc() const { // get locale //NOTE: ignoring locales + return std::locale::classic(); + } + }; +} + +static std::vector unicode_regex_prepare(const std::string & regex) { + std::vector regex_cpts; + regex_cpts.reserve(regex.size() * 12 / 10); // estimate +20% + + size_t offset = 0; + int inside_square = 0; + bool any_positive = false; + bool any_negative = false; + + const size_t size = regex.size(); + while (offset < size) { + inside_square += regex[offset] == '['; + inside_square -= regex[offset] == ']'; + GGML_ASSERT(inside_square >= 0); + if (!inside_square) { + any_positive = false; + any_negative = false; + } + + if (regex[offset] == '\\') { + const size_t i = offset + 1; + if (regex[i] == 'p' || regex[i] == 'P') { + // convert \p{Xx} to custom 'character class expression' [:Xy:] + if (regex[i + 1] == '{' && regex[i + 2] && regex[i + 3]) { + codepoint_categ categ = {}; + if (regex[i + 3] == '}') { + categ = codepoint_categ::from_chars(regex[i + 2]); + offset += 5; + } else if (regex[i + 3] != '}' && regex[i + 4] == '}') { + categ = codepoint_categ::from_chars(regex[i + 2], regex[i + 3]); + offset += 6; + } + bool negated = regex[i] == 'P'; + any_positive |= !negated; + any_negative |= negated; + GGML_ASSERT(any_positive != any_negative); //BUG: can not mix 'p' and 'P' inside [] + GGML_ASSERT(sizeof(categ) <= 2); + // encoded category in 32 bits codepoint + uint32_t cpt_categ = (1 << 31) | (negated << 30) | categ.encoded; + if (inside_square) { + regex_cpts.insert(regex_cpts.end(), {'[', ':', cpt_categ, ':', ']'}); + } else { + regex_cpts.insert(regex_cpts.end(), {'[', '[', ':', cpt_categ, ':', ']', ']'}); + } + continue; + } + } + } + + regex_cpts.push_back(unicode_cpt_from_utf8(regex, offset)); + } + + return regex_cpts; +} + +// use std::basic_regex to split the text codepoints +static std::vector unicode_regex_split_stl(const std::vector & text_cpts, const std::vector & regex_cpts, const std::vector & offsets) { + using regex_type = std::basic_regex; + using iter_type = std::regex_iterator; + regex_type regex(regex_cpts.begin(), regex_cpts.end()); + const iter_type end; + + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // reserve memory for the approximate size + const uint32_t * text_data = text_cpts.data(); + for (auto offset : offsets) { + iter_type it(text_data, text_data + offset, regex); + int64_t start_idx = 0; + while (it != end) { + if (it->position() > start_idx) { + bpe_offsets.emplace_back(it->position() - start_idx); + } + bpe_offsets.emplace_back(it->length()); + start_idx = it->position() + it->length(); + ++it; + } + + if (start_idx < (int64_t) offset) { + bpe_offsets.emplace_back(offset - start_idx); + } + text_data += offset; + } + + return bpe_offsets; +} + // // interface // @@ -639,288 +834,21 @@ uint32_t unicode_tolower(uint32_t cp) { return it == unicode_map_lowercase.end() ? cp : it->second; } -std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) { - // std::wregex does not support unicode categories: \p{N}, \p{L}, \p{Lu}, \p{Ll} ... - // std::wregex does not support unicode whitespaces \s: 0x85, 0xA0, 0x001680 ... 0x003000. - // std::wregex allows full wchar_t 32 bit codepoints, not limited to standard max 0x110000. - // The main idea is to insert unicode category bits into all regex and text codepoints. - // Max unicode codepoint 0x110000 fits in 21 bits. - // Store unicode category and subcategory in 10 bits. - // Set the high bit to zero to keep wchar_t positive (uint32_t codepoints). - // Categorized codepoint: - // 1 bit zero + 7 bits category + 3 bits subcategory index + 21 bits codepoint - // 0b0'XXXXXXX'xxx'ccccccccccccccccccccc - // A "categorized codepoint" re-defines the ordering keeping category hierarchy. - // All high category codepoints \p{X} fall into the range: - // 0b0'XXXXXXX'000'000000000000000000000 - // 0b0'XXXXXXX'111'111111111111111111111 - // All subcategory codepoints \p{Xx} fall into the range: - // 0b0'XXXXXXX'xxx'000000000000000000000 - // 0b0'XXXXXXX'xxx'111111111111111111111 - // Processing steps: - // Build a lists of "categorized codepoints/ranges" for replacing regex \s \w and \d. - // Replace all regex codepoints/ranges with respective "categorized codepoints/ranges". - // Replace all text codepoints with respective "categorized codepoints". - // Caveats: - // Some regex ranges starts and ends with different category/subcategory. - // Split the ranges in sub-ranges to ensure a single category to maintain the new hierarchy. - // This forces iterating all ranges and could produce long sub-range sequences. - - //TODO: Regex processing can be cached. - - // insert unicode category and subcategory before codepoint bits - // 1 bit zero + 7 bits category + 3 bits subcategory index + 21 bits zero - static const auto categorized_prefix = [] (const codepoint_categ categ) -> wchar_t { - static const uint32_t MASK = codepoint_categ::MASK; // category mask - static const uint32_t SUBMASK = codepoint_categ::SUBMASK & ~codepoint_categ::MASK; // subcategory mask - return (wchar_t) (((categ.encoded & MASK) << (21+3)) | ((categ.encoded & SUBMASK) << (21-7))); - }; - - // insert unicode category and subcategory before codepoint bits - // 1 bit zero + 7 bits category + 3 bits subcategory index + 21 bits codepoint - static const auto categorize_codepoint = [] (const uint32_t cpt) -> wchar_t { - GGML_ASSERT(cpt < (1 << 21)); - return categorized_prefix(unicode_cpt_category(cpt)) | (wchar_t)cpt; - }; - - // remove the categorized prefix bits and restore original codepoint bits - static const auto decategorize_codepoint = [] (const wchar_t cpt) -> uint32_t { - return (uint32_t) cpt & ((1 << 21) - 1); - }; - - // returns the respective categorized codepoint range of the category/subcategory - static const auto categorize_range_from_chars = [] (const char categ, const char subcateg) { - const wchar_t range_ini = categorized_prefix(codepoint_categ::from_chars(categ, subcateg)); - const wchar_t range_end = (wchar_t) (range_ini | (subcateg ? (1<<21)-1 : (1<<24)-1)); - return std::pair(range_ini, range_end); - }; - - // helper function to append/concat regex expressions - auto wregex_append_subregex = [] (std::wstring & wregex, const std::wstring & subregex, const bool add_squares, const bool negated) { - if (add_squares) { - wregex += '['; - if (negated) { - wregex += '^'; - } - wregex += subregex; - wregex += ']'; - } else { - GGML_ASSERT(!negated); //TODO: negation inside square brackets: \S \W \D - wregex += subregex; - } - }; - - // \d digits replacement - static const std::wstring wregex_digits = { - categorize_codepoint('0'), '-', categorize_codepoint('9'), - }; - - // \w words replacement - static const std::wstring wregex_words = { - categorize_codepoint('_'), - categorize_codepoint('0'), '-', categorize_codepoint('9'), - categorize_codepoint('A'), '-', categorize_codepoint('Z'), - categorize_codepoint('a'), '-', categorize_codepoint('z'), - }; - - // \s whitespaces replacement - static const std::wstring wregex_whitespaces = [] { - std::wstring wregex_whitespaces; - for (const auto & range : unicode_ranges_whitespace) { - wregex_whitespaces += categorize_codepoint(range.first); - if (range.second > range.first) { - wregex_whitespaces += '-'; - wregex_whitespaces += categorize_codepoint(range.second); - } - } - return wregex_whitespaces; - }(); - - GGML_ASSERT(sizeof(wchar_t) == sizeof(uint32_t)); - std::wstring wtext = unicode_wstring_from_utf8(text); - - std::vector offsets = { wtext.size() }; +std::vector unicode_regex_split(const std::string & text_utf8, const std::vector & regex_exprs) { + const std::vector cpts = unicode_cpts_from_utf8(text_utf8); + std::vector offsets = { cpts.size() }; for (auto & regex_expr : regex_exprs) { // first, see if we have an efficient custom regex implementation - auto tmp = unicode_regex_split_custom(text, regex_expr, offsets); + auto tmp = unicode_regex_split_custom(text_utf8, regex_expr, offsets); if (!tmp.empty()) { offsets = std::move(tmp); continue; } - std::wstring wregex; - bool inside_square = false; - bool is_cpt_range = false; - - const auto cpts_regex = unicode_cpts_from_utf8(regex_expr); - wregex.reserve(2 * cpts_regex.size()); - - for (size_t i = 0; i < cpts_regex.size(); ++i) { - uint32_t cpt = cpts_regex[i]; - - // parse regex metacharacters - wregex += (wchar_t) cpt; - if (inside_square) { - switch(cpt) { - case '^': - if (cpts_regex[i - 1] != '[') { - break; - } - continue; - case ']': - inside_square = false; - continue; - case '-': - is_cpt_range = true; - continue; - } - } else { - switch(cpt) { - case '^': - if (i > 0) { - break; - } - continue; - case '$': - if (i + 1 < cpts_regex.size()) { - break; - } - continue; - case '[': - inside_square = true; - continue; - case '{': - while (cpt && cpt != '}') { - cpt = cpts_regex[++i]; - wregex += (wchar_t) cpt; - } - continue; - case '}': - case ']': - GGML_ABORT("invalid regex"); - case '(': - if (cpts_regex[i + 1] == '?') { // (?: (?i: (?= (?! (?<= (? range; - if (cpts_regex[i + 4] == '}') { - range = categorize_range_from_chars((char)cpts_regex[i + 3], (char)'\0'); - i += 4; - } else { - range = categorize_range_from_chars((char)cpts_regex[i + 3], (char)cpts_regex[i + 4]); - i += 5; - } - GGML_ASSERT(cpts_regex[i] == '}'); - const std::wstring subregex = {range.first, '-', range.second}; - wregex_append_subregex(wregex, subregex, !inside_square, false); - continue; - } - - // parse more metcharacters and espaped characters - if (cpt == '\\') { - switch (cpts_regex[i + 1]) { - case 's': // \s whitespaces - case 'S': // \S no whitespaces - wregex_append_subregex(wregex, wregex_whitespaces, !inside_square, cpts_regex[++i] == 'S'); - continue; - case 'w': // \w words - case 'W': // \W no words - wregex_append_subregex(wregex, wregex_words, !inside_square, cpts_regex[++i] == 'W'); - continue; - case 'd': // \d digits - case 'D': // \D no digits - wregex_append_subregex(wregex, wregex_digits, !inside_square, cpts_regex[++i] == 'D'); - continue; - case 't': ++i; cpt = '\t'; break; - case 'r': ++i; cpt = '\r'; break; - case 'n': ++i; cpt = '\n'; break; - case 'x': GGML_ABORT("TODO"); //TODO: hex values - case 'u': GGML_ABORT("TODO"); //TODO: unicode values - case 'U': GGML_ABORT("TODO"); //TODO: unicode values - default: // escaped character - GGML_ASSERT(!is_cpt_range); - cpt = cpts_regex[++i]; - GGML_ASSERT(cpt < 0x80); - break; - } - } - - if (is_cpt_range) { - // Some regex ranges starts and ends with different category/subcategory. - // Split the ranges in sub-ranges to ensure a single category to maintain the new hierarchy. - // Warning: This forces iterating all ranges and could produce long sub-range sequences. - GGML_ASSERT(wregex.size() && wregex.back() == '-'); - wregex.pop_back(); - wchar_t categorized = wregex.back(); - uint32_t range_ini = decategorize_codepoint(categorized); - const uint32_t range_end = cpt; - GGML_ASSERT(range_ini <= range_end); - codepoint_categ range_categ = unicode_cpt_category(range_ini); - for (cpt = range_ini + 1; cpt <= range_end; ++cpt) { - codepoint_categ categ = unicode_cpt_category(cpt); - if (categ == range_categ) { // still same range category ? - ++categorized; - if (cpt == range_ini + 1) { // single step, no need range - wregex += categorized; - } else if (cpt == range_ini + 2) { // need range if +2 step - wregex.back() = '-'; - wregex += categorized; - } else { - wregex.back() = categorized; // keep range growing - } - } else { // new range category - categorized = categorize_codepoint(cpt); - wregex += categorized; - range_categ = categ; - range_ini = cpt; - } - } - is_cpt_range = false; - } else { - wregex += categorize_codepoint(cpt); - } - } - - // categorize all wtext codepoints - if (wtext.size() && wtext[0] < MAX_CODEPOINTS) { // if not already categorized - for (size_t i = 0; i < wtext.size(); ++i) { - wtext[i] = categorize_codepoint((uint32_t) wtext[i]); - } - } - - offsets = unicode_regex_split_stl(wtext, wregex, offsets); + const auto regex_cpts = unicode_regex_prepare(regex_expr); + offsets = unicode_regex_split_stl(cpts, regex_cpts, offsets); } std::vector bpe_words; @@ -930,8 +858,7 @@ std::vector unicode_regex_split(const std::string & text, const std for (size_t & offset : offsets) { bpe_words.emplace_back(); for (size_t i = start; i < start + offset; ++i) { - const uint32_t cpt = decategorize_codepoint(wtext[i]); - bpe_words.back() += unicode_cpt_to_utf8(cpt); + bpe_words.back() += unicode_cpt_to_utf8(cpts[i]); } start += offset; } diff --git a/src/unicode.h b/src/unicode.h index 3aeb74771..f2c3e7147 100644 --- a/src/unicode.h +++ b/src/unicode.h @@ -113,6 +113,23 @@ struct codepoint_categ { inline bool is_Zp() const { return (encoded & MASK) == Zp; } inline bool is_Zs() const { return (encoded & MASK) == Zs; } + inline uint64_t expand_bits(const bool add_categ=true) const { // one bit for each category/subcateory and flags + const uint32_t subindex = encoded & SUBMASK; + const uint64_t bits = (encoded & MASK) >> 3; + const uint64_t flags = encoded >> 10; + return (flags << (7 * 8)) | (bits << (7 * subindex)) | (bits * add_categ); + } + + inline bool is_in_range(const codepoint_categ other) const { // this.first <= other <= this.last + if (encoded & SUBMASK) { + return encoded == other.encoded; // no range + } + if (encoded & MASK) { + return encoded == (other.encoded & ~SUBMASK); // from 0bffffff'ccccccc'000 to 0bffffff'ccccccc'111 + } + return encoded == (other.encoded & ~MASK); // from 0bffffff'0000000'000 to 0bffffff'1111111'111 + } + inline bool operator == (const codepoint_categ other) const { return encoded == other.encoded; }