Add more comments

This commit is contained in:
jaime-m-p 2024-08-04 23:22:56 +02:00
parent 1cd7ac090b
commit aeac342132

View file

@ -636,13 +636,47 @@ uint32_t unicode_tolower(uint32_t cp) {
}
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
//TODO: update and add more comments
// generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
// ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
// std::regex does not support unicode categories: \p{N}, \p{L}, \p{Lu}, \p{Ll} ...
// std::regex does not support unicode whitespaces \s: 0x85, 0xA0, 0x001680 ... 0x003000.
// Generate a "collapsed" representation of the regex, where all unicode categories are replaced by codepoints ranges.
// Generate a "collapsed" representation of the text, where all codepoints are forced to fall into generated category ranges.
// Text codepoints not found in generated category ranges are replaced by a "collapsed" codepoint.
// This implementation generalizes the original implementation adding support to unicode subcategories:
// https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
// Definitions:
// - Unicode cagegory: high unicode categories, \p{C}, \p{L}, \p{M}, \p{N}, \p{P}, \p{S}, \p{Z}.
// - Unicode subcagegory: including all unicode categories, \p{Cc}, \p{Cf}, \p{Co}, \p{Cs}, ..., \p{Zs}.
// - Collapsed codepoint: unused codepoint representing a unicode subcategory.
// - Collapsed range: sequence of "collapsed" codepoint, representing one unicode category.
// - Collapsed regex: original regex including "collapsed" codepoints and ranges.
// (1) Build the "collapsed" regex:
// (1.1) Generate a replacement list of codepoint ranges:
// (1.1.1) For each unicode category.
// (1.1.2) For each unicode subcategory.
// (1.1.3) Expand \s adding unicode whitespaces.
// (1.2) Each list includes its respective "collaped" codepoint/range.
// (1.3) [Optimization] Only build lists of categories present in the regex.
// (1.4) Build the "collapsed" regex replacing categories and subcategories by this "collapsed" lists.
// (2) Build a list of codepoint ranges.
// (2.1) If a codepoint is not found in this list, then it is "collapsable".
// (2.2) [Optimization] Only build lists of ranges present in the regex.
// (3) For each input text:
// (3.1) Search codepoints in the regex codepoint ranges.
// (3.2) If found, it is a valid codepoint (the "collapsed" regex uses it), literal copy.
// (3.3) If not found, replace with its "collapsed" codepoint so the "collapsed" regex can process it.
//TODO: Refactor optimizations
// Steps (1) and (2) only depends on the regex expression text.
// Step (3) needs 'regex_expr_ranges' for text "collapsing" and 'wregex_collapsed'.
// Optimization: store and reuse 'wregex_collapsed' and 'regex_expr_ranges'.
// 0xDB80 to 0xDBFF: Private Use High Surrogate (128 range values)
static const uint32_t COLLAPSE_CPT_RANGE_FIRST = 0xDB80;
static const uint32_t COLLAPSE_CPT_RANGE_LAST = 0xDBFF;
// return the collapsed codepoint of an unicode category or subcategory
auto category_to_collapsed_cpt = [] (const codepoint_categ categ) {
const uint16_t subindex = categ.get_subcategory() >> 7; // subcategory stored in 3 bits
switch(categ.get_category()) { // category fits in other 3 bits
@ -657,6 +691,8 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
default: assert (false); return COLLAPSE_CPT_RANGE_FIRST;
}
};
// return the collapsed range of an unicode category (range including all subcategories)
auto category_to_collapsed_range = [&] (const codepoint_categ categ) {
// \p{Ll} --> \p{Ll} to \p{Ll} // has subcategory ? yes
// \p{Lu} --> \p{Lu} to \p{Lu} // has subcategory ? yes
@ -688,11 +724,14 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
bool inside_square = false;
bool is_cpt_range = false;
// (2) Build a list of codepoint ranges
// common ranges: \w \d
regex_expr_ranges.emplace_back('a', 'z');
regex_expr_ranges.emplace_back('A', 'Z');
regex_expr_ranges.emplace_back('0', '9');
regex_expr_ranges.emplace_back('_', '_');
// (2) Build a list of codepoint ranges
// common ranges: \s
for (uint32_t cpt : unicode_vec_whitespace) {
const auto categ_prev = unicode_cpt_category(regex_expr_ranges.back().second);
@ -704,6 +743,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
}
}
// (1.1.3) Expand \s adding unicode whitespaces.
// std::wregex \s does not match non-ASCII whitespaces
static const codepoint_categ categ_whitespace(codepoint_categ::MASK + 1); // UNDEF category, subcategory 1
std::wstring & wregex_whitespaces = map_categ_wregex[categ_whitespace.get_subcategory()];
@ -728,6 +768,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
for (size_t i = 0; i < cpts_regex.size(); ++i) {
uint32_t cpt = cpts_regex[i];
// skip regex metacharacters
if (inside_square) {
switch(cpt) {
case '^':
@ -788,6 +829,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
}
}
// parse unicode categories and subcategories
if (cpt == '\\' && cpts_regex[i + 1] == 'p' && cpts_regex[i + 2] == '{') {
assert (cpts_regex[i + 3] && cpts_regex[i + 4]);
codepoint_categ categ = {};
@ -797,6 +839,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
categ = codepoint_categ::from_chars((char)cpts_regex[i + 3], (char)cpts_regex[i + 4]);
assert (cpts_regex[i + 5] == '}');
}
// (2) Build a list of codepoint ranges. (2.2) [Optimization] Only build lists of ranges present in the regex.
categ.set_flag(codepoint_categ::WHITESPACE, inside_square); //NOTE: reusing flag 'WHITESPACE' to store 'inside square brackets'
regex_expr_categs.emplace_back(i, categ);
i += cpts_regex[i + 4] == '}' ? 4 : 5;
@ -805,6 +848,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
if (cpt == '\\') {
if (cpts_regex[i + 1] == 's' || cpts_regex[i + 1] == 'S') { // \s \S
// (2) Build a list of codepoint ranges. (2.2) [Optimization] Only build lists of ranges present in the regex.
regex_expr_categs.emplace_back(i, categ_whitespace);
//NOTE: reusing flag 'WHITESPACE' to store 'inside square brackets'
regex_expr_categs.back().second.set_flag(codepoint_categ::WHITESPACE, inside_square);
@ -813,6 +857,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
}
}
// parse more metcharacters and espaped characters
if (cpt == '\\') {
switch (cpts_regex[i + 1]) {
case 's': ++i; continue; // \s whitespaces
@ -835,8 +880,10 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
}
}
// ensure there is not a collission with any "collapsed" codepoints
assert (cpt < COLLAPSE_CPT_RANGE_FIRST || COLLAPSE_CPT_RANGE_LAST < cpt);
// (2) Build a list of codepoint ranges
if (is_cpt_range) {
is_cpt_range = false;
regex_expr_ranges.back().second = cpt;
@ -850,6 +897,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
const uint16_t subcateg = offset_categ.second.get_subcategory();
auto it = map_categ_wregex.find(subcateg);
if (it == map_categ_wregex.end()) {
// (1.2) Each list includes its respective "collaped" codepoint/range.
const auto collapsed_range = category_to_collapsed_range(offset_categ.second);
map_categ_wregex[subcateg] = (wchar_t) collapsed_range.first;
if (collapsed_range.first < collapsed_range.second) {
@ -868,6 +916,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
if (range.first > range.second) { // skip overlapping and repetitions
continue;
}
// (1.1) Generate a replacement list of codepoint ranges
codepoint_categ categ = unicode_cpt_category(range.first);
assert (categ == unicode_cpt_category(range.second));
auto it0 = map_categ_wregex.find(categ.get_category());
@ -906,6 +955,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
cpt_next == 'S' || cpt_next == 'W' || cpt_next == 'D'); // \S \W \D
i += 2;
}
// (1.4) Build the "collapsed" regex replacing categories and subcategories by this "collapsed" lists.
const codepoint_categ categ = offset_categ.second;
auto it = map_categ_wregex.find(categ.get_subcategory());
assert (it != map_categ_wregex.end());
@ -934,14 +984,17 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
wtext_collapsed.reserve(cpts.size());
for (uint32_t cpt : cpts) {
const codepoint_categ categ = unicode_cpt_category(cpt);
// (3.1) Search codepoints in the regex codepoint ranges.
auto it = std::lower_bound(regex_expr_ranges.begin(), regex_expr_ranges.end(), cpt,
[] (const std::pair<uint32_t, uint32_t> range, const uint32_t cpt) {
return range.second < cpt;
}
);
if (it == regex_expr_ranges.end() || cpt < it->first || it->second < cpt) {
// (3.3) If not found, replace with its "collapsed" codepoint so the "collapsed" regex can process it.
cpt = category_to_collapsed_cpt(categ); // not found, collapse to category codepoint
}
// (3.2) If found, it is a valid codepoint (the "collapsed" regex uses it), literal copy.
wtext_collapsed += (wchar_t) cpt;
}