Reimplement 'collapsed' unicode categories:

- Add all unicode categories.
- Fix \s with non-ASCII problem.
This commit is contained in:
jaime-m-p 2024-07-26 00:43:43 +02:00
parent 8f7d56ec5b
commit 1cd7ac090b

View file

@ -636,66 +636,38 @@ uint32_t unicode_tolower(uint32_t cp) {
}
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
// unicode categories
static const std::map<std::string, int> k_ucat_enum = {
{ "\\p{N}", codepoint_categ::N },
{ "\\p{L}", codepoint_categ::L },
{ "\\p{P}", codepoint_categ::P },
};
static const std::map<int, int> k_ucat_cpt = {
{ codepoint_categ::N, 0xD1 },
{ codepoint_categ::L, 0xD2 },
{ codepoint_categ::P, 0xD3 },
};
static const std::map<int, std::string> k_ucat_map = {
{ codepoint_categ::N, "\x30-\x39" }, // 0-9
{ codepoint_categ::L, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ codepoint_categ::P, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
};
// compute collapsed codepoints only if needed by at least one regex
bool need_collapse = false;
for (auto & regex_expr : regex_exprs) {
// search for unicode categories
for (const auto & ucat : k_ucat_enum) {
if (std::string::npos != regex_expr.find(ucat.first)) {
need_collapse = true;
break;
}
}
}
const auto cpts = unicode_cpts_from_utf8(text);
//TODO: update and add more comments
// generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
// ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
std::string text_collapsed;
if (need_collapse) {
// collapse all unicode categories
text_collapsed.resize(cpts.size());
for (size_t i = 0; i < cpts.size(); ++i) {
// keep single-byte codepoints as is
if (cpts[i] < 128) {
text_collapsed[i] = cpts[i];
continue;
}
const auto categ = unicode_cpt_category(cpts[i]);
if (categ.is_whitespace()) {
//NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
//text_collapsed[i] = (char) 0x85; // <Next Line> as whitespace fallback
text_collapsed[i] = (char) 0x0B; // <vertical tab> as whitespace fallback
} else if (k_ucat_cpt.find(categ.get_category()) != k_ucat_cpt.end()) {
text_collapsed[i] = k_ucat_cpt.at(categ.get_category());
} else {
text_collapsed[i] = (char) 0xD0; // fallback
}
// 0xDB80 to 0xDBFF: Private Use High Surrogate (128 range values)
static const uint32_t COLLAPSE_CPT_RANGE_FIRST = 0xDB80;
static const uint32_t COLLAPSE_CPT_RANGE_LAST = 0xDBFF;
auto category_to_collapsed_cpt = [] (const codepoint_categ categ) {
const uint16_t subindex = categ.get_subcategory() >> 7; // subcategory stored in 3 bits
switch(categ.get_category()) { // category fits in other 3 bits
case codepoint_categ::UNDEF: return COLLAPSE_CPT_RANGE_FIRST + ((0 << 3) | subindex);
case codepoint_categ::C: return COLLAPSE_CPT_RANGE_FIRST + ((1 << 3) | subindex);
case codepoint_categ::L: return COLLAPSE_CPT_RANGE_FIRST + ((2 << 3) | subindex);
case codepoint_categ::M: return COLLAPSE_CPT_RANGE_FIRST + ((3 << 3) | subindex);
case codepoint_categ::N: return COLLAPSE_CPT_RANGE_FIRST + ((4 << 3) | subindex);
case codepoint_categ::P: return COLLAPSE_CPT_RANGE_FIRST + ((5 << 3) | subindex);
case codepoint_categ::S: return COLLAPSE_CPT_RANGE_FIRST + ((6 << 3) | subindex);
case codepoint_categ::Z: return COLLAPSE_CPT_RANGE_FIRST + ((7 << 3) | subindex);
default: assert (false); return COLLAPSE_CPT_RANGE_FIRST;
}
}
};
auto category_to_collapsed_range = [&] (const codepoint_categ categ) {
// \p{Ll} --> \p{Ll} to \p{Ll} // has subcategory ? yes
// \p{Lu} --> \p{Lu} to \p{Lu} // has subcategory ? yes
// \p{L} --> \p{Ll} to \p{Lu} // has subcategory ? no
assert ((COLLAPSE_CPT_RANGE_FIRST & 0b111) == 0);
const uint32_t collapsed = category_to_collapsed_cpt(categ);
const uint32_t range = (collapsed & 0b111) ? 0 : 0b111; // has subcategory ?
return std::pair<uint32_t, uint32_t>(collapsed, collapsed + range);
};
const auto cpts = unicode_cpts_from_utf8(text);
std::vector<size_t> bpe_offsets = { cpts.size() };
@ -708,91 +680,272 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
continue;
}
// fallback to general-purpose std::regex / std::wregex
try {
// if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
// with the corresponding collapsed representation
bool use_collapsed = false;
for (auto & ucat : k_ucat_enum) {
if (std::string::npos != regex_expr.find(ucat.first)) {
use_collapsed = true;
std::vector<std::pair<uint32_t, uint32_t>> regex_expr_ranges; // start codepoint, last codepoint
std::vector<std::pair<uint32_t, codepoint_categ>> regex_expr_categs; // offset, codepoint category
std::map<uint16_t, std::wstring> map_categ_wregex; // categ --> regex utf32 string
std::wstring wregex_collapsed;
std::wstring wtext_collapsed;
bool inside_square = false;
bool is_cpt_range = false;
// common ranges: \w \d
regex_expr_ranges.emplace_back('a', 'z');
regex_expr_ranges.emplace_back('A', 'Z');
regex_expr_ranges.emplace_back('0', '9');
regex_expr_ranges.emplace_back('_', '_');
// common ranges: \s
for (uint32_t cpt : unicode_vec_whitespace) {
const auto categ_prev = unicode_cpt_category(regex_expr_ranges.back().second);
const auto categ_last = unicode_cpt_category(cpt);
if (categ_prev == categ_last && regex_expr_ranges.back().second + 1 == cpt) {
regex_expr_ranges.back().second = cpt;
} else {
regex_expr_ranges.emplace_back(cpt, cpt);
}
}
// std::wregex \s does not match non-ASCII whitespaces
static const codepoint_categ categ_whitespace(codepoint_categ::MASK + 1); // UNDEF category, subcategory 1
std::wstring & wregex_whitespaces = map_categ_wregex[categ_whitespace.get_subcategory()];
wregex_whitespaces += L"\\s";
for (uint32_t cpt : unicode_vec_whitespace) {
if (cpt >= 0x80) { // non-ASCII whitespaces
if (wregex_whitespaces.back() + 1 == cpt) {
if (*(wregex_whitespaces.end() - 2) == '-') {
wregex_whitespaces.back() = cpt;
} else {
wregex_whitespaces += '-';
wregex_whitespaces += cpt;
}
} else {
wregex_whitespaces += cpt;
}
}
}
const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
for (size_t i = 0; i < cpts_regex.size(); ++i) {
uint32_t cpt = cpts_regex[i];
if (inside_square) {
switch(cpt) {
case '^':
if (cpts_regex[i - 1] != '[') {
break;
}
continue;
case ']':
inside_square = false;
continue;
case '-':
is_cpt_range = true;
continue;
}
} else {
switch(cpt) {
case '^':
if (i > 0) {
break;
}
continue;
case '$':
if (i + 1 < cpts_regex.size()) {
break;
}
continue;
case '[':
inside_square = true;
continue;
case '{':
while (cpt && cpt != '}') {
cpt = cpts_regex[++i];
}
continue;
case '}':
case ']':
assert (false);
case '(':
if (cpts_regex[i + 1] == '?') { // (?: (?i: (?= (?! (?<= (?<!
if (cpts_regex[i + 2] == ':') {
i += 2;
} else if (cpts_regex[i + 2] == 'i') {
i += 3;
assert (cpts_regex[i] == ':');
} else {
i += 2 + (cpts_regex[i + 2] == '<');
assert (cpts_regex[i] == '=' || cpts_regex[i] == '!');
}
}
continue;
case ')':
case '|':
case '.':
case '?':
case '+':
case '*':
continue;
}
}
if (cpt == '\\' && cpts_regex[i + 1] == 'p' && cpts_regex[i + 2] == '{') {
assert (cpts_regex[i + 3] && cpts_regex[i + 4]);
codepoint_categ categ = {};
if (cpts_regex[i + 4] == '}') {
categ = codepoint_categ::from_chars((char)cpts_regex[i + 3]);
} else {
categ = codepoint_categ::from_chars((char)cpts_regex[i + 3], (char)cpts_regex[i + 4]);
assert (cpts_regex[i + 5] == '}');
}
categ.set_flag(codepoint_categ::WHITESPACE, inside_square); //NOTE: reusing flag 'WHITESPACE' to store 'inside square brackets'
regex_expr_categs.emplace_back(i, categ);
i += cpts_regex[i + 4] == '}' ? 4 : 5;
continue;
}
if (cpt == '\\') {
if (cpts_regex[i + 1] == 's' || cpts_regex[i + 1] == 'S') { // \s \S
regex_expr_categs.emplace_back(i, categ_whitespace);
//NOTE: reusing flag 'WHITESPACE' to store 'inside square brackets'
regex_expr_categs.back().second.set_flag(codepoint_categ::WHITESPACE, inside_square);
i += 1;
continue;
}
}
if (cpt == '\\') {
switch (cpts_regex[i + 1]) {
case 's': ++i; continue; // \s whitespaces
case 'w': ++i; continue; // \w words
case 'd': ++i; continue; // \d digits
case 'S': ++i; continue; // \S no whitespaces
case 'W': ++i; continue; // \W no words
case 'D': ++i; continue; // \D no digits
case 't': ++i; cpt = '\t'; break;
case 'r': ++i; cpt = '\r'; break;
case 'n': ++i; cpt = '\n'; break;
case 'x': assert (false); break; //TODO: hex values
case 'u': assert (false); break; //TODO: unicode values
case 'U': assert (false); break; //TODO: unicode values
default: // escaped character
assert (!is_cpt_range);
cpt = cpts_regex[++i];
assert (cpt < 0x80);
break;
}
}
if (use_collapsed) {
// sanity-check that the original regex does not contain any non-ASCII characters
const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
for (size_t i = 0; i < cpts_regex.size(); ++i) {
if (cpts_regex[i] >= 128) {
throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported");
}
}
assert (cpt < COLLAPSE_CPT_RANGE_FIRST || COLLAPSE_CPT_RANGE_LAST < cpt);
// generate a collapsed representation of the regex
std::string regex_expr_collapsed;
// track if we are inside [], because nested [] are not allowed
bool inside = false;
for (size_t i = 0; i < regex_expr.size(); ++i) {
if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) {
regex_expr_collapsed += '[';
inside = true;
continue;
}
if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') {
regex_expr_collapsed += ']';
inside = false;
continue;
}
if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() &&
regex_expr[i + 1] == 'p' &&
regex_expr[i + 2] == '{' &&
regex_expr[i + 4] == '}') {
const std::string pat = regex_expr.substr(i, 5);
if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
if (!inside) {
regex_expr_collapsed += '[';
}
regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
if (!inside) {
regex_expr_collapsed += ']';
}
i += 4;
continue;
}
}
regex_expr_collapsed += regex_expr[i];
}
//printf("text_collapsed: %s\n", text_collapsed.c_str());
//printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str());
bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
if (is_cpt_range) {
is_cpt_range = false;
regex_expr_ranges.back().second = cpt;
} else {
// no unicode category used, we can use std::wregex directly
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
std::wstring wtext(cpts.begin(), cpts.end());
for (size_t i = 0; i < wtext.size(); ++i) {
if (wtext[i] > 0x7F && unicode_cpt_category(wtext[i]).is_whitespace()) {
wtext[i] = 0x0B;
}
}
//printf("text: %s\n", text.c_str());
//printf("regex_expr: %s\n", regex_expr.c_str());
bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
regex_expr_ranges.emplace_back(cpt, cpt);
}
} catch (std::regex_error & e) {
fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str());
fprintf(stderr, "Regex error: %s\n", e.what());
throw std::runtime_error("Failed to process regex");
}
// assign collapsed codepoint to each category regex \p{...}
for (auto offset_categ : regex_expr_categs) {
const uint16_t subcateg = offset_categ.second.get_subcategory();
auto it = map_categ_wregex.find(subcateg);
if (it == map_categ_wregex.end()) {
const auto collapsed_range = category_to_collapsed_range(offset_categ.second);
map_categ_wregex[subcateg] = (wchar_t) collapsed_range.first;
if (collapsed_range.first < collapsed_range.second) {
map_categ_wregex[subcateg] += (wchar_t) '-';
map_categ_wregex[subcateg] += (wchar_t) collapsed_range.second;
}
}
}
// copy found regex ranges to each category regex
uint32_t regex_expr_ranges_uniques = 0;
std::pair<uint32_t, uint32_t> prev_range = {0, -1};
std::sort(regex_expr_ranges.begin(), regex_expr_ranges.end());
for (auto range : regex_expr_ranges) {
range.first = std::max(range.first, prev_range.second + 1); // prevent overlapping //TODO: as error?
if (range.first > range.second) { // skip overlapping and repetitions
continue;
}
codepoint_categ categ = unicode_cpt_category(range.first);
assert (categ == unicode_cpt_category(range.second));
auto it0 = map_categ_wregex.find(categ.get_category());
auto it1 = map_categ_wregex.find(categ.get_subcategory());
for (const auto & it : {it0, it1}) {
if (it != map_categ_wregex.end()) {
it->second += (wchar_t) range.first;
if (range.first < range.second) {
it->second += (wchar_t) '-';
it->second += (wchar_t) range.second;
}
}
}
prev_range = range;
regex_expr_ranges[regex_expr_ranges_uniques++] = range;
}
regex_expr_ranges.resize(regex_expr_ranges_uniques);
// replace categories with respective collapsed codepoint and ranges
uint32_t i = 0;
wregex_collapsed.reserve(regex_expr.size());
for (auto offset_categ : regex_expr_categs) {
while (i < offset_categ.first) { // copy original regex until reaching the category
wregex_collapsed += (wchar_t) cpts_regex[i];
i++;
}
assert (cpts_regex[i] == '\\');
const uint32_t cpt_next = cpts_regex[i + 1];
const bool is_negated = cpt_next < 'a'; // is uppercase
if (cpt_next == 'p' || cpt_next == 'P') {
assert (cpts_regex[i + 2] == '{' && cpts_regex[i + 3]);
i += cpts_regex[i + 4] == '}' ? 5 : 6;
assert (cpts_regex[i - 1] == '}');
} else {
assert (cpt_next == 's' || cpt_next == 'w' || cpt_next == 'd' || // \s \w \d
cpt_next == 'S' || cpt_next == 'W' || cpt_next == 'D'); // \S \W \D
i += 2;
}
const codepoint_categ categ = offset_categ.second;
auto it = map_categ_wregex.find(categ.get_subcategory());
assert (it != map_categ_wregex.end());
if (it != map_categ_wregex.end()) {
if (categ.is_whitespace()) { // inside square brackets //NOTE: reusing flag WHITESPACE
assert (is_negated == false);
wregex_collapsed += it->second;
} else if(it->second.size() == 1 && !is_negated) {
wregex_collapsed += it->second;
} else {
wregex_collapsed += '[';
if (is_negated) {
wregex_collapsed += '^';
}
wregex_collapsed += it->second;
wregex_collapsed += ']';
}
}
}
while (i < (uint32_t)cpts_regex.size()) {
wregex_collapsed += cpts_regex[i];
i++;
}
// collapse text codepoints not included in 'regex_expr_ranges'
wtext_collapsed.reserve(cpts.size());
for (uint32_t cpt : cpts) {
const codepoint_categ categ = unicode_cpt_category(cpt);
auto it = std::lower_bound(regex_expr_ranges.begin(), regex_expr_ranges.end(), cpt,
[] (const std::pair<uint32_t, uint32_t> range, const uint32_t cpt) {
return range.second < cpt;
}
);
if (it == regex_expr_ranges.end() || cpt < it->first || it->second < cpt) {
cpt = category_to_collapsed_cpt(categ); // not found, collapse to category codepoint
}
wtext_collapsed += (wchar_t) cpt;
}
bpe_offsets = unicode_regex_split_stl(wtext_collapsed, wregex_collapsed, bpe_offsets);
}
std::vector<std::string> bpe_words;