Replace 'codepoint_flags' with 'codepoint_categ'

This commit is contained in:
jaime-m-p 2024-07-20 23:28:05 +02:00
parent 2636cb6170
commit 23cf064e3b
3 changed files with 57 additions and 47 deletions

View file

@ -15836,22 +15836,22 @@ struct llm_tokenizer_wpm {
std::vector<std::string> words(1, "");
for (const uint32_t cpt : cpts_nfd) {
const auto flags = unicode_cpt_flags(cpt);
const auto categ = unicode_cpt_category(cpt);
if (flags.is_whitespace) {
if (categ.is_whitespace()) {
if (words.back().size()) { // finish previous word if any
words.emplace_back();
}
continue;
}
assert (!flags.is_separator);
if (cpt == 0 || cpt == 0xFFFD || flags.is_control) {
assert (!categ.is_S());
if (cpt == 0 || cpt == 0xFFFD || categ.is_C()) {
continue;
}
const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt));
if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) {
if (categ.is_P() || (cpt < 0x7F && categ.is_S()) || is_chinese_char(cpt)) {
if (words.back().size()) { // finish previous word if any
words.emplace_back();
}
@ -15869,7 +15869,7 @@ struct llm_tokenizer_wpm {
return words;
}
static bool is_chinese_char(uint32_t cpt) {
static bool is_chinese_char(uint32_t cpt) { //TODO: move to unicode-data.cpp? unicode_cpt_category(cpt).is_chinese()?
return
(cpt >= 0x04E00 && cpt <= 0x09FFF) ||
(cpt >= 0x03400 && cpt <= 0x04DBF) ||

View file

@ -203,8 +203,9 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
};
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
static const codepoint_categ SENTINEL = codepoint_categ::MASK + 1;
auto _get_categ = [&] (const size_t pos) -> codepoint_categ {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_category(cpts[pos]) : SENTINEL;
};
size_t _prev_end = offset_ini;
@ -226,7 +227,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
const uint32_t cpt = _get_cpt(pos);
const auto flags = _get_flags(pos);
const auto categ = _get_categ(pos);
// regex: 's|'t|'re|'ve|'m|'ll|'d
if (cpt == '\'' && pos+1 < offset_end) {
@ -246,37 +247,37 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
}
}
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
auto categ2 = (cpt == ' ' ? _get_categ(pos+1) : categ);
// regex: <space>?\p{L}+
if (flags2.is_letter) {
if (categ2.is_L()) {
pos += (cpt == ' ');
while (flags2.is_letter) {
flags2 = _get_flags(++pos);
while (categ2.is_L()) {
categ2 = _get_categ(++pos);
}
_add_token(pos);
continue;
}
// regex: <space>?\p{N}+
if (flags2.is_number) {
if (categ2.is_N()) {
pos += (cpt == ' ');
while (flags2.is_number) {
flags2 = _get_flags(++pos);
while (categ2.is_N()) {
categ2 = _get_categ(++pos);
}
_add_token(pos);
continue;
}
// regex: <space>?[^\s\p{L}\p{N}]+
if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
if (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) {
pos += (cpt == ' ');
while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
flags2 = _get_flags(++pos);
while (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) {
categ2 = _get_categ(++pos);
}
_add_token(pos);
continue;
}
size_t num_whitespaces = 0;
while (_get_flags(pos+num_whitespaces).is_whitespace) {
while (_get_categ(pos+num_whitespaces).is_whitespace()) {
num_whitespaces++;
}
@ -321,8 +322,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
};
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
static const codepoint_categ SENTINEL = codepoint_categ::MASK + 1;
auto _get_categ = [&] (const size_t pos) -> codepoint_categ {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_category(cpts[pos]) : SENTINEL;
};
size_t _prev_end = offset_ini;
@ -344,7 +346,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
const uint32_t cpt = _get_cpt(pos);
const auto flags = _get_flags(pos);
const auto categ = _get_categ(pos);
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
if (cpt == '\'' && pos+1 < offset_end) {
@ -365,10 +367,10 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
}
// regex: [^\r\n\p{L}\p{N}]?\p{L}+
if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) {
if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
if (!(cpt == '\r' || cpt == '\n' || categ.is_N())) {
if (categ.is_L() || _get_categ(pos+1).is_L()) { // one or more letters
pos++;
while (_get_flags(pos).is_letter) {
while (_get_categ(pos).is_L()) {
pos++;
}
_add_token(pos);
@ -377,9 +379,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
}
// regex: \p{N}{1,3}
if (flags.is_number) {
if (categ.is_N()) {
size_t ini = pos;
while (_get_flags(pos).is_number) {
while (_get_categ(pos).is_N()) {
if (++pos - ini >= 3 ) {
_add_token(pos);
ini = pos;
@ -390,11 +392,11 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
}
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) {
auto categ2 = (cpt == ' ' ? _get_categ(pos+1) : categ);
if (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) {
pos += (cpt == ' ');
while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
flags2 = _get_flags(++pos);
while (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) {
categ2 = _get_categ(++pos);
}
uint32_t cpt2 = _get_cpt(pos);
while (cpt2 == '\r' || cpt2 == '\n') {
@ -406,7 +408,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
size_t num_whitespaces = 0;
size_t last_end_r_or_n = 0;
while (_get_flags(pos+num_whitespaces).is_whitespace) {
while (_get_categ(pos+num_whitespaces).is_whitespace()) {
uint32_t cpt2 = _get_cpt(pos+num_whitespaces);
if (cpt2 == '\r' || cpt2 == '\n') {
last_end_r_or_n = pos + num_whitespaces + 1;
@ -636,21 +638,21 @@ uint32_t unicode_tolower(uint32_t cp) {
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
// unicode categories
static const std::map<std::string, int> k_ucat_enum = {
{ "\\p{N}", codepoint_flags::NUMBER },
{ "\\p{L}", codepoint_flags::LETTER },
{ "\\p{P}", codepoint_flags::PUNCTUATION },
{ "\\p{N}", codepoint_categ::N },
{ "\\p{L}", codepoint_categ::L },
{ "\\p{P}", codepoint_categ::P },
};
static const std::map<int, int> k_ucat_cpt = {
{ codepoint_flags::NUMBER, 0xD1 },
{ codepoint_flags::LETTER, 0xD2 },
{ codepoint_flags::PUNCTUATION, 0xD3 },
{ codepoint_categ::N, 0xD1 },
{ codepoint_categ::L, 0xD2 },
{ codepoint_categ::P, 0xD3 },
};
static const std::map<int, std::string> k_ucat_map = {
{ codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
{ codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
{ codepoint_categ::N, "\x30-\x39" }, // 0-9
{ codepoint_categ::L, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ codepoint_categ::P, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
};
// compute collapsed codepoints only if needed by at least one regex
@ -681,14 +683,14 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
continue;
}
const auto flags = unicode_cpt_flags(cpts[i]);
const auto categ = unicode_cpt_category(cpts[i]);
if (flags.is_whitespace) {
if (categ.is_whitespace()) {
//NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
//text_collapsed[i] = (char) 0x85; // <Next Line> as whitespace fallback
text_collapsed[i] = (char) 0x0B; // <vertical tab> as whitespace fallback
} else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
text_collapsed[i] = k_ucat_cpt.at(flags.category_flag());
} else if (k_ucat_cpt.find(categ.get_category()) != k_ucat_cpt.end()) {
text_collapsed[i] = k_ucat_cpt.at(categ.get_category());
} else {
text_collapsed[i] = (char) 0xD0; // fallback
}
@ -777,7 +779,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
std::wstring wtext(cpts.begin(), cpts.end());
for (size_t i = 0; i < wtext.size(); ++i) {
if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) {
if (wtext[i] > 0x7F && unicode_cpt_category(wtext[i]).is_whitespace()) {
wtext[i] = 0x0B;
}
}

View file

@ -121,6 +121,14 @@ struct codepoint_categ {
inline auto is_Zp() const { return (encoded & SUBMASK) == Zp; }
inline auto is_Zs() const { return (encoded & SUBMASK) == Zs; }
inline bool operator == (const codepoint_categ other) const {
return encoded == other.encoded;
}
inline bool operator != (const codepoint_categ other) const {
return encoded != other.encoded;
}
const char * c_str() const {
static const std::map<uint16_t, const char *> map = {
{UNDEF, "UNDEF"}, {C, "C"}, {L, "L"}, {M, "M"}, {N, "N"}, {P, "P"}, {S, "S"}, {Z, "Z"},