Reimplement 'collapsed' unicode categories:
- Add all unicode categories. - Fix \s with non-ASCII problem.
This commit is contained in:
parent
8f7d56ec5b
commit
1cd7ac090b
1 changed files with 286 additions and 133 deletions
419
src/unicode.cpp
419
src/unicode.cpp
|
@ -636,66 +636,38 @@ uint32_t unicode_tolower(uint32_t cp) {
|
|||
}
|
||||
|
||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
|
||||
// unicode categories
|
||||
static const std::map<std::string, int> k_ucat_enum = {
|
||||
{ "\\p{N}", codepoint_categ::N },
|
||||
{ "\\p{L}", codepoint_categ::L },
|
||||
{ "\\p{P}", codepoint_categ::P },
|
||||
};
|
||||
|
||||
static const std::map<int, int> k_ucat_cpt = {
|
||||
{ codepoint_categ::N, 0xD1 },
|
||||
{ codepoint_categ::L, 0xD2 },
|
||||
{ codepoint_categ::P, 0xD3 },
|
||||
};
|
||||
|
||||
static const std::map<int, std::string> k_ucat_map = {
|
||||
{ codepoint_categ::N, "\x30-\x39" }, // 0-9
|
||||
{ codepoint_categ::L, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
||||
{ codepoint_categ::P, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
||||
};
|
||||
|
||||
// compute collapsed codepoints only if needed by at least one regex
|
||||
bool need_collapse = false;
|
||||
for (auto & regex_expr : regex_exprs) {
|
||||
// search for unicode categories
|
||||
for (const auto & ucat : k_ucat_enum) {
|
||||
if (std::string::npos != regex_expr.find(ucat.first)) {
|
||||
need_collapse = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto cpts = unicode_cpts_from_utf8(text);
|
||||
|
||||
//TODO: update and add more comments
|
||||
// generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
|
||||
std::string text_collapsed;
|
||||
if (need_collapse) {
|
||||
// collapse all unicode categories
|
||||
text_collapsed.resize(cpts.size());
|
||||
|
||||
for (size_t i = 0; i < cpts.size(); ++i) {
|
||||
// keep single-byte codepoints as is
|
||||
if (cpts[i] < 128) {
|
||||
text_collapsed[i] = cpts[i];
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto categ = unicode_cpt_category(cpts[i]);
|
||||
|
||||
if (categ.is_whitespace()) {
|
||||
//NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
|
||||
//text_collapsed[i] = (char) 0x85; // <Next Line> as whitespace fallback
|
||||
text_collapsed[i] = (char) 0x0B; // <vertical tab> as whitespace fallback
|
||||
} else if (k_ucat_cpt.find(categ.get_category()) != k_ucat_cpt.end()) {
|
||||
text_collapsed[i] = k_ucat_cpt.at(categ.get_category());
|
||||
} else {
|
||||
text_collapsed[i] = (char) 0xD0; // fallback
|
||||
}
|
||||
// 0xDB80 to 0xDBFF: Private Use High Surrogate (128 range values)
|
||||
static const uint32_t COLLAPSE_CPT_RANGE_FIRST = 0xDB80;
|
||||
static const uint32_t COLLAPSE_CPT_RANGE_LAST = 0xDBFF;
|
||||
auto category_to_collapsed_cpt = [] (const codepoint_categ categ) {
|
||||
const uint16_t subindex = categ.get_subcategory() >> 7; // subcategory stored in 3 bits
|
||||
switch(categ.get_category()) { // category fits in other 3 bits
|
||||
case codepoint_categ::UNDEF: return COLLAPSE_CPT_RANGE_FIRST + ((0 << 3) | subindex);
|
||||
case codepoint_categ::C: return COLLAPSE_CPT_RANGE_FIRST + ((1 << 3) | subindex);
|
||||
case codepoint_categ::L: return COLLAPSE_CPT_RANGE_FIRST + ((2 << 3) | subindex);
|
||||
case codepoint_categ::M: return COLLAPSE_CPT_RANGE_FIRST + ((3 << 3) | subindex);
|
||||
case codepoint_categ::N: return COLLAPSE_CPT_RANGE_FIRST + ((4 << 3) | subindex);
|
||||
case codepoint_categ::P: return COLLAPSE_CPT_RANGE_FIRST + ((5 << 3) | subindex);
|
||||
case codepoint_categ::S: return COLLAPSE_CPT_RANGE_FIRST + ((6 << 3) | subindex);
|
||||
case codepoint_categ::Z: return COLLAPSE_CPT_RANGE_FIRST + ((7 << 3) | subindex);
|
||||
default: assert (false); return COLLAPSE_CPT_RANGE_FIRST;
|
||||
}
|
||||
}
|
||||
};
|
||||
auto category_to_collapsed_range = [&] (const codepoint_categ categ) {
|
||||
// \p{Ll} --> \p{Ll} to \p{Ll} // has subcategory ? yes
|
||||
// \p{Lu} --> \p{Lu} to \p{Lu} // has subcategory ? yes
|
||||
// \p{L} --> \p{Ll} to \p{Lu} // has subcategory ? no
|
||||
assert ((COLLAPSE_CPT_RANGE_FIRST & 0b111) == 0);
|
||||
const uint32_t collapsed = category_to_collapsed_cpt(categ);
|
||||
const uint32_t range = (collapsed & 0b111) ? 0 : 0b111; // has subcategory ?
|
||||
return std::pair<uint32_t, uint32_t>(collapsed, collapsed + range);
|
||||
};
|
||||
|
||||
const auto cpts = unicode_cpts_from_utf8(text);
|
||||
|
||||
std::vector<size_t> bpe_offsets = { cpts.size() };
|
||||
|
||||
|
@ -708,91 +680,272 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|||
continue;
|
||||
}
|
||||
|
||||
// fallback to general-purpose std::regex / std::wregex
|
||||
try {
|
||||
// if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
|
||||
// with the corresponding collapsed representation
|
||||
bool use_collapsed = false;
|
||||
for (auto & ucat : k_ucat_enum) {
|
||||
if (std::string::npos != regex_expr.find(ucat.first)) {
|
||||
use_collapsed = true;
|
||||
std::vector<std::pair<uint32_t, uint32_t>> regex_expr_ranges; // start codepoint, last codepoint
|
||||
std::vector<std::pair<uint32_t, codepoint_categ>> regex_expr_categs; // offset, codepoint category
|
||||
std::map<uint16_t, std::wstring> map_categ_wregex; // categ --> regex utf32 string
|
||||
std::wstring wregex_collapsed;
|
||||
std::wstring wtext_collapsed;
|
||||
bool inside_square = false;
|
||||
bool is_cpt_range = false;
|
||||
|
||||
// common ranges: \w \d
|
||||
regex_expr_ranges.emplace_back('a', 'z');
|
||||
regex_expr_ranges.emplace_back('A', 'Z');
|
||||
regex_expr_ranges.emplace_back('0', '9');
|
||||
regex_expr_ranges.emplace_back('_', '_');
|
||||
// common ranges: \s
|
||||
for (uint32_t cpt : unicode_vec_whitespace) {
|
||||
const auto categ_prev = unicode_cpt_category(regex_expr_ranges.back().second);
|
||||
const auto categ_last = unicode_cpt_category(cpt);
|
||||
if (categ_prev == categ_last && regex_expr_ranges.back().second + 1 == cpt) {
|
||||
regex_expr_ranges.back().second = cpt;
|
||||
} else {
|
||||
regex_expr_ranges.emplace_back(cpt, cpt);
|
||||
}
|
||||
}
|
||||
|
||||
// std::wregex \s does not match non-ASCII whitespaces
|
||||
static const codepoint_categ categ_whitespace(codepoint_categ::MASK + 1); // UNDEF category, subcategory 1
|
||||
std::wstring & wregex_whitespaces = map_categ_wregex[categ_whitespace.get_subcategory()];
|
||||
wregex_whitespaces += L"\\s";
|
||||
for (uint32_t cpt : unicode_vec_whitespace) {
|
||||
if (cpt >= 0x80) { // non-ASCII whitespaces
|
||||
if (wregex_whitespaces.back() + 1 == cpt) {
|
||||
if (*(wregex_whitespaces.end() - 2) == '-') {
|
||||
wregex_whitespaces.back() = cpt;
|
||||
} else {
|
||||
wregex_whitespaces += '-';
|
||||
wregex_whitespaces += cpt;
|
||||
}
|
||||
} else {
|
||||
wregex_whitespaces += cpt;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
|
||||
|
||||
for (size_t i = 0; i < cpts_regex.size(); ++i) {
|
||||
uint32_t cpt = cpts_regex[i];
|
||||
|
||||
if (inside_square) {
|
||||
switch(cpt) {
|
||||
case '^':
|
||||
if (cpts_regex[i - 1] != '[') {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
case ']':
|
||||
inside_square = false;
|
||||
continue;
|
||||
case '-':
|
||||
is_cpt_range = true;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
switch(cpt) {
|
||||
case '^':
|
||||
if (i > 0) {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
case '$':
|
||||
if (i + 1 < cpts_regex.size()) {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
case '[':
|
||||
inside_square = true;
|
||||
continue;
|
||||
case '{':
|
||||
while (cpt && cpt != '}') {
|
||||
cpt = cpts_regex[++i];
|
||||
}
|
||||
continue;
|
||||
case '}':
|
||||
case ']':
|
||||
assert (false);
|
||||
case '(':
|
||||
if (cpts_regex[i + 1] == '?') { // (?: (?i: (?= (?! (?<= (?<!
|
||||
if (cpts_regex[i + 2] == ':') {
|
||||
i += 2;
|
||||
} else if (cpts_regex[i + 2] == 'i') {
|
||||
i += 3;
|
||||
assert (cpts_regex[i] == ':');
|
||||
} else {
|
||||
i += 2 + (cpts_regex[i + 2] == '<');
|
||||
assert (cpts_regex[i] == '=' || cpts_regex[i] == '!');
|
||||
}
|
||||
}
|
||||
continue;
|
||||
case ')':
|
||||
case '|':
|
||||
case '.':
|
||||
case '?':
|
||||
case '+':
|
||||
case '*':
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (cpt == '\\' && cpts_regex[i + 1] == 'p' && cpts_regex[i + 2] == '{') {
|
||||
assert (cpts_regex[i + 3] && cpts_regex[i + 4]);
|
||||
codepoint_categ categ = {};
|
||||
if (cpts_regex[i + 4] == '}') {
|
||||
categ = codepoint_categ::from_chars((char)cpts_regex[i + 3]);
|
||||
} else {
|
||||
categ = codepoint_categ::from_chars((char)cpts_regex[i + 3], (char)cpts_regex[i + 4]);
|
||||
assert (cpts_regex[i + 5] == '}');
|
||||
}
|
||||
categ.set_flag(codepoint_categ::WHITESPACE, inside_square); //NOTE: reusing flag 'WHITESPACE' to store 'inside square brackets'
|
||||
regex_expr_categs.emplace_back(i, categ);
|
||||
i += cpts_regex[i + 4] == '}' ? 4 : 5;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (cpt == '\\') {
|
||||
if (cpts_regex[i + 1] == 's' || cpts_regex[i + 1] == 'S') { // \s \S
|
||||
regex_expr_categs.emplace_back(i, categ_whitespace);
|
||||
//NOTE: reusing flag 'WHITESPACE' to store 'inside square brackets'
|
||||
regex_expr_categs.back().second.set_flag(codepoint_categ::WHITESPACE, inside_square);
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (cpt == '\\') {
|
||||
switch (cpts_regex[i + 1]) {
|
||||
case 's': ++i; continue; // \s whitespaces
|
||||
case 'w': ++i; continue; // \w words
|
||||
case 'd': ++i; continue; // \d digits
|
||||
case 'S': ++i; continue; // \S no whitespaces
|
||||
case 'W': ++i; continue; // \W no words
|
||||
case 'D': ++i; continue; // \D no digits
|
||||
case 't': ++i; cpt = '\t'; break;
|
||||
case 'r': ++i; cpt = '\r'; break;
|
||||
case 'n': ++i; cpt = '\n'; break;
|
||||
case 'x': assert (false); break; //TODO: hex values
|
||||
case 'u': assert (false); break; //TODO: unicode values
|
||||
case 'U': assert (false); break; //TODO: unicode values
|
||||
default: // escaped character
|
||||
assert (!is_cpt_range);
|
||||
cpt = cpts_regex[++i];
|
||||
assert (cpt < 0x80);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (use_collapsed) {
|
||||
// sanity-check that the original regex does not contain any non-ASCII characters
|
||||
const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
|
||||
for (size_t i = 0; i < cpts_regex.size(); ++i) {
|
||||
if (cpts_regex[i] >= 128) {
|
||||
throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported");
|
||||
}
|
||||
}
|
||||
assert (cpt < COLLAPSE_CPT_RANGE_FIRST || COLLAPSE_CPT_RANGE_LAST < cpt);
|
||||
|
||||
// generate a collapsed representation of the regex
|
||||
std::string regex_expr_collapsed;
|
||||
|
||||
// track if we are inside [], because nested [] are not allowed
|
||||
bool inside = false;
|
||||
for (size_t i = 0; i < regex_expr.size(); ++i) {
|
||||
if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) {
|
||||
regex_expr_collapsed += '[';
|
||||
inside = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') {
|
||||
regex_expr_collapsed += ']';
|
||||
inside = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() &&
|
||||
regex_expr[i + 1] == 'p' &&
|
||||
regex_expr[i + 2] == '{' &&
|
||||
regex_expr[i + 4] == '}') {
|
||||
const std::string pat = regex_expr.substr(i, 5);
|
||||
if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
|
||||
if (!inside) {
|
||||
regex_expr_collapsed += '[';
|
||||
}
|
||||
regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
|
||||
regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
|
||||
if (!inside) {
|
||||
regex_expr_collapsed += ']';
|
||||
}
|
||||
i += 4;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
regex_expr_collapsed += regex_expr[i];
|
||||
}
|
||||
|
||||
//printf("text_collapsed: %s\n", text_collapsed.c_str());
|
||||
//printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str());
|
||||
bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
|
||||
if (is_cpt_range) {
|
||||
is_cpt_range = false;
|
||||
regex_expr_ranges.back().second = cpt;
|
||||
} else {
|
||||
// no unicode category used, we can use std::wregex directly
|
||||
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
|
||||
|
||||
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
|
||||
std::wstring wtext(cpts.begin(), cpts.end());
|
||||
for (size_t i = 0; i < wtext.size(); ++i) {
|
||||
if (wtext[i] > 0x7F && unicode_cpt_category(wtext[i]).is_whitespace()) {
|
||||
wtext[i] = 0x0B;
|
||||
}
|
||||
}
|
||||
|
||||
//printf("text: %s\n", text.c_str());
|
||||
//printf("regex_expr: %s\n", regex_expr.c_str());
|
||||
bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
|
||||
regex_expr_ranges.emplace_back(cpt, cpt);
|
||||
}
|
||||
} catch (std::regex_error & e) {
|
||||
fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str());
|
||||
fprintf(stderr, "Regex error: %s\n", e.what());
|
||||
throw std::runtime_error("Failed to process regex");
|
||||
}
|
||||
|
||||
// assign collapsed codepoint to each category regex \p{...}
|
||||
for (auto offset_categ : regex_expr_categs) {
|
||||
const uint16_t subcateg = offset_categ.second.get_subcategory();
|
||||
auto it = map_categ_wregex.find(subcateg);
|
||||
if (it == map_categ_wregex.end()) {
|
||||
const auto collapsed_range = category_to_collapsed_range(offset_categ.second);
|
||||
map_categ_wregex[subcateg] = (wchar_t) collapsed_range.first;
|
||||
if (collapsed_range.first < collapsed_range.second) {
|
||||
map_categ_wregex[subcateg] += (wchar_t) '-';
|
||||
map_categ_wregex[subcateg] += (wchar_t) collapsed_range.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copy found regex ranges to each category regex
|
||||
uint32_t regex_expr_ranges_uniques = 0;
|
||||
std::pair<uint32_t, uint32_t> prev_range = {0, -1};
|
||||
std::sort(regex_expr_ranges.begin(), regex_expr_ranges.end());
|
||||
for (auto range : regex_expr_ranges) {
|
||||
range.first = std::max(range.first, prev_range.second + 1); // prevent overlapping //TODO: as error?
|
||||
if (range.first > range.second) { // skip overlapping and repetitions
|
||||
continue;
|
||||
}
|
||||
codepoint_categ categ = unicode_cpt_category(range.first);
|
||||
assert (categ == unicode_cpt_category(range.second));
|
||||
auto it0 = map_categ_wregex.find(categ.get_category());
|
||||
auto it1 = map_categ_wregex.find(categ.get_subcategory());
|
||||
for (const auto & it : {it0, it1}) {
|
||||
if (it != map_categ_wregex.end()) {
|
||||
it->second += (wchar_t) range.first;
|
||||
if (range.first < range.second) {
|
||||
it->second += (wchar_t) '-';
|
||||
it->second += (wchar_t) range.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
prev_range = range;
|
||||
regex_expr_ranges[regex_expr_ranges_uniques++] = range;
|
||||
}
|
||||
regex_expr_ranges.resize(regex_expr_ranges_uniques);
|
||||
|
||||
// replace categories with respective collapsed codepoint and ranges
|
||||
uint32_t i = 0;
|
||||
wregex_collapsed.reserve(regex_expr.size());
|
||||
for (auto offset_categ : regex_expr_categs) {
|
||||
while (i < offset_categ.first) { // copy original regex until reaching the category
|
||||
wregex_collapsed += (wchar_t) cpts_regex[i];
|
||||
i++;
|
||||
}
|
||||
assert (cpts_regex[i] == '\\');
|
||||
const uint32_t cpt_next = cpts_regex[i + 1];
|
||||
const bool is_negated = cpt_next < 'a'; // is uppercase
|
||||
if (cpt_next == 'p' || cpt_next == 'P') {
|
||||
assert (cpts_regex[i + 2] == '{' && cpts_regex[i + 3]);
|
||||
i += cpts_regex[i + 4] == '}' ? 5 : 6;
|
||||
assert (cpts_regex[i - 1] == '}');
|
||||
} else {
|
||||
assert (cpt_next == 's' || cpt_next == 'w' || cpt_next == 'd' || // \s \w \d
|
||||
cpt_next == 'S' || cpt_next == 'W' || cpt_next == 'D'); // \S \W \D
|
||||
i += 2;
|
||||
}
|
||||
const codepoint_categ categ = offset_categ.second;
|
||||
auto it = map_categ_wregex.find(categ.get_subcategory());
|
||||
assert (it != map_categ_wregex.end());
|
||||
if (it != map_categ_wregex.end()) {
|
||||
if (categ.is_whitespace()) { // inside square brackets //NOTE: reusing flag WHITESPACE
|
||||
assert (is_negated == false);
|
||||
wregex_collapsed += it->second;
|
||||
} else if(it->second.size() == 1 && !is_negated) {
|
||||
wregex_collapsed += it->second;
|
||||
} else {
|
||||
wregex_collapsed += '[';
|
||||
if (is_negated) {
|
||||
wregex_collapsed += '^';
|
||||
}
|
||||
wregex_collapsed += it->second;
|
||||
wregex_collapsed += ']';
|
||||
}
|
||||
}
|
||||
}
|
||||
while (i < (uint32_t)cpts_regex.size()) {
|
||||
wregex_collapsed += cpts_regex[i];
|
||||
i++;
|
||||
}
|
||||
|
||||
// collapse text codepoints not included in 'regex_expr_ranges'
|
||||
wtext_collapsed.reserve(cpts.size());
|
||||
for (uint32_t cpt : cpts) {
|
||||
const codepoint_categ categ = unicode_cpt_category(cpt);
|
||||
auto it = std::lower_bound(regex_expr_ranges.begin(), regex_expr_ranges.end(), cpt,
|
||||
[] (const std::pair<uint32_t, uint32_t> range, const uint32_t cpt) {
|
||||
return range.second < cpt;
|
||||
}
|
||||
);
|
||||
if (it == regex_expr_ranges.end() || cpt < it->first || it->second < cpt) {
|
||||
cpt = category_to_collapsed_cpt(categ); // not found, collapse to category codepoint
|
||||
}
|
||||
wtext_collapsed += (wchar_t) cpt;
|
||||
}
|
||||
|
||||
bpe_offsets = unicode_regex_split_stl(wtext_collapsed, wregex_collapsed, bpe_offsets);
|
||||
}
|
||||
|
||||
std::vector<std::string> bpe_words;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue