GPT2 custom regex split

This commit is contained in:
jaime-m-p 2024-04-29 19:13:18 +02:00
parent 5c38f6ed7a
commit 1d8fcc06ba

View file

@ -224,138 +224,109 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
std::vector<size_t> bpe_offsets; // store the offset of each word std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
const auto cpts = unicode_cpts_from_utf8(text); const auto cpts = unicode_cpts_from_utf8(text);
size_t start = 0;
for (auto offset : offsets) { for (auto offset : offsets) {
std::string token;
bool collecting_numeric = false; const size_t offset_ini = start;
bool collecting_letter = false; const size_t offset_end = start + offset;
bool collecting_special = false; assert(offset_end <= cpts.size());
bool collecting_whitespace_lookahead = false; start = offset_end;
bool collecting = false;
std::vector<std::string> text_utf; auto _get_cpt = [&] (const size_t pos) -> char32_t {
text_utf.reserve(offset); return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
};
for (size_t i = start; i < start + offset; ++i) { auto _get_cpt_type = [&] (const size_t pos) -> int {
text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i])); return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED;
} };
for (int i = 0; i < (int)text_utf.size(); i++) { size_t _prev_end = offset_ini;
const std::string & utf_char = text_utf[i]; auto _add_token = [&] (const size_t end) -> size_t {
bool split_condition = false; assert(_prev_end <= end && end <= offset_end);
int bytes_remain = text_utf.size() - i; size_t len = end - _prev_end;
if(len > 0)
bpe_offsets.push_back(len);
_prev_end = end;
//if(len) {
// std::string s = "";
// for(size_t p = end-len; p < end; p++)
// s += unicode_cpt_to_utf8(cpts[p]);
// printf(">>> '%s'\n", s.c_str());
//}
return len;
};
// forward backward lookups for(size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : ""; const char32_t cpt = _get_cpt(pos);
const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : ""; const int cpt_type = _get_cpt_type(pos);
// handling contractions // regex: 's|'t|'re|'ve|'m|'ll|'d
if (!split_condition && bytes_remain >= 2) { if (cpt == '\'' && pos+1 < offset_end) {
// 's|'t|'m|'d char32_t cpt_next = _get_cpt(pos+1);
if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) { if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
split_condition = true; pos += _add_token(pos+2);
}
if (split_condition) {
if (token.size()) {
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
}
token = utf_char + utf_char_next;
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
token = "";
i++;
continue; continue;
} } else if (pos+2 < offset_end) {
} char32_t cpt_next_next = _get_cpt(pos+2);
if (!split_condition && bytes_remain >= 3) { if ((cpt_next == 'r' && cpt_next_next == 'e') ||
// 're|'ve|'ll (cpt_next == 'v' && cpt_next_next == 'e') ||
if (utf_char == "\'" && ( (cpt_next == 'l' && cpt_next_next == 'l')) {
(utf_char_next == "r" && utf_char_next_next == "e") || pos += _add_token(pos+3);
(utf_char_next == "v" && utf_char_next_next == "e") || continue;
(utf_char_next == "l" && utf_char_next_next == "l"))
) {
split_condition = true;
}
if (split_condition) {
// current token + next token can be defined
if (token.size()) {
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
} }
token = utf_char;
token += utf_char_next;
token += utf_char_next_next;
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
token = "";
i += 2;
continue;
} }
} }
if (!split_condition && !collecting) { int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) { // regex: <space>?\p{L}+
collecting_letter = true; if (cpt2_type == CODEPOINT_TYPE_LETTER) {
collecting = true; pos += (cpt == ' ');
} while(cpt2_type == CODEPOINT_TYPE_LETTER)
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { cpt2_type = _get_cpt_type(++pos);
collecting_numeric = true; _add_token(pos);
collecting = true; continue;
}
else if (
((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
(token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
) {
collecting_special = true;
collecting = true;
}
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
collecting_whitespace_lookahead = true;
collecting = true;
}
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
split_condition = true;
}
} }
else if (!split_condition && collecting) { // regex: <space>?\p{N}+
if (collecting_letter && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER) { if (cpt2_type == CODEPOINT_TYPE_DIGIT) {
split_condition = true; pos += (cpt == ' ');
} while(cpt2_type == CODEPOINT_TYPE_DIGIT)
else if (collecting_numeric && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) { cpt2_type = _get_cpt_type(++pos);
split_condition = true; _add_token(pos);
} continue;
else if (collecting_special && (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) { }
split_condition = true; // regex: <space>?[^\s\p{L}\p{N}]+
} if (cpt2_type != CODEPOINT_TYPE_WHITESPACE && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_DIGIT && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
else if (collecting_whitespace_lookahead && (unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { pos += (cpt == ' ');
split_condition = true; while(cpt2_type != CODEPOINT_TYPE_WHITESPACE && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_DIGIT && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED)
} cpt2_type = _get_cpt_type(++pos);
_add_token(pos);
continue;
} }
if (utf_char_next == "") { size_t num_whitespaces = 0;
split_condition = true; // final while (_get_cpt_type(pos+num_whitespaces) == CODEPOINT_TYPE_WHITESPACE) {
token += utf_char; num_whitespaces++;
}
// regex: \s+(?!\S)
if(num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
pos += num_whitespaces - 1;
_add_token(pos);
continue;
} }
if (split_condition) { // regex: \s+
if (token.size()) { if(num_whitespaces > 0) {
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size()); pos += num_whitespaces;
} _add_token(pos);
token = utf_char; continue;
collecting = false;
collecting_letter = false;
collecting_numeric = false;
collecting_special = false;
collecting_whitespace_lookahead = false;
}
else {
token += utf_char;
} }
// no matches
_add_token(++pos);
} }
start += offset;
} }
return bpe_offsets; return bpe_offsets;