unicode : clean-up

This commit is contained in:
Georgi Gerganov 2024-04-28 18:01:59 +03:00
parent d63cc9068b
commit e972e6cbf8
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 53 additions and 62 deletions

View file

@ -1,4 +1,4 @@
#include "unicode-data.h" #include "unicode-data.h"
#include <cstdint> #include <cstdint>
#include <map> #include <map>
@ -1649,7 +1649,3 @@ const std::map<char32_t, char32_t> unicode_map_lowercase = {
{0x1E917, 0x1E939}, {0x1E918, 0x1E93A}, {0x1E919, 0x1E93B}, {0x1E91A, 0x1E93C}, {0x1E91B, 0x1E93D}, {0x1E91C, 0x1E93E}, {0x1E917, 0x1E939}, {0x1E918, 0x1E93A}, {0x1E919, 0x1E93B}, {0x1E91A, 0x1E93C}, {0x1E91B, 0x1E93D}, {0x1E91C, 0x1E93E},
{0x1E91D, 0x1E93F}, {0x1E91E, 0x1E940}, {0x1E91F, 0x1E941}, {0x1E920, 0x1E942}, {0x1E921, 0x1E943}, {0x1E91D, 0x1E93F}, {0x1E91E, 0x1E940}, {0x1E91F, 0x1E941}, {0x1E920, 0x1E942}, {0x1E921, 0x1E943},
}; };
const std::set<std::string> unicode_regex_with_custom_preprocessor = {
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)"
};

View file

@ -2,8 +2,6 @@
#include <cstdint> #include <cstdint>
#include <map> #include <map>
#include <set>
#include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -14,6 +12,5 @@ extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_accent_ma
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_punctuation; extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_punctuation;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol; extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control; extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd; extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd;
extern const std::map<char32_t, char32_t> unicode_map_lowercase; extern const std::map<char32_t, char32_t> unicode_map_lowercase;
extern const std::set<std::string> unicode_regex_with_custom_preprocessor;

View file

@ -202,20 +202,16 @@ static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
return conv.from_bytes(s); return conv.from_bytes(s);
} }
//static inline std::string unicode_wstring_to_utf8(const std::wstring & ws) {
// std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
// return conv.to_bytes(ws);
//}
static std::vector<std::string> unicode_byte_encoding_process(const std::vector<std::string> & bpe_words) { static std::vector<std::string> unicode_byte_encoding_process(const std::vector<std::string> & bpe_words) {
std::vector<std::string>bpe_encoded_words; std::vector<std::string> bpe_encoded_words;
for (auto word : bpe_words) { for (const auto & word : bpe_words) {
std::string text_utf = ""; std::string text_utf;
auto utf_word = unicode_cpts_from_utf8(word); auto utf_word = unicode_cpts_from_utf8(word);
for (size_t i = 0; i < utf_word.size(); ++i) for (size_t i = 0; i < utf_word.size(); ++i) {
text_utf += unicode_cpt_to_utf8(utf_word[i]); text_utf += unicode_cpt_to_utf8(utf_word[i]);
}
std::string encoded_token = ""; std::string encoded_token;
for (char & c : text_utf) { for (char & c : text_utf) {
encoded_token += unicode_byte_to_utf8(c); encoded_token += unicode_byte_to_utf8(c);
} }
@ -225,7 +221,7 @@ static std::vector<std::string> unicode_byte_encoding_process(const std::vector<
} }
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+ // GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::string & text, const std::vector<size_t> & offsets) { static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & text, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets; // store the offset of each word std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
@ -289,7 +285,10 @@ static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::string & tex
if (token.size()) { if (token.size()) {
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size()); bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
} }
token = utf_char + utf_char_next + utf_char_next_next; token = utf_char;
token += utf_char_next;
token += utf_char_next_next;
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size()); bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
token = ""; token = "";
i += 2; i += 2;
@ -298,17 +297,17 @@ static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::string & tex
} }
if (!split_condition && !collecting) { if (!split_condition && !collecting) {
if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) { if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) {
collecting_letter = true; collecting_letter = true;
collecting = true; collecting = true;
} }
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
collecting_numeric = true; collecting_numeric = true;
collecting = true; collecting = true;
} }
else if ( else if (
((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) || ((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
(!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE) (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
) { ) {
collecting_special = true; collecting_special = true;
collecting = true; collecting = true;
@ -363,7 +362,8 @@ static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::string & tex
return bpe_offsets; return bpe_offsets;
} }
static std::vector<size_t> unicode_regex_preprocess(const std::wstring & wtext, const std::vector<size_t> & offsets, const std::wstring & regex_expr) { // use std::wregex to split the text
static std::vector<size_t> unicode_regex_split_stl(const std::wstring & wtext, const std::vector<size_t> & offsets, const std::wstring & regex_expr) {
std::wregex expr(regex_expr); std::wregex expr(regex_expr);
std::vector<size_t> bpe_offsets; // store the offset of each word std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
@ -393,7 +393,7 @@ static std::vector<size_t> unicode_regex_preprocess(const std::wstring & wtext,
} }
// use std::regex to split the text // use std::regex to split the text
static std::vector<size_t> unicode_regex_preprocess(const std::string & text, const std::vector<size_t> & offsets, const std::string & regex_expr) { static std::vector<size_t> unicode_regex_split_stl(const std::string & text, const std::vector<size_t> & offsets, const std::string & regex_expr) {
std::regex expr(regex_expr); std::regex expr(regex_expr);
std::vector<size_t> bpe_offsets; // store the offset of each word std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
@ -422,15 +422,11 @@ static std::vector<size_t> unicode_regex_preprocess(const std::string & text, co
return bpe_offsets; return bpe_offsets;
} }
static bool unicode_regex_with_custom_preprocessor_exists(const std::string & regex) { static std::vector<size_t> unicode_regex_split_custom(const std::string & regex, const std::string & text, const std::vector<size_t> & offsets) {
return unicode_regex_with_custom_preprocessor.find(regex) != unicode_regex_with_custom_preprocessor.end();
}
static std::vector<size_t> unicode_regex_custom_preprocess(const std::string & regex, const std::string & text, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets; std::vector<size_t> bpe_offsets;
if (regex == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") { if (regex == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
bpe_offsets = unicode_gpt2_regex_preprocess(text, offsets); bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
} }
return bpe_offsets; return bpe_offsets;
@ -518,20 +514,6 @@ char32_t unicode_tolower(char32_t cp) {
return it == unicode_map_lowercase.end() ? cp : it->second; return it == unicode_map_lowercase.end() ? cp : it->second;
} }
static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
std::string result;
for (size_t pos = 0; ; pos += search.length()) {
auto new_pos = s.find(search, pos);
if (new_pos == std::string::npos) {
result += s.substr(pos, s.size() - pos);
break;
}
result += s.substr(pos, new_pos - pos) + replace;
pos = new_pos;
}
s = std::move(result);
}
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) { std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
// unicode categories // unicode categories
static const std::map<std::string, int> k_ucat_enum = { static const std::map<std::string, int> k_ucat_enum = {
@ -547,16 +529,16 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
}; };
static const std::map<int, std::string> k_ucat_map = { static const std::map<int, std::string> k_ucat_map = {
{ CODEPOINT_TYPE_DIGIT, "\x30-\x39" }, { CODEPOINT_TYPE_DIGIT, "\x30-\x39" }, // 0-9
{ CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" }, { CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, { CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
}; };
// compute collapsed codepoints only if needed by at least one regex // compute collapsed codepoints only if needed by at least one regex
bool need_collapse = false; bool need_collapse = false;
for (auto & regex_expr : regex_exprs) { for (auto & regex_expr : regex_exprs) {
// search for unicode categories // search for unicode categories
for (auto & ucat : k_ucat_enum) { for (const auto & ucat : k_ucat_enum) {
if (std::string::npos != regex_expr.find(ucat.first)) { if (std::string::npos != regex_expr.find(ucat.first)) {
need_collapse = true; need_collapse = true;
break; break;
@ -564,10 +546,13 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
} }
} }
const auto cpts = unicode_cpts_from_utf8(text);
// generated a "collapsed" representation of the text, where all codepoints are replaced by a single byte
// ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
std::string text_collapsed; std::string text_collapsed;
if (need_collapse) { if (need_collapse) {
// collapse all unicode categories // collapse all unicode categories
const auto cpts = unicode_cpts_from_utf8(text);
text_collapsed.resize(cpts.size()); text_collapsed.resize(cpts.size());
for (size_t i = 0; i < cpts.size(); ++i) { for (size_t i = 0; i < cpts.size(); ++i) {
@ -587,14 +572,16 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
} }
} }
const auto cpts = unicode_cpts_from_utf8(text);
std::vector<size_t> bpe_offsets = { cpts.size() }; std::vector<size_t> bpe_offsets = { cpts.size() };
for (auto & regex_expr : regex_exprs) { for (auto & regex_expr : regex_exprs) {
if (unicode_regex_with_custom_preprocessor_exists(regex_expr)) { // first, see if we have an efficient custom regex implementation
bpe_offsets = unicode_regex_custom_preprocess(regex_expr, text, bpe_offsets); auto tmp = unicode_regex_split_custom(regex_expr, text, bpe_offsets);
if (!tmp.empty()) {
bpe_offsets = std::move(tmp);
} else { } else {
// fallback // fallback to general-purpose std::regex / std::wregex
try { try {
// if a unicode category is used in the regex, we use the collapsed text and replace the unicode category // if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
// with the corresponding collapsed representation // with the corresponding collapsed representation
@ -607,20 +594,32 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
} }
if (use_collapsed) { if (use_collapsed) {
// sanity-check that the original regex does not contain any non-ASCII characters
const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
for (size_t i = 0; i < cpts_regex.size(); ++i) {
if (cpts_regex[i] >= 128) {
throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported");
}
}
// generate a collapsed representation of the regex
std::string regex_expr_collapsed; std::string regex_expr_collapsed;
// track if we are inside [] // track if we are inside [], because nested [] are not allowed
bool inside = false; bool inside = false;
for (size_t i = 0; i < regex_expr.size(); ++i) { for (size_t i = 0; i < regex_expr.size(); ++i) {
if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) { if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) {
regex_expr_collapsed += '['; regex_expr_collapsed += '[';
inside = true; inside = true;
continue; continue;
} else if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') { }
if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') {
regex_expr_collapsed += ']'; regex_expr_collapsed += ']';
inside = false; inside = false;
continue; continue;
} }
if (regex_expr[i] == '\\' && i + 1 < regex_expr.size()) { if (regex_expr[i] == '\\' && i + 1 < regex_expr.size()) {
if (regex_expr[i + 1] == 'p') { if (regex_expr[i + 1] == 'p') {
if (i + 3 < regex_expr.size() && regex_expr[i + 2] == '{') { if (i + 3 < regex_expr.size() && regex_expr[i + 2] == '{') {
@ -648,7 +647,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
//printf("text_collapsed: %s\n", text_collapsed.c_str()); //printf("text_collapsed: %s\n", text_collapsed.c_str());
//printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str()); //printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str());
bpe_offsets = unicode_regex_preprocess(text_collapsed, bpe_offsets, regex_expr_collapsed); bpe_offsets = unicode_regex_split_stl(text_collapsed, bpe_offsets, regex_expr_collapsed);
} else { } else {
// no unicode category used, we can use std::wregex directly // no unicode category used, we can use std::wregex directly
const std::wstring wtext = unicode_wstring_from_utf8(text); const std::wstring wtext = unicode_wstring_from_utf8(text);
@ -656,7 +655,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
//printf("text: %s\n", text.c_str()); //printf("text: %s\n", text.c_str());
//printf("regex_expr: %s\n", regex_expr.c_str()); //printf("regex_expr: %s\n", regex_expr.c_str());
bpe_offsets = unicode_regex_preprocess(wtext, bpe_offsets, wregex_expr); bpe_offsets = unicode_regex_split_stl(wtext, bpe_offsets, wregex_expr);
} }
} catch (std::regex_error & e) { } catch (std::regex_error & e) {
fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str()); fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str());
@ -667,11 +666,10 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
} }
std::vector<std::string> bpe_words; std::vector<std::string> bpe_words;
bpe_words.reserve(bpe_offsets.size()); // Reserve memory for the approximate size bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size
size_t start = 0; size_t start = 0;
for (size_t & offset : bpe_offsets) { for (size_t & offset : bpe_offsets) {
//bpe_words.emplace_back(unicode_wstring_to_utf8(std::wstring(wtext, start, offset)));
bpe_words.emplace_back(); bpe_words.emplace_back();
for (size_t i = start; i < start + offset; ++i) { for (size_t i = start; i < start + offset; ++i) {
bpe_words.back() += unicode_cpt_to_utf8(cpts[i]); bpe_words.back() += unicode_cpt_to_utf8(cpts[i]);