Replace 'codepoint_flags' with 'codepoint_categ'
This commit is contained in:
parent
2636cb6170
commit
23cf064e3b
3 changed files with 57 additions and 47 deletions
|
@ -15836,22 +15836,22 @@ struct llm_tokenizer_wpm {
|
|||
std::vector<std::string> words(1, "");
|
||||
|
||||
for (const uint32_t cpt : cpts_nfd) {
|
||||
const auto flags = unicode_cpt_flags(cpt);
|
||||
const auto categ = unicode_cpt_category(cpt);
|
||||
|
||||
if (flags.is_whitespace) {
|
||||
if (categ.is_whitespace()) {
|
||||
if (words.back().size()) { // finish previous word if any
|
||||
words.emplace_back();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
assert (!flags.is_separator);
|
||||
if (cpt == 0 || cpt == 0xFFFD || flags.is_control) {
|
||||
assert (!categ.is_S());
|
||||
if (cpt == 0 || cpt == 0xFFFD || categ.is_C()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt));
|
||||
if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) {
|
||||
if (categ.is_P() || (cpt < 0x7F && categ.is_S()) || is_chinese_char(cpt)) {
|
||||
if (words.back().size()) { // finish previous word if any
|
||||
words.emplace_back();
|
||||
}
|
||||
|
@ -15869,7 +15869,7 @@ struct llm_tokenizer_wpm {
|
|||
return words;
|
||||
}
|
||||
|
||||
static bool is_chinese_char(uint32_t cpt) {
|
||||
static bool is_chinese_char(uint32_t cpt) { //TODO: move to unicode-data.cpp? unicode_cpt_category(cpt).is_chinese()?
|
||||
return
|
||||
(cpt >= 0x04E00 && cpt <= 0x09FFF) ||
|
||||
(cpt >= 0x03400 && cpt <= 0x04DBF) ||
|
||||
|
|
|
@ -203,8 +203,9 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
|
|||
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
|
||||
};
|
||||
|
||||
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
|
||||
static const codepoint_categ SENTINEL = codepoint_categ::MASK + 1;
|
||||
auto _get_categ = [&] (const size_t pos) -> codepoint_categ {
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_category(cpts[pos]) : SENTINEL;
|
||||
};
|
||||
|
||||
size_t _prev_end = offset_ini;
|
||||
|
@ -226,7 +227,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
|
|||
|
||||
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
||||
const uint32_t cpt = _get_cpt(pos);
|
||||
const auto flags = _get_flags(pos);
|
||||
const auto categ = _get_categ(pos);
|
||||
|
||||
// regex: 's|'t|'re|'ve|'m|'ll|'d
|
||||
if (cpt == '\'' && pos+1 < offset_end) {
|
||||
|
@ -246,37 +247,37 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
|
|||
}
|
||||
}
|
||||
|
||||
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
|
||||
auto categ2 = (cpt == ' ' ? _get_categ(pos+1) : categ);
|
||||
// regex: <space>?\p{L}+
|
||||
if (flags2.is_letter) {
|
||||
if (categ2.is_L()) {
|
||||
pos += (cpt == ' ');
|
||||
while (flags2.is_letter) {
|
||||
flags2 = _get_flags(++pos);
|
||||
while (categ2.is_L()) {
|
||||
categ2 = _get_categ(++pos);
|
||||
}
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
// regex: <space>?\p{N}+
|
||||
if (flags2.is_number) {
|
||||
if (categ2.is_N()) {
|
||||
pos += (cpt == ' ');
|
||||
while (flags2.is_number) {
|
||||
flags2 = _get_flags(++pos);
|
||||
while (categ2.is_N()) {
|
||||
categ2 = _get_categ(++pos);
|
||||
}
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
// regex: <space>?[^\s\p{L}\p{N}]+
|
||||
if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
|
||||
if (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) {
|
||||
pos += (cpt == ' ');
|
||||
while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
|
||||
flags2 = _get_flags(++pos);
|
||||
while (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) {
|
||||
categ2 = _get_categ(++pos);
|
||||
}
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t num_whitespaces = 0;
|
||||
while (_get_flags(pos+num_whitespaces).is_whitespace) {
|
||||
while (_get_categ(pos+num_whitespaces).is_whitespace()) {
|
||||
num_whitespaces++;
|
||||
}
|
||||
|
||||
|
@ -321,8 +322,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
|
||||
};
|
||||
|
||||
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
|
||||
static const codepoint_categ SENTINEL = codepoint_categ::MASK + 1;
|
||||
auto _get_categ = [&] (const size_t pos) -> codepoint_categ {
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_category(cpts[pos]) : SENTINEL;
|
||||
};
|
||||
|
||||
size_t _prev_end = offset_ini;
|
||||
|
@ -344,7 +346,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||
|
||||
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
||||
const uint32_t cpt = _get_cpt(pos);
|
||||
const auto flags = _get_flags(pos);
|
||||
const auto categ = _get_categ(pos);
|
||||
|
||||
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
|
||||
if (cpt == '\'' && pos+1 < offset_end) {
|
||||
|
@ -365,10 +367,10 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||
}
|
||||
|
||||
// regex: [^\r\n\p{L}\p{N}]?\p{L}+
|
||||
if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) {
|
||||
if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
|
||||
if (!(cpt == '\r' || cpt == '\n' || categ.is_N())) {
|
||||
if (categ.is_L() || _get_categ(pos+1).is_L()) { // one or more letters
|
||||
pos++;
|
||||
while (_get_flags(pos).is_letter) {
|
||||
while (_get_categ(pos).is_L()) {
|
||||
pos++;
|
||||
}
|
||||
_add_token(pos);
|
||||
|
@ -377,9 +379,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||
}
|
||||
|
||||
// regex: \p{N}{1,3}
|
||||
if (flags.is_number) {
|
||||
if (categ.is_N()) {
|
||||
size_t ini = pos;
|
||||
while (_get_flags(pos).is_number) {
|
||||
while (_get_categ(pos).is_N()) {
|
||||
if (++pos - ini >= 3 ) {
|
||||
_add_token(pos);
|
||||
ini = pos;
|
||||
|
@ -390,11 +392,11 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||
}
|
||||
|
||||
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
|
||||
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
|
||||
if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) {
|
||||
auto categ2 = (cpt == ' ' ? _get_categ(pos+1) : categ);
|
||||
if (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) {
|
||||
pos += (cpt == ' ');
|
||||
while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
|
||||
flags2 = _get_flags(++pos);
|
||||
while (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) {
|
||||
categ2 = _get_categ(++pos);
|
||||
}
|
||||
uint32_t cpt2 = _get_cpt(pos);
|
||||
while (cpt2 == '\r' || cpt2 == '\n') {
|
||||
|
@ -406,7 +408,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||
|
||||
size_t num_whitespaces = 0;
|
||||
size_t last_end_r_or_n = 0;
|
||||
while (_get_flags(pos+num_whitespaces).is_whitespace) {
|
||||
while (_get_categ(pos+num_whitespaces).is_whitespace()) {
|
||||
uint32_t cpt2 = _get_cpt(pos+num_whitespaces);
|
||||
if (cpt2 == '\r' || cpt2 == '\n') {
|
||||
last_end_r_or_n = pos + num_whitespaces + 1;
|
||||
|
@ -636,21 +638,21 @@ uint32_t unicode_tolower(uint32_t cp) {
|
|||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
|
||||
// unicode categories
|
||||
static const std::map<std::string, int> k_ucat_enum = {
|
||||
{ "\\p{N}", codepoint_flags::NUMBER },
|
||||
{ "\\p{L}", codepoint_flags::LETTER },
|
||||
{ "\\p{P}", codepoint_flags::PUNCTUATION },
|
||||
{ "\\p{N}", codepoint_categ::N },
|
||||
{ "\\p{L}", codepoint_categ::L },
|
||||
{ "\\p{P}", codepoint_categ::P },
|
||||
};
|
||||
|
||||
static const std::map<int, int> k_ucat_cpt = {
|
||||
{ codepoint_flags::NUMBER, 0xD1 },
|
||||
{ codepoint_flags::LETTER, 0xD2 },
|
||||
{ codepoint_flags::PUNCTUATION, 0xD3 },
|
||||
{ codepoint_categ::N, 0xD1 },
|
||||
{ codepoint_categ::L, 0xD2 },
|
||||
{ codepoint_categ::P, 0xD3 },
|
||||
};
|
||||
|
||||
static const std::map<int, std::string> k_ucat_map = {
|
||||
{ codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
|
||||
{ codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
||||
{ codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
||||
{ codepoint_categ::N, "\x30-\x39" }, // 0-9
|
||||
{ codepoint_categ::L, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
||||
{ codepoint_categ::P, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
||||
};
|
||||
|
||||
// compute collapsed codepoints only if needed by at least one regex
|
||||
|
@ -681,14 +683,14 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|||
continue;
|
||||
}
|
||||
|
||||
const auto flags = unicode_cpt_flags(cpts[i]);
|
||||
const auto categ = unicode_cpt_category(cpts[i]);
|
||||
|
||||
if (flags.is_whitespace) {
|
||||
if (categ.is_whitespace()) {
|
||||
//NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
|
||||
//text_collapsed[i] = (char) 0x85; // <Next Line> as whitespace fallback
|
||||
text_collapsed[i] = (char) 0x0B; // <vertical tab> as whitespace fallback
|
||||
} else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
|
||||
text_collapsed[i] = k_ucat_cpt.at(flags.category_flag());
|
||||
} else if (k_ucat_cpt.find(categ.get_category()) != k_ucat_cpt.end()) {
|
||||
text_collapsed[i] = k_ucat_cpt.at(categ.get_category());
|
||||
} else {
|
||||
text_collapsed[i] = (char) 0xD0; // fallback
|
||||
}
|
||||
|
@ -777,7 +779,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|||
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
|
||||
std::wstring wtext(cpts.begin(), cpts.end());
|
||||
for (size_t i = 0; i < wtext.size(); ++i) {
|
||||
if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) {
|
||||
if (wtext[i] > 0x7F && unicode_cpt_category(wtext[i]).is_whitespace()) {
|
||||
wtext[i] = 0x0B;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -121,6 +121,14 @@ struct codepoint_categ {
|
|||
inline auto is_Zp() const { return (encoded & SUBMASK) == Zp; }
|
||||
inline auto is_Zs() const { return (encoded & SUBMASK) == Zs; }
|
||||
|
||||
inline bool operator == (const codepoint_categ other) const {
|
||||
return encoded == other.encoded;
|
||||
}
|
||||
|
||||
inline bool operator != (const codepoint_categ other) const {
|
||||
return encoded != other.encoded;
|
||||
}
|
||||
|
||||
const char * c_str() const {
|
||||
static const std::map<uint16_t, const char *> map = {
|
||||
{UNDEF, "UNDEF"}, {C, "C"}, {L, "L"}, {M, "M"}, {N, "N"}, {P, "P"}, {S, "S"}, {Z, "Z"},
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue