From 23cf064e3bcd1ca6502484c47b23cb48f4cb4321 Mon Sep 17 00:00:00 2001 From: jaime-m-p <> Date: Sat, 20 Jul 2024 23:28:05 +0200 Subject: [PATCH] Replace 'codepoint_flags' with 'codepoint_categ' --- src/llama.cpp | 12 +++---- src/unicode.cpp | 84 +++++++++++++++++++++++++------------------------ src/unicode.h | 8 +++++ 3 files changed, 57 insertions(+), 47 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 7d68ed811..e8dcc9ff3 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -15836,22 +15836,22 @@ struct llm_tokenizer_wpm { std::vector words(1, ""); for (const uint32_t cpt : cpts_nfd) { - const auto flags = unicode_cpt_flags(cpt); + const auto categ = unicode_cpt_category(cpt); - if (flags.is_whitespace) { + if (categ.is_whitespace()) { if (words.back().size()) { // finish previous word if any words.emplace_back(); } continue; } - assert (!flags.is_separator); - if (cpt == 0 || cpt == 0xFFFD || flags.is_control) { + assert (!categ.is_S()); + if (cpt == 0 || cpt == 0xFFFD || categ.is_C()) { continue; } const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt)); - if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) { + if (categ.is_P() || (cpt < 0x7F && categ.is_S()) || is_chinese_char(cpt)) { if (words.back().size()) { // finish previous word if any words.emplace_back(); } @@ -15869,7 +15869,7 @@ struct llm_tokenizer_wpm { return words; } - static bool is_chinese_char(uint32_t cpt) { + static bool is_chinese_char(uint32_t cpt) { //TODO: move to unicode-data.cpp? unicode_cpt_category(cpt).is_chinese()? return (cpt >= 0x04E00 && cpt <= 0x09FFF) || (cpt >= 0x03400 && cpt <= 0x04DBF) || diff --git a/src/unicode.cpp b/src/unicode.cpp index a78c59f74..4c3335974 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -203,8 +203,9 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string & t return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; }; - auto _get_flags = [&] (const size_t pos) -> codepoint_flags { - return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{}; + static const codepoint_categ SENTINEL = codepoint_categ::MASK + 1; + auto _get_categ = [&] (const size_t pos) -> codepoint_categ { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_category(cpts[pos]) : SENTINEL; }; size_t _prev_end = offset_ini; @@ -226,7 +227,7 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string & t for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { const uint32_t cpt = _get_cpt(pos); - const auto flags = _get_flags(pos); + const auto categ = _get_categ(pos); // regex: 's|'t|'re|'ve|'m|'ll|'d if (cpt == '\'' && pos+1 < offset_end) { @@ -246,37 +247,37 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string & t } } - auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags); + auto categ2 = (cpt == ' ' ? _get_categ(pos+1) : categ); // regex: ?\p{L}+ - if (flags2.is_letter) { + if (categ2.is_L()) { pos += (cpt == ' '); - while (flags2.is_letter) { - flags2 = _get_flags(++pos); + while (categ2.is_L()) { + categ2 = _get_categ(++pos); } _add_token(pos); continue; } // regex: ?\p{N}+ - if (flags2.is_number) { + if (categ2.is_N()) { pos += (cpt == ' '); - while (flags2.is_number) { - flags2 = _get_flags(++pos); + while (categ2.is_N()) { + categ2 = _get_categ(++pos); } _add_token(pos); continue; } // regex: ?[^\s\p{L}\p{N}]+ - if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) { + if (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) { pos += (cpt == ' '); - while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) { - flags2 = _get_flags(++pos); + while (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) { + categ2 = _get_categ(++pos); } _add_token(pos); continue; } size_t num_whitespaces = 0; - while (_get_flags(pos+num_whitespaces).is_whitespace) { + while (_get_categ(pos+num_whitespaces).is_whitespace()) { num_whitespaces++; } @@ -321,8 +322,9 @@ static std::vector unicode_regex_split_custom_llama3(const std::string & return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; }; - auto _get_flags = [&] (const size_t pos) -> codepoint_flags { - return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{}; + static const codepoint_categ SENTINEL = codepoint_categ::MASK + 1; + auto _get_categ = [&] (const size_t pos) -> codepoint_categ { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_category(cpts[pos]) : SENTINEL; }; size_t _prev_end = offset_ini; @@ -344,7 +346,7 @@ static std::vector unicode_regex_split_custom_llama3(const std::string & for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { const uint32_t cpt = _get_cpt(pos); - const auto flags = _get_flags(pos); + const auto categ = _get_categ(pos); // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive if (cpt == '\'' && pos+1 < offset_end) { @@ -365,10 +367,10 @@ static std::vector unicode_regex_split_custom_llama3(const std::string & } // regex: [^\r\n\p{L}\p{N}]?\p{L}+ - if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) { - if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters + if (!(cpt == '\r' || cpt == '\n' || categ.is_N())) { + if (categ.is_L() || _get_categ(pos+1).is_L()) { // one or more letters pos++; - while (_get_flags(pos).is_letter) { + while (_get_categ(pos).is_L()) { pos++; } _add_token(pos); @@ -377,9 +379,9 @@ static std::vector unicode_regex_split_custom_llama3(const std::string & } // regex: \p{N}{1,3} - if (flags.is_number) { + if (categ.is_N()) { size_t ini = pos; - while (_get_flags(pos).is_number) { + while (_get_categ(pos).is_N()) { if (++pos - ini >= 3 ) { _add_token(pos); ini = pos; @@ -390,11 +392,11 @@ static std::vector unicode_regex_split_custom_llama3(const std::string & } // regex: ?[^\s\p{L}\p{N}]+[\r\n]* - auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags); - if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) { + auto categ2 = (cpt == ' ' ? _get_categ(pos+1) : categ); + if (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) { pos += (cpt == ' '); - while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) { - flags2 = _get_flags(++pos); + while (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) { + categ2 = _get_categ(++pos); } uint32_t cpt2 = _get_cpt(pos); while (cpt2 == '\r' || cpt2 == '\n') { @@ -406,7 +408,7 @@ static std::vector unicode_regex_split_custom_llama3(const std::string & size_t num_whitespaces = 0; size_t last_end_r_or_n = 0; - while (_get_flags(pos+num_whitespaces).is_whitespace) { + while (_get_categ(pos+num_whitespaces).is_whitespace()) { uint32_t cpt2 = _get_cpt(pos+num_whitespaces); if (cpt2 == '\r' || cpt2 == '\n') { last_end_r_or_n = pos + num_whitespaces + 1; @@ -636,21 +638,21 @@ uint32_t unicode_tolower(uint32_t cp) { std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) { // unicode categories static const std::map k_ucat_enum = { - { "\\p{N}", codepoint_flags::NUMBER }, - { "\\p{L}", codepoint_flags::LETTER }, - { "\\p{P}", codepoint_flags::PUNCTUATION }, + { "\\p{N}", codepoint_categ::N }, + { "\\p{L}", codepoint_categ::L }, + { "\\p{P}", codepoint_categ::P }, }; static const std::map k_ucat_cpt = { - { codepoint_flags::NUMBER, 0xD1 }, - { codepoint_flags::LETTER, 0xD2 }, - { codepoint_flags::PUNCTUATION, 0xD3 }, + { codepoint_categ::N, 0xD1 }, + { codepoint_categ::L, 0xD2 }, + { codepoint_categ::P, 0xD3 }, }; static const std::map k_ucat_map = { - { codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9 - { codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z - { codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} + { codepoint_categ::N, "\x30-\x39" }, // 0-9 + { codepoint_categ::L, "\x41-\x5A\x61-\x7A" }, // A-Za-z + { codepoint_categ::P, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} }; // compute collapsed codepoints only if needed by at least one regex @@ -681,14 +683,14 @@ std::vector unicode_regex_split(const std::string & text, const std continue; } - const auto flags = unicode_cpt_flags(cpts[i]); + const auto categ = unicode_cpt_category(cpts[i]); - if (flags.is_whitespace) { + if (categ.is_whitespace()) { //NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does. //text_collapsed[i] = (char) 0x85; // as whitespace fallback text_collapsed[i] = (char) 0x0B; // as whitespace fallback - } else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) { - text_collapsed[i] = k_ucat_cpt.at(flags.category_flag()); + } else if (k_ucat_cpt.find(categ.get_category()) != k_ucat_cpt.end()) { + text_collapsed[i] = k_ucat_cpt.at(categ.get_category()); } else { text_collapsed[i] = (char) 0xD0; // fallback } @@ -777,7 +779,7 @@ std::vector unicode_regex_split(const std::string & text, const std // std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback std::wstring wtext(cpts.begin(), cpts.end()); for (size_t i = 0; i < wtext.size(); ++i) { - if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) { + if (wtext[i] > 0x7F && unicode_cpt_category(wtext[i]).is_whitespace()) { wtext[i] = 0x0B; } } diff --git a/src/unicode.h b/src/unicode.h index e8928f261..339ef2893 100644 --- a/src/unicode.h +++ b/src/unicode.h @@ -121,6 +121,14 @@ struct codepoint_categ { inline auto is_Zp() const { return (encoded & SUBMASK) == Zp; } inline auto is_Zs() const { return (encoded & SUBMASK) == Zs; } + inline bool operator == (const codepoint_categ other) const { + return encoded == other.encoded; + } + + inline bool operator != (const codepoint_categ other) const { + return encoded != other.encoded; + } + const char * c_str() const { static const std::map map = { {UNDEF, "UNDEF"}, {C, "C"}, {L, "L"}, {M, "M"}, {N, "N"}, {P, "P"}, {S, "S"}, {Z, "Z"},