Reimplement unicode_regex_split():
- Using std::basic_regex. - Custom std::ctype specialization for 32bits codepoints. - Custom std::regex_traits specialization for 32bits codepoints. - Implementing custom 'character class expression' for \p{Xx}. - Single pass regex preparation.
This commit is contained in:
parent
b565148cb4
commit
5a93d2ec50
2 changed files with 279 additions and 335 deletions
597
src/unicode.cpp
597
src/unicode.cpp
|
@ -451,66 +451,6 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||
return bpe_offsets;
|
||||
}
|
||||
|
||||
// use std::wregex to split the text
|
||||
static std::vector<size_t> unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector<size_t> & offsets) {
|
||||
std::wregex expr(regex_expr);
|
||||
std::vector<size_t> bpe_offsets; // store the offset of each word
|
||||
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
||||
size_t start = 0;
|
||||
for (auto offset : offsets) {
|
||||
std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
|
||||
std::wcregex_iterator end;
|
||||
|
||||
int64_t start_idx = 0;
|
||||
while (it != end) {
|
||||
std::wcmatch match = *it;
|
||||
if (match.position() > start_idx) {
|
||||
bpe_offsets.emplace_back(match.position() - start_idx);
|
||||
}
|
||||
bpe_offsets.emplace_back(match.length());
|
||||
start_idx = match.position() + match.length();
|
||||
++it;
|
||||
}
|
||||
|
||||
if (start_idx < (int64_t) offset) {
|
||||
bpe_offsets.emplace_back(offset - start_idx);
|
||||
}
|
||||
start += offset;
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
}
|
||||
|
||||
// use std::regex to split the text
|
||||
static std::vector<size_t> unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
|
||||
std::regex expr(regex_expr);
|
||||
std::vector<size_t> bpe_offsets; // store the offset of each word
|
||||
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
||||
size_t start = 0;
|
||||
for (auto offset : offsets) {
|
||||
std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
|
||||
std::cregex_iterator end;
|
||||
|
||||
int64_t start_idx = 0;
|
||||
while (it != end) {
|
||||
std::cmatch match = *it;
|
||||
if (match.position() > start_idx) {
|
||||
bpe_offsets.emplace_back(match.position() - start_idx);
|
||||
}
|
||||
bpe_offsets.emplace_back(match.length());
|
||||
start_idx = match.position() + match.length();
|
||||
++it;
|
||||
}
|
||||
|
||||
if (start_idx < (int64_t) offset) {
|
||||
bpe_offsets.emplace_back(offset - start_idx);
|
||||
}
|
||||
start += offset;
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
}
|
||||
|
||||
static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
|
||||
std::vector<size_t> bpe_offsets;
|
||||
|
||||
|
@ -526,6 +466,261 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
|
|||
return bpe_offsets;
|
||||
}
|
||||
|
||||
// Custom std::regex specializations for 32bit unicode codepoints
|
||||
// std::wregex does not support unicode categories: \p{N}, \p{L}, \p{Lu}, \p{Ll} ...
|
||||
// std::wregex does not support unicode whitespaces \s: 0x85, 0xA0, 0x001680 ... 0x003000.
|
||||
// std::wregex supports full 32 bit codepoints, not limited to standard max 0x110000.
|
||||
namespace std {
|
||||
using codepoint = uint32_t; // codepoint type for all template specializations
|
||||
|
||||
// Minimal required implementation for std::regex string processing
|
||||
template<> // custom specialized std::ctype<codepoint>
|
||||
class ctype<codepoint> {
|
||||
public:
|
||||
|
||||
using CharT = codepoint;
|
||||
using char_type = CharT;
|
||||
|
||||
using mask = uint8_t; //NOTE: see std::ctype_base
|
||||
static const mask digit = 1; // requiered variable names
|
||||
static const mask xdigit = 2; // user defined values
|
||||
static const mask alpha = 3; // used to be a bitmask
|
||||
static const mask upper = 4; // we do not need a bitmask
|
||||
static const mask lower = 5; // using a sequence instead
|
||||
|
||||
static locale::id id; // required by std::locale::facet
|
||||
|
||||
bool is(mask m, char_type c) const {
|
||||
switch (m) {
|
||||
case digit: return ('0' <= c && c <= '9');
|
||||
case xdigit: return ('0' <= c && c <= '9') || ('A' <= c && c <= 'F');
|
||||
case alpha: return ('A' <= c && c <= 'Z') || ('a' <= c && c <= 'z');
|
||||
case upper: return ('A' <= c && c <= 'Z');
|
||||
case lower: return ('a' <= c && c <= 'z');
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
||||
char_type toupper(char_type c) const {
|
||||
return ('a' <= c && c <= 'z') ? c - ('a' - 'A') : c;
|
||||
}
|
||||
|
||||
char_type tolower(char_type c) const {
|
||||
return ('A' <= c && c <= 'Z') ? c + ('a' - 'A') : c;
|
||||
}
|
||||
|
||||
char_type widen(char c) const { // char to codepoint
|
||||
return (char_type) c;
|
||||
}
|
||||
|
||||
char narrow(char_type c, char dfault) const { // codepoint to char
|
||||
return (c < 0x80 ? (char)c : dfault);
|
||||
}
|
||||
};
|
||||
|
||||
locale::id ctype<codepoint>::id = {};
|
||||
|
||||
template<> // specialization to use our custom specialized std::ctype<codepoint>
|
||||
const std::ctype<codepoint> & use_facet<std::ctype<codepoint>>(const std::locale &) {
|
||||
static std::ctype<codepoint> ctype_uint32 = {};
|
||||
return ctype_uint32;
|
||||
}
|
||||
|
||||
template<> // specialization to use our custom specialized std::ctype<codepoint>
|
||||
const std::ctype<codepoint> & use_facet<const std::ctype<codepoint>>(const std::locale & loc) {
|
||||
return use_facet<std::ctype<codepoint>>(loc);
|
||||
}
|
||||
|
||||
// Minimal required implementation for std::regex string processing
|
||||
template<> // custom specialized std::regex_traits<codepoint>
|
||||
class regex_traits<codepoint> {
|
||||
public:
|
||||
|
||||
using CharT = codepoint;
|
||||
using char_type = codepoint;
|
||||
using size_type = size_t;
|
||||
using string_type = std::basic_string<CharT>;
|
||||
using locale_type = std::locale;
|
||||
using char_class_type = uint64_t;
|
||||
|
||||
#if (defined(_WIN32) || defined(_WIN64)) // MSVC class _Regex_traits
|
||||
using _Uelem = CharT;
|
||||
static const auto _Ch_upper = std::ctype<CharT>::upper;
|
||||
static const auto _Ch_alpha = std::ctype<CharT>::alpha;
|
||||
#endif
|
||||
|
||||
static size_type length(const CharT * str) {
|
||||
return std::char_traits<CharT>::length(str);
|
||||
}
|
||||
|
||||
CharT translate(CharT c) const {
|
||||
return c;
|
||||
}
|
||||
|
||||
CharT translate_nocase(CharT c) const {
|
||||
return unicode_tolower(c);
|
||||
}
|
||||
|
||||
template<typename It>
|
||||
string_type transform(It first, It last) const {
|
||||
GGML_ASSERT(false); //TODO: not needed ?
|
||||
return {first, last}; //TODO: not tested
|
||||
}
|
||||
|
||||
template<typename It>
|
||||
string_type transform_primary(It first, It last) const {
|
||||
(void) first;
|
||||
(void) last;
|
||||
GGML_ASSERT(*first < MAX_CODEPOINTS); // valid codepoint
|
||||
return {};
|
||||
}
|
||||
|
||||
template<typename It>
|
||||
string_type lookup_collatename(It first, It last) const {
|
||||
(void) last;
|
||||
GGML_ASSERT(*first & (1 << 31));
|
||||
return {*first};
|
||||
}
|
||||
|
||||
template<typename It>
|
||||
char_class_type lookup_classname(It first, It last, bool icase = false) const {
|
||||
(void) last;
|
||||
(void) icase;
|
||||
const uint32_t encoded = *first;
|
||||
codepoint_categ categ = {};
|
||||
switch(encoded) {
|
||||
case 's':
|
||||
case 'S': // negation is internally tracked
|
||||
categ.set_flag(codepoint_categ::WHITESPACES);
|
||||
return categ.expand_bits();
|
||||
case 'w':
|
||||
case 'W': // negation is internally tracked
|
||||
categ.set_flag(codepoint_categ::WORDS);
|
||||
return categ.expand_bits();
|
||||
case 'd':
|
||||
case 'D': // negation is internally tracked
|
||||
categ.set_flag(codepoint_categ::DIGITS);
|
||||
return categ.expand_bits();
|
||||
default: { // unicode category \p{Xx} encoded in codepoint
|
||||
GGML_ASSERT(encoded & (1 << 31)); // make sure its our custom codepoint encoding the category
|
||||
const bool negated = encoded & (1 << 30); // negation of 'character class expression' are not internally tracked
|
||||
categ = {(uint16_t) encoded};
|
||||
return ((uint64_t) negated << 63) | categ.expand_bits(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool isctype(CharT c, char_class_type mask) const {
|
||||
const bool negated = mask & (1llu << 63);
|
||||
mask &= unicode_cpt_category(c).expand_bits();
|
||||
return negated ^ (bool) mask;
|
||||
}
|
||||
|
||||
int value(CharT c, int radix) const { // char to int value
|
||||
switch (radix) {
|
||||
case 8: return ('0' <= c && c <= '7') ? (int)c - '0' : -1;
|
||||
case 10: return ('0' <= c && c <= '9') ? (int)c - '0' : -1;
|
||||
case 16: return ('0' <= c && c <= '9') ? (int)c - '0' : (('A' <= c && c <= 'F') ? (int)c - 'A' + 10 : -1);
|
||||
default: return -1;
|
||||
}
|
||||
}
|
||||
|
||||
const locale_type & imbue(const locale_type &) { // set locale //NOTE: ignoring locales
|
||||
return std::locale::classic();
|
||||
}
|
||||
|
||||
const locale_type & getloc() const { // get locale //NOTE: ignoring locales
|
||||
return std::locale::classic();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
static std::vector<uint32_t> unicode_regex_prepare(const std::string & regex) {
|
||||
std::vector<uint32_t> regex_cpts;
|
||||
regex_cpts.reserve(regex.size() * 12 / 10); // estimate +20%
|
||||
|
||||
size_t offset = 0;
|
||||
int inside_square = 0;
|
||||
bool any_positive = false;
|
||||
bool any_negative = false;
|
||||
|
||||
const size_t size = regex.size();
|
||||
while (offset < size) {
|
||||
inside_square += regex[offset] == '[';
|
||||
inside_square -= regex[offset] == ']';
|
||||
GGML_ASSERT(inside_square >= 0);
|
||||
if (!inside_square) {
|
||||
any_positive = false;
|
||||
any_negative = false;
|
||||
}
|
||||
|
||||
if (regex[offset] == '\\') {
|
||||
const size_t i = offset + 1;
|
||||
if (regex[i] == 'p' || regex[i] == 'P') {
|
||||
// convert \p{Xx} to custom 'character class expression' [:Xy:]
|
||||
if (regex[i + 1] == '{' && regex[i + 2] && regex[i + 3]) {
|
||||
codepoint_categ categ = {};
|
||||
if (regex[i + 3] == '}') {
|
||||
categ = codepoint_categ::from_chars(regex[i + 2]);
|
||||
offset += 5;
|
||||
} else if (regex[i + 3] != '}' && regex[i + 4] == '}') {
|
||||
categ = codepoint_categ::from_chars(regex[i + 2], regex[i + 3]);
|
||||
offset += 6;
|
||||
}
|
||||
bool negated = regex[i] == 'P';
|
||||
any_positive |= !negated;
|
||||
any_negative |= negated;
|
||||
GGML_ASSERT(any_positive != any_negative); //BUG: can not mix 'p' and 'P' inside []
|
||||
GGML_ASSERT(sizeof(categ) <= 2);
|
||||
// encoded category in 32 bits codepoint
|
||||
uint32_t cpt_categ = (1 << 31) | (negated << 30) | categ.encoded;
|
||||
if (inside_square) {
|
||||
regex_cpts.insert(regex_cpts.end(), {'[', ':', cpt_categ, ':', ']'});
|
||||
} else {
|
||||
regex_cpts.insert(regex_cpts.end(), {'[', '[', ':', cpt_categ, ':', ']', ']'});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
regex_cpts.push_back(unicode_cpt_from_utf8(regex, offset));
|
||||
}
|
||||
|
||||
return regex_cpts;
|
||||
}
|
||||
|
||||
// use std::basic_regex<uint32_t> to split the text codepoints
|
||||
static std::vector<size_t> unicode_regex_split_stl(const std::vector<uint32_t> & text_cpts, const std::vector<uint32_t> & regex_cpts, const std::vector<size_t> & offsets) {
|
||||
using regex_type = std::basic_regex<uint32_t>;
|
||||
using iter_type = std::regex_iterator<const uint32_t *>;
|
||||
regex_type regex(regex_cpts.begin(), regex_cpts.end());
|
||||
const iter_type end;
|
||||
|
||||
std::vector<size_t> bpe_offsets; // store the offset of each word
|
||||
bpe_offsets.reserve(offsets.size()); // reserve memory for the approximate size
|
||||
const uint32_t * text_data = text_cpts.data();
|
||||
for (auto offset : offsets) {
|
||||
iter_type it(text_data, text_data + offset, regex);
|
||||
int64_t start_idx = 0;
|
||||
while (it != end) {
|
||||
if (it->position() > start_idx) {
|
||||
bpe_offsets.emplace_back(it->position() - start_idx);
|
||||
}
|
||||
bpe_offsets.emplace_back(it->length());
|
||||
start_idx = it->position() + it->length();
|
||||
++it;
|
||||
}
|
||||
|
||||
if (start_idx < (int64_t) offset) {
|
||||
bpe_offsets.emplace_back(offset - start_idx);
|
||||
}
|
||||
text_data += offset;
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
}
|
||||
|
||||
//
|
||||
// interface
|
||||
//
|
||||
|
@ -639,288 +834,21 @@ uint32_t unicode_tolower(uint32_t cp) {
|
|||
return it == unicode_map_lowercase.end() ? cp : it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
|
||||
// std::wregex does not support unicode categories: \p{N}, \p{L}, \p{Lu}, \p{Ll} ...
|
||||
// std::wregex does not support unicode whitespaces \s: 0x85, 0xA0, 0x001680 ... 0x003000.
|
||||
// std::wregex allows full wchar_t 32 bit codepoints, not limited to standard max 0x110000.
|
||||
// The main idea is to insert unicode category bits into all regex and text codepoints.
|
||||
// Max unicode codepoint 0x110000 fits in 21 bits.
|
||||
// Store unicode category and subcategory in 10 bits.
|
||||
// Set the high bit to zero to keep wchar_t positive (uint32_t codepoints).
|
||||
// Categorized codepoint:
|
||||
// 1 bit zero + 7 bits category + 3 bits subcategory index + 21 bits codepoint
|
||||
// 0b0'XXXXXXX'xxx'ccccccccccccccccccccc
|
||||
// A "categorized codepoint" re-defines the ordering keeping category hierarchy.
|
||||
// All high category codepoints \p{X} fall into the range:
|
||||
// 0b0'XXXXXXX'000'000000000000000000000
|
||||
// 0b0'XXXXXXX'111'111111111111111111111
|
||||
// All subcategory codepoints \p{Xx} fall into the range:
|
||||
// 0b0'XXXXXXX'xxx'000000000000000000000
|
||||
// 0b0'XXXXXXX'xxx'111111111111111111111
|
||||
// Processing steps:
|
||||
// Build a lists of "categorized codepoints/ranges" for replacing regex \s \w and \d.
|
||||
// Replace all regex codepoints/ranges with respective "categorized codepoints/ranges".
|
||||
// Replace all text codepoints with respective "categorized codepoints".
|
||||
// Caveats:
|
||||
// Some regex ranges starts and ends with different category/subcategory.
|
||||
// Split the ranges in sub-ranges to ensure a single category to maintain the new hierarchy.
|
||||
// This forces iterating all ranges and could produce long sub-range sequences.
|
||||
|
||||
//TODO: Regex processing can be cached.
|
||||
|
||||
// insert unicode category and subcategory before codepoint bits
|
||||
// 1 bit zero + 7 bits category + 3 bits subcategory index + 21 bits zero
|
||||
static const auto categorized_prefix = [] (const codepoint_categ categ) -> wchar_t {
|
||||
static const uint32_t MASK = codepoint_categ::MASK; // category mask
|
||||
static const uint32_t SUBMASK = codepoint_categ::SUBMASK & ~codepoint_categ::MASK; // subcategory mask
|
||||
return (wchar_t) (((categ.encoded & MASK) << (21+3)) | ((categ.encoded & SUBMASK) << (21-7)));
|
||||
};
|
||||
|
||||
// insert unicode category and subcategory before codepoint bits
|
||||
// 1 bit zero + 7 bits category + 3 bits subcategory index + 21 bits codepoint
|
||||
static const auto categorize_codepoint = [] (const uint32_t cpt) -> wchar_t {
|
||||
GGML_ASSERT(cpt < (1 << 21));
|
||||
return categorized_prefix(unicode_cpt_category(cpt)) | (wchar_t)cpt;
|
||||
};
|
||||
|
||||
// remove the categorized prefix bits and restore original codepoint bits
|
||||
static const auto decategorize_codepoint = [] (const wchar_t cpt) -> uint32_t {
|
||||
return (uint32_t) cpt & ((1 << 21) - 1);
|
||||
};
|
||||
|
||||
// returns the respective categorized codepoint range of the category/subcategory
|
||||
static const auto categorize_range_from_chars = [] (const char categ, const char subcateg) {
|
||||
const wchar_t range_ini = categorized_prefix(codepoint_categ::from_chars(categ, subcateg));
|
||||
const wchar_t range_end = (wchar_t) (range_ini | (subcateg ? (1<<21)-1 : (1<<24)-1));
|
||||
return std::pair<wchar_t, wchar_t>(range_ini, range_end);
|
||||
};
|
||||
|
||||
// helper function to append/concat regex expressions
|
||||
auto wregex_append_subregex = [] (std::wstring & wregex, const std::wstring & subregex, const bool add_squares, const bool negated) {
|
||||
if (add_squares) {
|
||||
wregex += '[';
|
||||
if (negated) {
|
||||
wregex += '^';
|
||||
}
|
||||
wregex += subregex;
|
||||
wregex += ']';
|
||||
} else {
|
||||
GGML_ASSERT(!negated); //TODO: negation inside square brackets: \S \W \D
|
||||
wregex += subregex;
|
||||
}
|
||||
};
|
||||
|
||||
// \d digits replacement
|
||||
static const std::wstring wregex_digits = {
|
||||
categorize_codepoint('0'), '-', categorize_codepoint('9'),
|
||||
};
|
||||
|
||||
// \w words replacement
|
||||
static const std::wstring wregex_words = {
|
||||
categorize_codepoint('_'),
|
||||
categorize_codepoint('0'), '-', categorize_codepoint('9'),
|
||||
categorize_codepoint('A'), '-', categorize_codepoint('Z'),
|
||||
categorize_codepoint('a'), '-', categorize_codepoint('z'),
|
||||
};
|
||||
|
||||
// \s whitespaces replacement
|
||||
static const std::wstring wregex_whitespaces = [] {
|
||||
std::wstring wregex_whitespaces;
|
||||
for (const auto & range : unicode_ranges_whitespace) {
|
||||
wregex_whitespaces += categorize_codepoint(range.first);
|
||||
if (range.second > range.first) {
|
||||
wregex_whitespaces += '-';
|
||||
wregex_whitespaces += categorize_codepoint(range.second);
|
||||
}
|
||||
}
|
||||
return wregex_whitespaces;
|
||||
}();
|
||||
|
||||
GGML_ASSERT(sizeof(wchar_t) == sizeof(uint32_t));
|
||||
std::wstring wtext = unicode_wstring_from_utf8(text);
|
||||
|
||||
std::vector<size_t> offsets = { wtext.size() };
|
||||
std::vector<std::string> unicode_regex_split(const std::string & text_utf8, const std::vector<std::string> & regex_exprs) {
|
||||
const std::vector<uint32_t> cpts = unicode_cpts_from_utf8(text_utf8);
|
||||
std::vector<size_t> offsets = { cpts.size() };
|
||||
|
||||
for (auto & regex_expr : regex_exprs) {
|
||||
// first, see if we have an efficient custom regex implementation
|
||||
auto tmp = unicode_regex_split_custom(text, regex_expr, offsets);
|
||||
auto tmp = unicode_regex_split_custom(text_utf8, regex_expr, offsets);
|
||||
|
||||
if (!tmp.empty()) {
|
||||
offsets = std::move(tmp);
|
||||
continue;
|
||||
}
|
||||
|
||||
std::wstring wregex;
|
||||
bool inside_square = false;
|
||||
bool is_cpt_range = false;
|
||||
|
||||
const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
|
||||
wregex.reserve(2 * cpts_regex.size());
|
||||
|
||||
for (size_t i = 0; i < cpts_regex.size(); ++i) {
|
||||
uint32_t cpt = cpts_regex[i];
|
||||
|
||||
// parse regex metacharacters
|
||||
wregex += (wchar_t) cpt;
|
||||
if (inside_square) {
|
||||
switch(cpt) {
|
||||
case '^':
|
||||
if (cpts_regex[i - 1] != '[') {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
case ']':
|
||||
inside_square = false;
|
||||
continue;
|
||||
case '-':
|
||||
is_cpt_range = true;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
switch(cpt) {
|
||||
case '^':
|
||||
if (i > 0) {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
case '$':
|
||||
if (i + 1 < cpts_regex.size()) {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
case '[':
|
||||
inside_square = true;
|
||||
continue;
|
||||
case '{':
|
||||
while (cpt && cpt != '}') {
|
||||
cpt = cpts_regex[++i];
|
||||
wregex += (wchar_t) cpt;
|
||||
}
|
||||
continue;
|
||||
case '}':
|
||||
case ']':
|
||||
GGML_ABORT("invalid regex");
|
||||
case '(':
|
||||
if (cpts_regex[i + 1] == '?') { // (?: (?i: (?= (?! (?<= (?<!
|
||||
if (cpts_regex[i + 2] == ':') {
|
||||
wregex += (wchar_t) cpts_regex[++i];
|
||||
wregex += (wchar_t) cpts_regex[++i];
|
||||
} else if (cpts_regex[i + 2] == 'i') {
|
||||
wregex += (wchar_t) cpts_regex[++i];
|
||||
wregex += (wchar_t) cpts_regex[++i];
|
||||
wregex += (wchar_t) cpts_regex[++i];
|
||||
GGML_ASSERT(cpts_regex[i] == ':');
|
||||
} else {
|
||||
wregex += (wchar_t) cpts_regex[++i];
|
||||
wregex += (wchar_t) cpts_regex[++i];
|
||||
if (cpts_regex[i] == '<') {
|
||||
wregex += (wchar_t) cpts_regex[++i];
|
||||
}
|
||||
GGML_ASSERT(cpts_regex[i] == '=' || cpts_regex[i] == '!');
|
||||
}
|
||||
}
|
||||
continue;
|
||||
case ')':
|
||||
case '|':
|
||||
case '.':
|
||||
case '?':
|
||||
case '+':
|
||||
case '*':
|
||||
continue;
|
||||
}
|
||||
}
|
||||
wregex.pop_back();
|
||||
|
||||
// parse unicode categories and subcategories, replace category with the categorized range
|
||||
if (cpt == '\\' && cpts_regex[i + 1] == 'p' && cpts_regex[i + 2] == '{') {
|
||||
GGML_ASSERT(cpts_regex[i + 3] && cpts_regex[i + 4]);
|
||||
std::pair<wchar_t, wchar_t> range;
|
||||
if (cpts_regex[i + 4] == '}') {
|
||||
range = categorize_range_from_chars((char)cpts_regex[i + 3], (char)'\0');
|
||||
i += 4;
|
||||
} else {
|
||||
range = categorize_range_from_chars((char)cpts_regex[i + 3], (char)cpts_regex[i + 4]);
|
||||
i += 5;
|
||||
}
|
||||
GGML_ASSERT(cpts_regex[i] == '}');
|
||||
const std::wstring subregex = {range.first, '-', range.second};
|
||||
wregex_append_subregex(wregex, subregex, !inside_square, false);
|
||||
continue;
|
||||
}
|
||||
|
||||
// parse more metcharacters and espaped characters
|
||||
if (cpt == '\\') {
|
||||
switch (cpts_regex[i + 1]) {
|
||||
case 's': // \s whitespaces
|
||||
case 'S': // \S no whitespaces
|
||||
wregex_append_subregex(wregex, wregex_whitespaces, !inside_square, cpts_regex[++i] == 'S');
|
||||
continue;
|
||||
case 'w': // \w words
|
||||
case 'W': // \W no words
|
||||
wregex_append_subregex(wregex, wregex_words, !inside_square, cpts_regex[++i] == 'W');
|
||||
continue;
|
||||
case 'd': // \d digits
|
||||
case 'D': // \D no digits
|
||||
wregex_append_subregex(wregex, wregex_digits, !inside_square, cpts_regex[++i] == 'D');
|
||||
continue;
|
||||
case 't': ++i; cpt = '\t'; break;
|
||||
case 'r': ++i; cpt = '\r'; break;
|
||||
case 'n': ++i; cpt = '\n'; break;
|
||||
case 'x': GGML_ABORT("TODO"); //TODO: hex values
|
||||
case 'u': GGML_ABORT("TODO"); //TODO: unicode values
|
||||
case 'U': GGML_ABORT("TODO"); //TODO: unicode values
|
||||
default: // escaped character
|
||||
GGML_ASSERT(!is_cpt_range);
|
||||
cpt = cpts_regex[++i];
|
||||
GGML_ASSERT(cpt < 0x80);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_cpt_range) {
|
||||
// Some regex ranges starts and ends with different category/subcategory.
|
||||
// Split the ranges in sub-ranges to ensure a single category to maintain the new hierarchy.
|
||||
// Warning: This forces iterating all ranges and could produce long sub-range sequences.
|
||||
GGML_ASSERT(wregex.size() && wregex.back() == '-');
|
||||
wregex.pop_back();
|
||||
wchar_t categorized = wregex.back();
|
||||
uint32_t range_ini = decategorize_codepoint(categorized);
|
||||
const uint32_t range_end = cpt;
|
||||
GGML_ASSERT(range_ini <= range_end);
|
||||
codepoint_categ range_categ = unicode_cpt_category(range_ini);
|
||||
for (cpt = range_ini + 1; cpt <= range_end; ++cpt) {
|
||||
codepoint_categ categ = unicode_cpt_category(cpt);
|
||||
if (categ == range_categ) { // still same range category ?
|
||||
++categorized;
|
||||
if (cpt == range_ini + 1) { // single step, no need range
|
||||
wregex += categorized;
|
||||
} else if (cpt == range_ini + 2) { // need range if +2 step
|
||||
wregex.back() = '-';
|
||||
wregex += categorized;
|
||||
} else {
|
||||
wregex.back() = categorized; // keep range growing
|
||||
}
|
||||
} else { // new range category
|
||||
categorized = categorize_codepoint(cpt);
|
||||
wregex += categorized;
|
||||
range_categ = categ;
|
||||
range_ini = cpt;
|
||||
}
|
||||
}
|
||||
is_cpt_range = false;
|
||||
} else {
|
||||
wregex += categorize_codepoint(cpt);
|
||||
}
|
||||
}
|
||||
|
||||
// categorize all wtext codepoints
|
||||
if (wtext.size() && wtext[0] < MAX_CODEPOINTS) { // if not already categorized
|
||||
for (size_t i = 0; i < wtext.size(); ++i) {
|
||||
wtext[i] = categorize_codepoint((uint32_t) wtext[i]);
|
||||
}
|
||||
}
|
||||
|
||||
offsets = unicode_regex_split_stl(wtext, wregex, offsets);
|
||||
const auto regex_cpts = unicode_regex_prepare(regex_expr);
|
||||
offsets = unicode_regex_split_stl(cpts, regex_cpts, offsets);
|
||||
}
|
||||
|
||||
std::vector<std::string> bpe_words;
|
||||
|
@ -930,8 +858,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|||
for (size_t & offset : offsets) {
|
||||
bpe_words.emplace_back();
|
||||
for (size_t i = start; i < start + offset; ++i) {
|
||||
const uint32_t cpt = decategorize_codepoint(wtext[i]);
|
||||
bpe_words.back() += unicode_cpt_to_utf8(cpt);
|
||||
bpe_words.back() += unicode_cpt_to_utf8(cpts[i]);
|
||||
}
|
||||
start += offset;
|
||||
}
|
||||
|
|
|
@ -113,6 +113,23 @@ struct codepoint_categ {
|
|||
inline bool is_Zp() const { return (encoded & MASK) == Zp; }
|
||||
inline bool is_Zs() const { return (encoded & MASK) == Zs; }
|
||||
|
||||
inline uint64_t expand_bits(const bool add_categ=true) const { // one bit for each category/subcateory and flags
|
||||
const uint32_t subindex = encoded & SUBMASK;
|
||||
const uint64_t bits = (encoded & MASK) >> 3;
|
||||
const uint64_t flags = encoded >> 10;
|
||||
return (flags << (7 * 8)) | (bits << (7 * subindex)) | (bits * add_categ);
|
||||
}
|
||||
|
||||
inline bool is_in_range(const codepoint_categ other) const { // this.first <= other <= this.last
|
||||
if (encoded & SUBMASK) {
|
||||
return encoded == other.encoded; // no range
|
||||
}
|
||||
if (encoded & MASK) {
|
||||
return encoded == (other.encoded & ~SUBMASK); // from 0bffffff'ccccccc'000 to 0bffffff'ccccccc'111
|
||||
}
|
||||
return encoded == (other.encoded & ~MASK); // from 0bffffff'0000000'000 to 0bffffff'1111111'111
|
||||
}
|
||||
|
||||
inline bool operator == (const codepoint_categ other) const {
|
||||
return encoded == other.encoded;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue