unicode : category support via std::regex
This commit is contained in:
parent
581c4a0239
commit
b97add52a4
4 changed files with 172 additions and 99 deletions
|
@ -4290,7 +4290,7 @@ static void llm_load_vocab(
|
|||
}
|
||||
|
||||
if (tokenizer_pre.empty()) {
|
||||
LLAMA_LOG_WARN("%s: missing tokenizer pre, using default tokenizer pre: 'default'\n", __func__);
|
||||
LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||
} else if (tokenizer_pre == "default") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -16,5 +16,4 @@ extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol;
|
|||
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
|
||||
extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd;
|
||||
extern const std::map<char32_t, char32_t> unicode_map_lowercase;
|
||||
extern const std::map<std::string, std::wstring> unicode_regex_equivalent_wregex;
|
||||
extern const std::set<std::string> unicode_regex_with_custom_preprocessor;
|
||||
|
|
243
unicode.cpp
243
unicode.cpp
|
@ -202,15 +202,10 @@ static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
|
|||
return conv.from_bytes(s);
|
||||
}
|
||||
|
||||
static inline std::string unicode_wstring_to_utf8(const std::wstring & ws) {
|
||||
#if defined(_MSC_VER)
|
||||
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>, wchar_t> converter;
|
||||
return converter.to_bytes(ws);
|
||||
#else
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
|
||||
return conv.to_bytes(ws);
|
||||
#endif
|
||||
}
|
||||
//static inline std::string unicode_wstring_to_utf8(const std::wstring & ws) {
|
||||
// std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
|
||||
// return conv.to_bytes(ws);
|
||||
//}
|
||||
|
||||
static std::vector<std::string> unicode_byte_encoding_process(const std::vector<std::string> & bpe_words) {
|
||||
std::vector<std::string>bpe_encoded_words;
|
||||
|
@ -229,16 +224,18 @@ static std::vector<std::string> unicode_byte_encoding_process(const std::vector<
|
|||
return bpe_encoded_words;
|
||||
}
|
||||
|
||||
static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::wstring & wtext, const std::vector<size_t> & offsets) {
|
||||
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
|
||||
static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::string & text, const std::vector<size_t> & offsets) {
|
||||
std::vector<size_t> bpe_offsets; // store the offset of each word
|
||||
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
||||
|
||||
size_t start = 0;
|
||||
|
||||
for (auto offset : offsets) {
|
||||
const std::string text = unicode_wstring_to_utf8(std::wstring(wtext, start, offset));
|
||||
const auto cpts = unicode_cpts_from_utf8(text);
|
||||
|
||||
for (auto offset : offsets) {
|
||||
std::string token;
|
||||
|
||||
std::string token = "";
|
||||
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
|
||||
bool collecting_numeric = false;
|
||||
bool collecting_letter = false;
|
||||
bool collecting_special = false;
|
||||
|
@ -246,10 +243,9 @@ static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::wstring & wt
|
|||
bool collecting = false;
|
||||
|
||||
std::vector<std::string> text_utf;
|
||||
text_utf.reserve(text.size());
|
||||
text_utf.reserve(offset);
|
||||
|
||||
const auto cpts = unicode_cpts_from_utf8(text);
|
||||
for (size_t i = 0; i < cpts.size(); ++i) {
|
||||
for (size_t i = start; i < start + offset; ++i) {
|
||||
text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i]));
|
||||
}
|
||||
|
||||
|
@ -270,10 +266,10 @@ static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::wstring & wt
|
|||
}
|
||||
if (split_condition) {
|
||||
if (token.size()) {
|
||||
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
|
||||
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
|
||||
}
|
||||
token = utf_char + utf_char_next;
|
||||
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
|
||||
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
|
||||
token = "";
|
||||
i++;
|
||||
continue;
|
||||
|
@ -291,10 +287,10 @@ static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::wstring & wt
|
|||
if (split_condition) {
|
||||
// current token + next token can be defined
|
||||
if (token.size()) {
|
||||
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
|
||||
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
|
||||
}
|
||||
token = utf_char + utf_char_next + utf_char_next_next;
|
||||
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
|
||||
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
|
||||
token = "";
|
||||
i += 2;
|
||||
continue;
|
||||
|
@ -347,7 +343,7 @@ static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::wstring & wt
|
|||
|
||||
if (split_condition) {
|
||||
if (token.size()) {
|
||||
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
|
||||
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
|
||||
}
|
||||
token = utf_char;
|
||||
collecting = false;
|
||||
|
@ -367,13 +363,13 @@ static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::wstring & wt
|
|||
return bpe_offsets;
|
||||
}
|
||||
|
||||
static std::vector<size_t> unicode_regex_preprocess(const std::wstring & text, const std::vector<size_t> & offsets, const std::wstring & regex_expr) {
|
||||
static std::vector<size_t> unicode_regex_preprocess(const std::wstring & wtext, const std::vector<size_t> & offsets, const std::wstring & regex_expr) {
|
||||
std::wregex expr(regex_expr);
|
||||
std::vector<size_t> bpe_offsets; // store the offset of each word
|
||||
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
||||
size_t start = 0;
|
||||
for (auto offset : offsets) {
|
||||
std::wcregex_iterator it(text.data() + start, text.data() + start + offset, expr);
|
||||
std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
|
||||
std::wcregex_iterator end;
|
||||
|
||||
int64_t start_idx = 0;
|
||||
|
@ -396,19 +392,45 @@ static std::vector<size_t> unicode_regex_preprocess(const std::wstring & text, c
|
|||
return bpe_offsets;
|
||||
}
|
||||
|
||||
static bool unicode_regex_equivalent_wregex_exists(const std::string & regex) {
|
||||
return unicode_regex_equivalent_wregex.find(regex) != unicode_regex_equivalent_wregex.end();
|
||||
// use std::regex to split the text
|
||||
static std::vector<size_t> unicode_regex_preprocess(const std::string & text, const std::vector<size_t> & offsets, const std::string & regex_expr) {
|
||||
std::regex expr(regex_expr);
|
||||
std::vector<size_t> bpe_offsets; // store the offset of each word
|
||||
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
||||
size_t start = 0;
|
||||
for (auto offset : offsets) {
|
||||
std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
|
||||
std::cregex_iterator end;
|
||||
|
||||
int64_t start_idx = 0;
|
||||
while (it != end) {
|
||||
std::cmatch match = *it;
|
||||
if (match.position() > start_idx) {
|
||||
bpe_offsets.emplace_back(match.position() - start_idx);
|
||||
}
|
||||
bpe_offsets.emplace_back(match.length());
|
||||
start_idx = match.position() + match.length();
|
||||
++it;
|
||||
}
|
||||
|
||||
if (start_idx < (int64_t) offset) {
|
||||
bpe_offsets.emplace_back(offset - start_idx);
|
||||
}
|
||||
start += offset;
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
}
|
||||
|
||||
static bool unicode_regex_with_custom_preprocessor_exists(const std::string & regex) {
|
||||
return unicode_regex_with_custom_preprocessor.find(regex) != unicode_regex_with_custom_preprocessor.end();
|
||||
}
|
||||
|
||||
static std::vector<size_t> unicode_regex_custom_preprocess(const std::string & regex, const std::wstring & wtext, const std::vector<size_t> & offsets) {
|
||||
static std::vector<size_t> unicode_regex_custom_preprocess(const std::string & regex, const std::string & text, const std::vector<size_t> & offsets) {
|
||||
std::vector<size_t> bpe_offsets;
|
||||
|
||||
if (regex == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
|
||||
bpe_offsets = unicode_gpt2_regex_preprocess(wtext, offsets);
|
||||
bpe_offsets = unicode_gpt2_regex_preprocess(text, offsets);
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
|
@ -496,82 +518,144 @@ char32_t unicode_tolower(char32_t cp) {
|
|||
return it == unicode_map_lowercase.end() ? cp : it->second;
|
||||
}
|
||||
|
||||
static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
||||
std::string result;
|
||||
for (size_t pos = 0; ; pos += search.length()) {
|
||||
auto new_pos = s.find(search, pos);
|
||||
if (new_pos == std::string::npos) {
|
||||
result += s.substr(pos, s.size() - pos);
|
||||
break;
|
||||
}
|
||||
result += s.substr(pos, new_pos - pos) + replace;
|
||||
pos = new_pos;
|
||||
}
|
||||
s = std::move(result);
|
||||
}
|
||||
|
||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
|
||||
std::wstring wtext = unicode_wstring_from_utf8(text);
|
||||
// unicode categories
|
||||
static const std::map<std::string, int> k_ucat_enum = {
|
||||
{ "\\p{N}", CODEPOINT_TYPE_DIGIT },
|
||||
{ "\\p{L}", CODEPOINT_TYPE_LETTER },
|
||||
{ "\\p{P}", CODEPOINT_TYPE_PUNCTUATION },
|
||||
};
|
||||
|
||||
static const std::map<int, int> k_ucat_cpt = {
|
||||
{ CODEPOINT_TYPE_DIGIT, 0xD1 },
|
||||
{ CODEPOINT_TYPE_LETTER, 0xD2 },
|
||||
{ CODEPOINT_TYPE_PUNCTUATION, 0xD3 },
|
||||
};
|
||||
|
||||
static const std::map<int, std::string> k_ucat_map = {
|
||||
{ CODEPOINT_TYPE_DIGIT, "\x30-\x39" },
|
||||
{ CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" },
|
||||
{ CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" },
|
||||
};
|
||||
|
||||
// compute collapsed codepoints only if needed by at least one regex
|
||||
bool need_collapse = false;
|
||||
for (auto & regex_expr : regex_exprs) {
|
||||
// search for \\p{L} or \\p{N}
|
||||
if (std::string::npos != regex_expr.find("\\p{N}") ||
|
||||
std::string::npos != regex_expr.find("\\p{L}") ||
|
||||
std::string::npos != regex_expr.find("\\p{P}")) {
|
||||
// search for unicode categories
|
||||
for (auto & ucat : k_ucat_enum) {
|
||||
if (std::string::npos != regex_expr.find(ucat.first)) {
|
||||
need_collapse = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::wstring wtext_collapsed;
|
||||
std::string text_collapsed;
|
||||
if (need_collapse) {
|
||||
// collapse all digit, letter and punctuation cpts to a single codepoint
|
||||
// the collapsed codepoint is selected to be the one at the end of the range
|
||||
//
|
||||
// - convert text to cpts
|
||||
// - collapse digit cpts to 0x0001FBF9
|
||||
// - collapse letter cpts to 0x0003134A
|
||||
// - collapse punctuation cpts to 0x0001E95F
|
||||
// - convert back to text
|
||||
auto cpts = unicode_cpts_from_utf8(text);
|
||||
// collapse all unicode categories
|
||||
const auto cpts = unicode_cpts_from_utf8(text);
|
||||
text_collapsed.resize(cpts.size());
|
||||
|
||||
for (size_t i = 0; i < cpts.size(); ++i) {
|
||||
if (unicode_cpt_type(cpts[i]) == CODEPOINT_TYPE_DIGIT) {
|
||||
cpts[i] = 0x0001FBF9;
|
||||
} else if (unicode_cpt_type(cpts[i]) == CODEPOINT_TYPE_LETTER) {
|
||||
cpts[i] = 0x0003134A;
|
||||
} else if (unicode_cpt_type(cpts[i]) == CODEPOINT_TYPE_PUNCTUATION) {
|
||||
cpts[i] = 0x0001E95F;
|
||||
}
|
||||
}
|
||||
wtext_collapsed = unicode_wstring_from_utf8(unicode_cpts_to_utf8(cpts));
|
||||
// keep single-byte codepoints as is
|
||||
if (cpts[i] < 128) {
|
||||
text_collapsed[i] = cpts[i];
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<size_t> bpe_offsets = {wtext.size()};
|
||||
const int cpt_type = unicode_cpt_type(cpts[i]);
|
||||
|
||||
if (k_ucat_cpt.find(cpt_type) != k_ucat_cpt.end()) {
|
||||
text_collapsed[i] = k_ucat_cpt.at(cpt_type);
|
||||
} else {
|
||||
text_collapsed[i] = (char) 0xD0; // fallback
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto cpts = unicode_cpts_from_utf8(text);
|
||||
std::vector<size_t> bpe_offsets = { cpts.size() };
|
||||
|
||||
for (auto & regex_expr : regex_exprs) {
|
||||
if (unicode_regex_with_custom_preprocessor_exists(regex_expr)) {
|
||||
bpe_offsets = unicode_regex_custom_preprocess(regex_expr, wtext, bpe_offsets);
|
||||
//} else if (unicode_regex_equivalent_wregex_exists(regex_expr)) {
|
||||
// const std::wstring & wregex_expr = unicode_regex_equivalent_wregex.at(regex_expr);
|
||||
// bpe_offsets = unicode_regex_preprocess(wtext, bpe_offsets, wregex_expr);
|
||||
bpe_offsets = unicode_regex_custom_preprocess(regex_expr, text, bpe_offsets);
|
||||
} else {
|
||||
// fallback
|
||||
try {
|
||||
// if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
|
||||
// with the corresponding collapsed codepoint
|
||||
if (std::string::npos != regex_expr.find("\\p{N}") ||
|
||||
std::string::npos != regex_expr.find("\\p{L}") ||
|
||||
std::string::npos != regex_expr.find("\\p{P}")) {
|
||||
// replace \\p{N} with \U0001FBF9
|
||||
// replace \\p{L} with \U0003134A
|
||||
// replace \\p{P} with \U0001E95F
|
||||
std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
|
||||
for (size_t i = 0; i < wregex_expr.size(); ++i) {
|
||||
if (wregex_expr[i] == L'\\' && i + 1 < wregex_expr.size()) {
|
||||
if (wregex_expr[i + 1] == L'p' && i + 3 < wregex_expr.size()) {
|
||||
if (wregex_expr[i + 2] == L'{' && wregex_expr[i + 4] == L'}') {
|
||||
if (wregex_expr[i + 3] == L'N') {
|
||||
wregex_expr.replace(i, 5, L"\U0001FBF9");
|
||||
} else if (wregex_expr[i + 3] == L'L') {
|
||||
wregex_expr.replace(i, 5, L"\U0003134A");
|
||||
} else if (wregex_expr[i + 3] == L'P') {
|
||||
wregex_expr.replace(i, 5, L"\U0001E95F");
|
||||
// with the corresponding collapsed representation
|
||||
bool use_collapsed = false;
|
||||
for (auto & ucat : k_ucat_enum) {
|
||||
if (std::string::npos != regex_expr.find(ucat.first)) {
|
||||
use_collapsed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (use_collapsed) {
|
||||
std::string regex_expr_collapsed;
|
||||
|
||||
// track if we are inside []
|
||||
bool inside = false;
|
||||
for (size_t i = 0; i < regex_expr.size(); ++i) {
|
||||
if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) {
|
||||
regex_expr_collapsed += '[';
|
||||
inside = true;
|
||||
continue;
|
||||
} else if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') {
|
||||
regex_expr_collapsed += ']';
|
||||
inside = false;
|
||||
continue;
|
||||
}
|
||||
if (regex_expr[i] == '\\' && i + 1 < regex_expr.size()) {
|
||||
if (regex_expr[i + 1] == 'p') {
|
||||
if (i + 3 < regex_expr.size() && regex_expr[i + 2] == '{') {
|
||||
if (regex_expr[i + 4] == '}') {
|
||||
const std::string pat = regex_expr.substr(i, 5);
|
||||
if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
|
||||
if (!inside) {
|
||||
regex_expr_collapsed += '[';
|
||||
}
|
||||
regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
|
||||
regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
|
||||
if (!inside) {
|
||||
regex_expr_collapsed += ']';
|
||||
}
|
||||
i += 4;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
bpe_offsets = unicode_regex_preprocess(wtext_collapsed, bpe_offsets, wregex_expr);
|
||||
|
||||
regex_expr_collapsed += regex_expr[i];
|
||||
}
|
||||
|
||||
//printf("text_collapsed: %s\n", text_collapsed.c_str());
|
||||
//printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str());
|
||||
bpe_offsets = unicode_regex_preprocess(text_collapsed, bpe_offsets, regex_expr_collapsed);
|
||||
} else {
|
||||
// no unicode category used, we can use std::wregex directly
|
||||
const std::wstring wtext = unicode_wstring_from_utf8(text);
|
||||
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
|
||||
|
||||
//printf("text: %s\n", text.c_str());
|
||||
//printf("regex_expr: %s\n", regex_expr.c_str());
|
||||
bpe_offsets = unicode_regex_preprocess(wtext, bpe_offsets, wregex_expr);
|
||||
}
|
||||
} catch (std::regex_error & e) {
|
||||
|
@ -584,9 +668,14 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|||
|
||||
std::vector<std::string> bpe_words;
|
||||
bpe_words.reserve(bpe_offsets.size()); // Reserve memory for the approximate size
|
||||
|
||||
size_t start = 0;
|
||||
for (size_t & offset : bpe_offsets) {
|
||||
bpe_words.emplace_back(unicode_wstring_to_utf8(std::wstring(wtext, start, offset)));
|
||||
//bpe_words.emplace_back(unicode_wstring_to_utf8(std::wstring(wtext, start, offset)));
|
||||
bpe_words.emplace_back();
|
||||
for (size_t i = start; i < start + offset; ++i) {
|
||||
bpe_words.back() += unicode_cpt_to_utf8(cpts[i]);
|
||||
}
|
||||
start += offset;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue