diff --git a/unicode.cpp b/unicode.cpp index 233dae564..9abf83d34 100644 --- a/unicode.cpp +++ b/unicode.cpp @@ -224,138 +224,109 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string & t std::vector bpe_offsets; // store the offset of each word bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size - size_t start = 0; - const auto cpts = unicode_cpts_from_utf8(text); + size_t start = 0; for (auto offset : offsets) { - std::string token; - bool collecting_numeric = false; - bool collecting_letter = false; - bool collecting_special = false; - bool collecting_whitespace_lookahead = false; - bool collecting = false; + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; - std::vector text_utf; - text_utf.reserve(offset); + auto _get_cpt = [&] (const size_t pos) -> char32_t { + return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0; + }; - for (size_t i = start; i < start + offset; ++i) { - text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i])); - } + auto _get_cpt_type = [&] (const size_t pos) -> int { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED; + }; - for (int i = 0; i < (int)text_utf.size(); i++) { - const std::string & utf_char = text_utf[i]; - bool split_condition = false; - int bytes_remain = text_utf.size() - i; + size_t _prev_end = offset_ini; + auto _add_token = [&] (const size_t end) -> size_t { + assert(_prev_end <= end && end <= offset_end); + size_t len = end - _prev_end; + if(len > 0) + bpe_offsets.push_back(len); + _prev_end = end; + //if(len) { + // std::string s = ""; + // for(size_t p = end-len; p < end; p++) + // s += unicode_cpt_to_utf8(cpts[p]); + // printf(">>> '%s'\n", s.c_str()); + //} + return len; + }; - // forward backward lookups - const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : ""; - const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : ""; + for(size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { + const char32_t cpt = _get_cpt(pos); + const int cpt_type = _get_cpt_type(pos); - // handling contractions - if (!split_condition && bytes_remain >= 2) { - // 's|'t|'m|'d - if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) { - split_condition = true; - } - if (split_condition) { - if (token.size()) { - bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size()); - } - token = utf_char + utf_char_next; - bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size()); - token = ""; - i++; + // regex: 's|'t|'re|'ve|'m|'ll|'d + if (cpt == '\'' && pos+1 < offset_end) { + char32_t cpt_next = _get_cpt(pos+1); + if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { + pos += _add_token(pos+2); continue; - } - } - if (!split_condition && bytes_remain >= 3) { - // 're|'ve|'ll - if (utf_char == "\'" && ( - (utf_char_next == "r" && utf_char_next_next == "e") || - (utf_char_next == "v" && utf_char_next_next == "e") || - (utf_char_next == "l" && utf_char_next_next == "l")) - ) { - split_condition = true; - } - if (split_condition) { - // current token + next token can be defined - if (token.size()) { - bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size()); + } else if (pos+2 < offset_end) { + char32_t cpt_next_next = _get_cpt(pos+2); + if ((cpt_next == 'r' && cpt_next_next == 'e') || + (cpt_next == 'v' && cpt_next_next == 'e') || + (cpt_next == 'l' && cpt_next_next == 'l')) { + pos += _add_token(pos+3); + continue; } - token = utf_char; - token += utf_char_next; - token += utf_char_next_next; - - bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size()); - token = ""; - i += 2; - continue; } } - if (!split_condition && !collecting) { - if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) { - collecting_letter = true; - collecting = true; - } - else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { - collecting_numeric = true; - collecting = true; - } - else if ( - ((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) || - (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE) - ) { - collecting_special = true; - collecting = true; - } - else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) { - collecting_whitespace_lookahead = true; - collecting = true; - } - else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) { - split_condition = true; - } + int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type); + // regex: ?\p{L}+ + if (cpt2_type == CODEPOINT_TYPE_LETTER) { + pos += (cpt == ' '); + while(cpt2_type == CODEPOINT_TYPE_LETTER) + cpt2_type = _get_cpt_type(++pos); + _add_token(pos); + continue; } - else if (!split_condition && collecting) { - if (collecting_letter && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER) { - split_condition = true; - } - else if (collecting_numeric && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) { - split_condition = true; - } - else if (collecting_special && (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) { - split_condition = true; - } - else if (collecting_whitespace_lookahead && (unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { - split_condition = true; - } + // regex: ?\p{N}+ + if (cpt2_type == CODEPOINT_TYPE_DIGIT) { + pos += (cpt == ' '); + while(cpt2_type == CODEPOINT_TYPE_DIGIT) + cpt2_type = _get_cpt_type(++pos); + _add_token(pos); + continue; + } + // regex: ?[^\s\p{L}\p{N}]+ + if (cpt2_type != CODEPOINT_TYPE_WHITESPACE && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_DIGIT && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { + pos += (cpt == ' '); + while(cpt2_type != CODEPOINT_TYPE_WHITESPACE && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_DIGIT && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) + cpt2_type = _get_cpt_type(++pos); + _add_token(pos); + continue; } - if (utf_char_next == "") { - split_condition = true; // final - token += utf_char; + size_t num_whitespaces = 0; + while (_get_cpt_type(pos+num_whitespaces) == CODEPOINT_TYPE_WHITESPACE) { + num_whitespaces++; + } + + // regex: \s+(?!\S) + if(num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) { + pos += num_whitespaces - 1; + _add_token(pos); + continue; } - if (split_condition) { - if (token.size()) { - bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size()); - } - token = utf_char; - collecting = false; - collecting_letter = false; - collecting_numeric = false; - collecting_special = false; - collecting_whitespace_lookahead = false; - } - else { - token += utf_char; + // regex: \s+ + if(num_whitespaces > 0) { + pos += num_whitespaces; + _add_token(pos); + continue; } + + // no matches + _add_token(++pos); } - - start += offset; } return bpe_offsets;