Added needed functionality, testing remains

This commit is contained in:
Kazim Abrar Mahi 2024-04-16 04:56:35 +06:00 committed by Georgi Gerganov
parent 7e308ed212
commit feeaf4f39c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 175 additions and 173 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -2,6 +2,8 @@
#include <cstdint>
#include <map>
#include <set>
#include <string>
#include <utility>
#include <vector>
@ -14,4 +16,5 @@ extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd;
extern const std::map<char32_t, char32_t> unicode_map_lowercase;
extern const std::map<std::string, std::wstring> unicode_regex_to_wregex;
extern const std::map<std::string, std::wstring> unicode_regex_equivalent_wregex;
extern const std::set<std::string> unicode_regex_with_custom_preprocessor;

View file

@ -228,8 +228,13 @@ static std::vector<std::string> unicode_byte_encoding_process(const std::vector<
return bpe_encoded_words;
}
static std::vector<std::string> unicode_custom_preprocess(const std::string & text) {
std::vector<std::string> bpe_words;
static std::vector<size_t> unicode_gpt2_regex_preprocess(const std::wstring & wtext, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets; // stroe the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
for(auto offset : offsets) {
const std::string text = unicode_wstring_to_utf8(std::wstring(wtext, start, offset));
std::string token = "";
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
@ -241,7 +246,6 @@ static std::vector<std::string> unicode_custom_preprocess(const std::string & te
std::vector<std::string> text_utf;
text_utf.reserve(text.size());
bpe_words.reserve(text.size());
const auto cpts = unicode_cpts_from_utf8(text);
for (size_t i = 0; i < cpts.size(); ++i)
@ -263,10 +267,10 @@ static std::vector<std::string> unicode_custom_preprocess(const std::string & te
}
if (split_condition) {
if (token.size()) {
bpe_words.emplace_back(token); // push previous content as token
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
}
token = utf_char + utf_char_next;
bpe_words.emplace_back(token);
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
token = "";
i++;
continue;
@ -284,10 +288,10 @@ static std::vector<std::string> unicode_custom_preprocess(const std::string & te
if (split_condition) {
// current token + next token can be defined
if (token.size()) {
bpe_words.emplace_back(token); // push previous content as token
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
}
token = utf_char + utf_char_next + utf_char_next_next;
bpe_words.emplace_back(token); // the contraction
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
token = "";
i += 2;
continue;
@ -340,7 +344,7 @@ static std::vector<std::string> unicode_custom_preprocess(const std::string & te
if (split_condition) {
if (token.size()) {
bpe_words.emplace_back(token);
bpe_offsets.emplace_back(unicode_wstring_from_utf8(token).size());
}
token = utf_char;
collecting = false;
@ -353,8 +357,9 @@ static std::vector<std::string> unicode_custom_preprocess(const std::string & te
token += utf_char;
}
}
}
return bpe_words;
return bpe_offsets;
}
static std::vector<size_t> unicode_regex_preprocess(const std::wstring & text, const std::vector<size_t> & offsets, const std::wstring & regex_expr) {
@ -386,16 +391,22 @@ static std::vector<size_t> unicode_regex_preprocess(const std::wstring & text, c
return bpe_offsets;
}
static bool unicode_regex_matched(const std::wstring & text, const std::vector<std::wstring> & regex_exprs) {
for(auto & regex_expr: regex_exprs) {
std::wregex expr(regex_expr);
if(std::regex_match(text, expr)) {
return true;
}
static bool unicode_regex_equivalent_wregex_exists(const std::string & regex) {
return unicode_regex_equivalent_wregex.find(regex) != unicode_regex_equivalent_wregex.end();
}
return false;
static bool unicode_regex_with_custom_preprocessor_exists(const std::string & regex) {
return unicode_regex_with_custom_preprocessor.find(regex) != unicode_regex_with_custom_preprocessor.end();
}
static std::vector<size_t> unicode_regex_custom_preprocess(const std::string & regex, const std::wstring & wtext, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets;
if(regex == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
bpe_offsets = unicode_gpt2_regex_preprocess(wtext, offsets);
}
return bpe_offsets;
}
//
@ -480,32 +491,28 @@ char32_t unicode_tolower(char32_t cp) {
return it == unicode_map_lowercase.end() ? cp : it->second;
}
bool unicode_wregex_exists(const std::string & regex) {
return unicode_regex_to_wregex.find(regex) != unicode_regex_to_wregex.end();
}
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::wstring> & regex_exprs) {
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
std::wstring wtext = unicode_wstring_from_utf8(text);
std::vector<size_t> bpe_offsets = {wtext.size()};
for(auto & regex_expr : regex_exprs) {
bpe_offsets = unicode_regex_preprocess(wtext, bpe_offsets, regex_expr);
if (unicode_regex_equivalent_wregex_exists(regex_expr)) {
const std::wstring& wregex_expr = unicode_regex_equivalent_wregex.at(regex_expr);
bpe_offsets = unicode_regex_preprocess(wtext, bpe_offsets, wregex_expr);
} else if (unicode_regex_with_custom_preprocessor_exists(regex_expr)) {
bpe_offsets = unicode_regex_custom_preprocess(regex_expr, wtext, bpe_offsets);
} else {
throw std::runtime_error("Unicode regex is not found");
}
}
std::vector<std::string> bpe_words;
bpe_words.reserve(bpe_offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
for(size_t & offset : bpe_offsets) {
const auto temp_word = std::wstring(wtext, start, offset);
if(unicode_regex_matched(temp_word, regex_exprs)) {
bpe_words.emplace_back(unicode_wstring_to_utf8(temp_word));
} else {
auto custom_bpe_words = unicode_custom_preprocess(unicode_wstring_to_utf8(temp_word));
bpe_words.insert(bpe_words.end(), custom_bpe_words.begin(), custom_bpe_words.end());
}
bpe_words.emplace_back(unicode_wstring_to_utf8(std::wstring(wtext, start, offset)));
start += offset;
}

View file

@ -28,5 +28,4 @@ uint8_t unicode_utf8_to_byte(const std::string & utf8);
char32_t unicode_tolower(char32_t cp);
bool unicode_wregex_exists(const std::string & regex);
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::wstring> & regex_exprs);
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);