Added needed functionality, testing remains

This commit is contained in:
Kazim Abrar Mahi 2024-04-16 04:56:35 +06:00 committed by Georgi Gerganov
parent 7e308ed212
commit feeaf4f39c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 175 additions and 173 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -2,6 +2,8 @@
#include <cstdint>
#include <map>
#include <set>
#include <string>
#include <utility>
#include <vector>
@ -14,4 +16,5 @@ extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd;
extern const std::map<char32_t, char32_t> unicode_map_lowercase;
extern const std::map<std::string, std::wstring> unicode_regex_to_wregex;
extern const std::map<std::string, std::wstring> unicode_regex_equivalent_wregex;
extern const std::set<std::string> unicode_regex_with_custom_preprocessor;

View file

@ -228,133 +228,138 @@ static std::vector<std::string> unicode_byte_encoding_process(const std::vector<
return bpe_encoded_words;
}
static std::vector<std::string> unicode_custom_preprocess(const std::string & text) {
std::vector<std::string> bpe_words;
static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::wstring & wtext, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets; // stroe the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
std::string token = "";
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
bool collecting_numeric = false;
bool collecting_letter = false;
bool collecting_special = false;
bool collecting_whitespace_lookahead = false;
bool collecting = false;
for(auto offset : offsets) {
const std::string text = unicode_wstring_to_utf8(std::wstring(wtext, start, offset));
std::vector<std::string> text_utf;
text_utf.reserve(text.size());
bpe_words.reserve(text.size());
std::string token = "";
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
bool collecting_numeric = false;
bool collecting_letter = false;
bool collecting_special = false;
bool collecting_whitespace_lookahead = false;
bool collecting = false;
const auto cpts = unicode_cpts_from_utf8(text);
for (size_t i = 0; i < cpts.size(); ++i)
text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i]));
std::vector<std::string> text_utf;
text_utf.reserve(text.size());
for (int i = 0; i < (int)text_utf.size(); i++) {
const std::string & utf_char = text_utf[i];
bool split_condition = false;
int bytes_remain = text_utf.size() - i;
// forward backward lookups
const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : "";
const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : "";
const auto cpts = unicode_cpts_from_utf8(text);
for (size_t i = 0; i < cpts.size(); ++i)
text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i]));
// handling contractions
if (!split_condition && bytes_remain >= 2) {
// 's|'t|'m|'d
if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) {
split_condition = true;
for (int i = 0; i < (int)text_utf.size(); i++) {
const std::string & utf_char = text_utf[i];
bool split_condition = false;
int bytes_remain = text_utf.size() - i;
// forward backward lookups
const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : "";
const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : "";
// handling contractions
if (!split_condition && bytes_remain >= 2) {
// 's|'t|'m|'d
if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) {
split_condition = true;
}
if (split_condition) {
if (token.size()) {
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
}
token = utf_char + utf_char_next;
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
token = "";
i++;
continue;
}
}
if (!split_condition && bytes_remain >= 3) {
// 're|'ve|'ll
if (utf_char == "\'" && (
(utf_char_next == "r" && utf_char_next_next == "e") ||
(utf_char_next == "v" && utf_char_next_next == "e") ||
(utf_char_next == "l" && utf_char_next_next == "l"))
) {
split_condition = true;
}
if (split_condition) {
// current token + next token can be defined
if (token.size()) {
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
}
token = utf_char + utf_char_next + utf_char_next_next;
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
token = "";
i += 2;
continue;
}
}
if (!split_condition && !collecting) {
if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) {
collecting_letter = true;
collecting = true;
}
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
collecting_numeric = true;
collecting = true;
}
else if (
((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
(!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
) {
collecting_special = true;
collecting = true;
}
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
collecting_whitespace_lookahead = true;
collecting = true;
}
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
split_condition = true;
}
}
else if (!split_condition && collecting) {
if (collecting_letter && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER) {
split_condition = true;
}
else if (collecting_numeric && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) {
split_condition = true;
}
else if (collecting_special && (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
split_condition = true;
}
else if (collecting_whitespace_lookahead && (unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
split_condition = true;
}
}
if (utf_char_next == "") {
split_condition = true; // final
token += utf_char;
}
if (split_condition) {
if (token.size()) {
bpe_words.emplace_back(token); // push previous content as token
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
}
token = utf_char + utf_char_next;
bpe_words.emplace_back(token);
token = "";
i++;
continue;
token = utf_char;
collecting = false;
collecting_letter = false;
collecting_numeric = false;
collecting_special = false;
collecting_whitespace_lookahead = false;
}
}
if (!split_condition && bytes_remain >= 3) {
// 're|'ve|'ll
if (utf_char == "\'" && (
(utf_char_next == "r" && utf_char_next_next == "e") ||
(utf_char_next == "v" && utf_char_next_next == "e") ||
(utf_char_next == "l" && utf_char_next_next == "l"))
) {
split_condition = true;
else {
token += utf_char;
}
if (split_condition) {
// current token + next token can be defined
if (token.size()) {
bpe_words.emplace_back(token); // push previous content as token
}
token = utf_char + utf_char_next + utf_char_next_next;
bpe_words.emplace_back(token); // the contraction
token = "";
i += 2;
continue;
}
}
if (!split_condition && !collecting) {
if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) {
collecting_letter = true;
collecting = true;
}
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
collecting_numeric = true;
collecting = true;
}
else if (
((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
(!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
) {
collecting_special = true;
collecting = true;
}
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
collecting_whitespace_lookahead = true;
collecting = true;
}
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
split_condition = true;
}
}
else if (!split_condition && collecting) {
if (collecting_letter && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER) {
split_condition = true;
}
else if (collecting_numeric && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) {
split_condition = true;
}
else if (collecting_special && (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
split_condition = true;
}
else if (collecting_whitespace_lookahead && (unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
split_condition = true;
}
}
if (utf_char_next == "") {
split_condition = true; // final
token += utf_char;
}
if (split_condition) {
if (token.size()) {
bpe_words.emplace_back(token);
}
token = utf_char;
collecting = false;
collecting_letter = false;
collecting_numeric = false;
collecting_special = false;
collecting_whitespace_lookahead = false;
}
else {
token += utf_char;
}
}
return bpe_words;
return bpe_offsets;
}
static std::vector<size_t> unicode_regex_preprocess(const std::wstring & text, const std::vector<size_t> & offsets, const std::wstring & regex_expr) {
@ -386,16 +391,22 @@ static std::vector<size_t> unicode_regex_preprocess(const std::wstring & text, c
return bpe_offsets;
}
static bool unicode_regex_matched(const std::wstring & text, const std::vector<std::wstring> & regex_exprs) {
static bool unicode_regex_equivalent_wregex_exists(const std::string & regex) {
return unicode_regex_equivalent_wregex.find(regex) != unicode_regex_equivalent_wregex.end();
}
for(auto & regex_expr: regex_exprs) {
std::wregex expr(regex_expr);
if(std::regex_match(text, expr)) {
return true;
}
static bool unicode_regex_with_custom_preprocessor_exists(const std::string & regex) {
return unicode_regex_with_custom_preprocessor.find(regex) != unicode_regex_with_custom_preprocessor.end();
}
static std::vector<size_t> unicode_regex_custom_preprocess(const std::string & regex, const std::wstring & wtext, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets;
if(regex == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
bpe_offsets = unicode_gpt2_regex_preprocess(wtext, offsets);
}
return false;
return bpe_offsets;
}
//
@ -479,33 +490,29 @@ char32_t unicode_tolower(char32_t cp) {
auto it = unicode_map_lowercase.find(cp);
return it == unicode_map_lowercase.end() ? cp : it->second;
}
bool unicode_wregex_exists(const std::string & regex) {
return unicode_regex_to_wregex.find(regex) != unicode_regex_to_wregex.end();
}
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::wstring> & regex_exprs) {
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
std::wstring wtext = unicode_wstring_from_utf8(text);
std::vector<size_t> bpe_offsets = {wtext.size()};
for(auto & regex_expr : regex_exprs) {
bpe_offsets = unicode_regex_preprocess(wtext, bpe_offsets, regex_expr);
if (unicode_regex_equivalent_wregex_exists(regex_expr)) {
const std::wstring& wregex_expr = unicode_regex_equivalent_wregex.at(regex_expr);
bpe_offsets = unicode_regex_preprocess(wtext, bpe_offsets, wregex_expr);
} else if (unicode_regex_with_custom_preprocessor_exists(regex_expr)) {
bpe_offsets = unicode_regex_custom_preprocess(regex_expr, wtext, bpe_offsets);
} else {
throw std::runtime_error("Unicode regex is not found");
}
}
std::vector<std::string> bpe_words;
bpe_words.reserve(bpe_offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
for(size_t & offset : bpe_offsets){
const auto temp_word = std::wstring(wtext, start, offset);
if(unicode_regex_matched(temp_word, regex_exprs)) {
bpe_words.emplace_back(unicode_wstring_to_utf8(temp_word));
} else {
auto custom_bpe_words = unicode_custom_preprocess(unicode_wstring_to_utf8(temp_word));
bpe_words.insert(bpe_words.end(), custom_bpe_words.begin(), custom_bpe_words.end());
}
for(size_t & offset : bpe_offsets) {
bpe_words.emplace_back(unicode_wstring_to_utf8(std::wstring(wtext, start, offset)));
start += offset;
}

View file

@ -28,5 +28,4 @@ uint8_t unicode_utf8_to_byte(const std::string & utf8);
char32_t unicode_tolower(char32_t cp);
bool unicode_wregex_exists(const std::string & regex);
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::wstring> & regex_exprs);
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);