unicode : category support via std::regex

This commit is contained in:
Georgi Gerganov 2024-04-28 13:42:00 +03:00
parent 581c4a0239
commit b97add52a4
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 172 additions and 99 deletions

View file

@ -202,15 +202,10 @@ static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
return conv.from_bytes(s);
}
static inline std::string unicode_wstring_to_utf8(const std::wstring & ws) {
#if defined(_MSC_VER)
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>, wchar_t> converter;
return converter.to_bytes(ws);
#else
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
return conv.to_bytes(ws);
#endif
}
//static inline std::string unicode_wstring_to_utf8(const std::wstring & ws) {
// std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
// return conv.to_bytes(ws);
//}
static std::vector<std::string> unicode_byte_encoding_process(const std::vector<std::string> & bpe_words) {
std::vector<std::string>bpe_encoded_words;
@ -229,16 +224,18 @@ static std::vector<std::string> unicode_byte_encoding_process(const std::vector<
return bpe_encoded_words;
}
static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::wstring & wtext, const std::vector<size_t> & offsets) {
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::string & text, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
for (auto offset : offsets) {
const std::string text = unicode_wstring_to_utf8(std::wstring(wtext, start, offset));
const auto cpts = unicode_cpts_from_utf8(text);
for (auto offset : offsets) {
std::string token;
std::string token = "";
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
bool collecting_numeric = false;
bool collecting_letter = false;
bool collecting_special = false;
@ -246,10 +243,9 @@ static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::wstring & wt
bool collecting = false;
std::vector<std::string> text_utf;
text_utf.reserve(text.size());
text_utf.reserve(offset);
const auto cpts = unicode_cpts_from_utf8(text);
for (size_t i = 0; i < cpts.size(); ++i) {
for (size_t i = start; i < start + offset; ++i) {
text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i]));
}
@ -270,10 +266,10 @@ static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::wstring & wt
}
if (split_condition) {
if (token.size()) {
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
}
token = utf_char + utf_char_next;
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
token = "";
i++;
continue;
@ -291,10 +287,10 @@ static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::wstring & wt
if (split_condition) {
// current token + next token can be defined
if (token.size()) {
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
}
token = utf_char + utf_char_next + utf_char_next_next;
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
token = "";
i += 2;
continue;
@ -347,7 +343,7 @@ static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::wstring & wt
if (split_condition) {
if (token.size()) {
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
}
token = utf_char;
collecting = false;
@ -367,13 +363,13 @@ static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::wstring & wt
return bpe_offsets;
}
static std::vector<size_t> unicode_regex_preprocess(const std::wstring & text, const std::vector<size_t> & offsets, const std::wstring & regex_expr) {
static std::vector<size_t> unicode_regex_preprocess(const std::wstring & wtext, const std::vector<size_t> & offsets, const std::wstring & regex_expr) {
std::wregex expr(regex_expr);
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
for (auto offset : offsets) {
std::wcregex_iterator it(text.data() + start, text.data() + start + offset, expr);
std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
std::wcregex_iterator end;
int64_t start_idx = 0;
@ -396,19 +392,45 @@ static std::vector<size_t> unicode_regex_preprocess(const std::wstring & text, c
return bpe_offsets;
}
static bool unicode_regex_equivalent_wregex_exists(const std::string & regex) {
return unicode_regex_equivalent_wregex.find(regex) != unicode_regex_equivalent_wregex.end();
// use std::regex to split the text
static std::vector<size_t> unicode_regex_preprocess(const std::string & text, const std::vector<size_t> & offsets, const std::string & regex_expr) {
std::regex expr(regex_expr);
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
for (auto offset : offsets) {
std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
std::cregex_iterator end;
int64_t start_idx = 0;
while (it != end) {
std::cmatch match = *it;
if (match.position() > start_idx) {
bpe_offsets.emplace_back(match.position() - start_idx);
}
bpe_offsets.emplace_back(match.length());
start_idx = match.position() + match.length();
++it;
}
if (start_idx < (int64_t) offset) {
bpe_offsets.emplace_back(offset - start_idx);
}
start += offset;
}
return bpe_offsets;
}
static bool unicode_regex_with_custom_preprocessor_exists(const std::string & regex) {
return unicode_regex_with_custom_preprocessor.find(regex) != unicode_regex_with_custom_preprocessor.end();
}
static std::vector<size_t> unicode_regex_custom_preprocess(const std::string & regex, const std::wstring & wtext, const std::vector<size_t> & offsets) {
static std::vector<size_t> unicode_regex_custom_preprocess(const std::string & regex, const std::string & text, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets;
if (regex == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
bpe_offsets = unicode_gpt2_regex_preprocess(wtext, offsets);
bpe_offsets = unicode_gpt2_regex_preprocess(text, offsets);
}
return bpe_offsets;
@ -496,82 +518,144 @@ char32_t unicode_tolower(char32_t cp) {
return it == unicode_map_lowercase.end() ? cp : it->second;
}
static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
std::string result;
for (size_t pos = 0; ; pos += search.length()) {
auto new_pos = s.find(search, pos);
if (new_pos == std::string::npos) {
result += s.substr(pos, s.size() - pos);
break;
}
result += s.substr(pos, new_pos - pos) + replace;
pos = new_pos;
}
s = std::move(result);
}
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
std::wstring wtext = unicode_wstring_from_utf8(text);
// unicode categories
static const std::map<std::string, int> k_ucat_enum = {
{ "\\p{N}", CODEPOINT_TYPE_DIGIT },
{ "\\p{L}", CODEPOINT_TYPE_LETTER },
{ "\\p{P}", CODEPOINT_TYPE_PUNCTUATION },
};
static const std::map<int, int> k_ucat_cpt = {
{ CODEPOINT_TYPE_DIGIT, 0xD1 },
{ CODEPOINT_TYPE_LETTER, 0xD2 },
{ CODEPOINT_TYPE_PUNCTUATION, 0xD3 },
};
static const std::map<int, std::string> k_ucat_map = {
{ CODEPOINT_TYPE_DIGIT, "\x30-\x39" },
{ CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" },
{ CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" },
};
// compute collapsed codepoints only if needed by at least one regex
bool need_collapse = false;
for (auto & regex_expr : regex_exprs) {
// search for \\p{L} or \\p{N}
if (std::string::npos != regex_expr.find("\\p{N}") ||
std::string::npos != regex_expr.find("\\p{L}") ||
std::string::npos != regex_expr.find("\\p{P}")) {
need_collapse = true;
break;
}
}
std::wstring wtext_collapsed;
if (need_collapse) {
// collapse all digit, letter and punctuation cpts to a single codepoint
// the collapsed codepoint is selected to be the one at the end of the range
//
// - convert text to cpts
// - collapse digit cpts to 0x0001FBF9
// - collapse letter cpts to 0x0003134A
// - collapse punctuation cpts to 0x0001E95F
// - convert back to text
auto cpts = unicode_cpts_from_utf8(text);
for (size_t i = 0; i < cpts.size(); ++i) {
if (unicode_cpt_type(cpts[i]) == CODEPOINT_TYPE_DIGIT) {
cpts[i] = 0x0001FBF9;
} else if (unicode_cpt_type(cpts[i]) == CODEPOINT_TYPE_LETTER) {
cpts[i] = 0x0003134A;
} else if (unicode_cpt_type(cpts[i]) == CODEPOINT_TYPE_PUNCTUATION) {
cpts[i] = 0x0001E95F;
// search for unicode categories
for (auto & ucat : k_ucat_enum) {
if (std::string::npos != regex_expr.find(ucat.first)) {
need_collapse = true;
break;
}
}
wtext_collapsed = unicode_wstring_from_utf8(unicode_cpts_to_utf8(cpts));
}
std::vector<size_t> bpe_offsets = {wtext.size()};
std::string text_collapsed;
if (need_collapse) {
// collapse all unicode categories
const auto cpts = unicode_cpts_from_utf8(text);
text_collapsed.resize(cpts.size());
for (size_t i = 0; i < cpts.size(); ++i) {
// keep single-byte codepoints as is
if (cpts[i] < 128) {
text_collapsed[i] = cpts[i];
continue;
}
const int cpt_type = unicode_cpt_type(cpts[i]);
if (k_ucat_cpt.find(cpt_type) != k_ucat_cpt.end()) {
text_collapsed[i] = k_ucat_cpt.at(cpt_type);
} else {
text_collapsed[i] = (char) 0xD0; // fallback
}
}
}
const auto cpts = unicode_cpts_from_utf8(text);
std::vector<size_t> bpe_offsets = { cpts.size() };
for (auto & regex_expr : regex_exprs) {
if (unicode_regex_with_custom_preprocessor_exists(regex_expr)) {
bpe_offsets = unicode_regex_custom_preprocess(regex_expr, wtext, bpe_offsets);
//} else if (unicode_regex_equivalent_wregex_exists(regex_expr)) {
// const std::wstring & wregex_expr = unicode_regex_equivalent_wregex.at(regex_expr);
// bpe_offsets = unicode_regex_preprocess(wtext, bpe_offsets, wregex_expr);
bpe_offsets = unicode_regex_custom_preprocess(regex_expr, text, bpe_offsets);
} else {
// fallback
try {
// if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
// with the corresponding collapsed codepoint
if (std::string::npos != regex_expr.find("\\p{N}") ||
std::string::npos != regex_expr.find("\\p{L}") ||
std::string::npos != regex_expr.find("\\p{P}")) {
// replace \\p{N} with \U0001FBF9
// replace \\p{L} with \U0003134A
// replace \\p{P} with \U0001E95F
std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
for (size_t i = 0; i < wregex_expr.size(); ++i) {
if (wregex_expr[i] == L'\\' && i + 1 < wregex_expr.size()) {
if (wregex_expr[i + 1] == L'p' && i + 3 < wregex_expr.size()) {
if (wregex_expr[i + 2] == L'{' && wregex_expr[i + 4] == L'}') {
if (wregex_expr[i + 3] == L'N') {
wregex_expr.replace(i, 5, L"\U0001FBF9");
} else if (wregex_expr[i + 3] == L'L') {
wregex_expr.replace(i, 5, L"\U0003134A");
} else if (wregex_expr[i + 3] == L'P') {
wregex_expr.replace(i, 5, L"\U0001E95F");
// with the corresponding collapsed representation
bool use_collapsed = false;
for (auto & ucat : k_ucat_enum) {
if (std::string::npos != regex_expr.find(ucat.first)) {
use_collapsed = true;
break;
}
}
if (use_collapsed) {
std::string regex_expr_collapsed;
// track if we are inside []
bool inside = false;
for (size_t i = 0; i < regex_expr.size(); ++i) {
if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) {
regex_expr_collapsed += '[';
inside = true;
continue;
} else if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') {
regex_expr_collapsed += ']';
inside = false;
continue;
}
if (regex_expr[i] == '\\' && i + 1 < regex_expr.size()) {
if (regex_expr[i + 1] == 'p') {
if (i + 3 < regex_expr.size() && regex_expr[i + 2] == '{') {
if (regex_expr[i + 4] == '}') {
const std::string pat = regex_expr.substr(i, 5);
if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
if (!inside) {
regex_expr_collapsed += '[';
}
regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
if (!inside) {
regex_expr_collapsed += ']';
}
i += 4;
continue;
}
}
}
}
}
regex_expr_collapsed += regex_expr[i];
}
bpe_offsets = unicode_regex_preprocess(wtext_collapsed, bpe_offsets, wregex_expr);
//printf("text_collapsed: %s\n", text_collapsed.c_str());
//printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str());
bpe_offsets = unicode_regex_preprocess(text_collapsed, bpe_offsets, regex_expr_collapsed);
} else {
// no unicode category used, we can use std::wregex directly
const std::wstring wtext = unicode_wstring_from_utf8(text);
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
//printf("text: %s\n", text.c_str());
//printf("regex_expr: %s\n", regex_expr.c_str());
bpe_offsets = unicode_regex_preprocess(wtext, bpe_offsets, wregex_expr);
}
} catch (std::regex_error & e) {
@ -584,9 +668,14 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
std::vector<std::string> bpe_words;
bpe_words.reserve(bpe_offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
for (size_t & offset : bpe_offsets) {
bpe_words.emplace_back(unicode_wstring_to_utf8(std::wstring(wtext, start, offset)));
//bpe_words.emplace_back(unicode_wstring_to_utf8(std::wstring(wtext, start, offset)));
bpe_words.emplace_back();
for (size_t i = start; i < start + offset; ++i) {
bpe_words.back() += unicode_cpt_to_utf8(cpts[i]);
}
start += offset;
}