Replace CODEPOINT_TYPE_* with codepoint_flags

This commit is contained in:
jaime-m-p 2024-05-12 17:49:11 +02:00
parent 3b3963c55c
commit e44e608239
6 changed files with 5361 additions and 2357 deletions

View file

@ -12575,16 +12575,16 @@ struct llm_tokenizer_wpm {
// to lowercase, pad chinese characters, pad punctuation // to lowercase, pad chinese characters, pad punctuation
std::string new_str = ""; std::string new_str = "";
for (uint32_t code : cpts_nfd) { for (uint32_t code : cpts_nfd) {
int type = unicode_cpt_type(code); const codepoint_flags flags = unicode_cpt_flags(code);
if (type == CODEPOINT_TYPE_ACCENT_MARK || type == CODEPOINT_TYPE_CONTROL) { if (flags.is_accent_mark || flags.is_control) {
continue; continue;
} }
code = unicode_tolower(code); code = unicode_tolower(code);
if (type == CODEPOINT_TYPE_SEPARATOR) { if (flags.is_separator || flags.is_whitespace) { //####FIXME: is_separator ?
code = ' '; code = ' ';
} }
std::string s = unicode_cpt_to_utf8(code); std::string s = unicode_cpt_to_utf8(code);
if (type == CODEPOINT_TYPE_PUNCTUATION || is_ascii_punct(code) || is_chinese_char(code)) { if (flags.is_punctuation || is_ascii_punct(code) || is_chinese_char(code)) {
new_str += " "; new_str += " ";
new_str += s; new_str += s;
new_str += " "; new_str += " ";

View file

@ -1,64 +1,105 @@
import regex import regex
import ctypes
def get_matches(regex_expr): class CoodepointFlags (ctypes.Structure):
regex_expr_compiled = regex.compile(regex_expr) _fields_ = [ # see definition in unicode.h
unicode_ranges = [] ("is_undefined", ctypes.c_uint16, 1),
current_range = None ("is_number", ctypes.c_uint16, 1), # regex: \p{N}
("is_letter", ctypes.c_uint16, 1), # regex: \p{L}
("is_separator", ctypes.c_uint16, 1), # regex: \p{Z}
("is_accent_mark", ctypes.c_uint16, 1), # regex: \p{M}
("is_punctuation", ctypes.c_uint16, 1), # regex: \p{P}
("is_symbol", ctypes.c_uint16, 1), # regex: \p{S}
("is_control", ctypes.c_uint16, 1), # regex: \p{C}
]
for codepoint in range(0x110000): assert(ctypes.sizeof(CoodepointFlags) == 2)
MAX_CODEPOINTS = 0x110000
regex_number = regex.compile(r'\p{N}')
regex_letter = regex.compile(r'\p{L}')
regex_separator = regex.compile(r'\p{Z}')
regex_accent_mark = regex.compile(r'\p{M}')
regex_punctuation = regex.compile(r'\p{P}')
regex_symbol = regex.compile(r'\p{S}')
regex_control = regex.compile(r'\p{C}')
regex_whitespace = regex.compile(r'\s')
codepoint_flags = (CoodepointFlags * MAX_CODEPOINTS)()
table_whitespace = []
table_lowercase = []
table_uppercase = []
for codepoint in range(MAX_CODEPOINTS):
# convert codepoint to unicode character
char = chr(codepoint) char = chr(codepoint)
if regex_expr_compiled.match(char):
if current_range is None:
current_range = [codepoint, codepoint]
else:
current_range[1] = codepoint
elif current_range is not None:
unicode_ranges.append(tuple(current_range))
current_range = None
if current_range is not None: # regex categories
unicode_ranges.append(tuple(current_range)) flags = codepoint_flags[codepoint]
flags.is_number = bool(regex_number.match(char))
flags.is_letter = bool(regex_letter.match(char))
flags.is_separator = bool(regex_separator.match(char))
flags.is_accent_mark = bool(regex_accent_mark.match(char))
flags.is_punctuation = bool(regex_punctuation.match(char))
flags.is_symbol = bool(regex_symbol.match(char))
flags.is_control = bool(regex_control.match(char))
flags.is_undefined = bytes(flags)[0] == 0
assert(not flags.is_undefined)
return unicode_ranges # whitespaces
if bool(regex_whitespace.match(char)):
table_whitespace.append(codepoint)
# lowercase conversion
def print_cat(mode, cat, ranges):
if mode == "range":
print("const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_{} = {{".format(cat)) # noqa: NP100
if mode == "map":
print("const std::map<uint32_t, uint32_t> unicode_map_{} = {{".format(cat)) # noqa: NP100
for i, values in enumerate(ranges):
end = ",\n" if (i % 4 == 3 or i + 1 == len(ranges)) else ", "
values = ["0x%08X" % value for value in values]
print("{" + ", ".join(values) + "}", end=end) # noqa: NP100
print("};") # noqa: NP100
print("") # noqa: NP100
print_cat("range", "number", get_matches(r'\p{N}'))
print_cat("range", "letter", get_matches(r'\p{L}'))
print_cat("range", "separator", get_matches(r'\p{Z}'))
print_cat("range", "accent_mark", get_matches(r'\p{M}'))
print_cat("range", "punctuation", get_matches(r'\p{P}'))
print_cat("range", "symbol", get_matches(r'\p{S}'))
print_cat("range", "control", get_matches(r'\p{C}'))
print_cat("range", "whitespace", get_matches(r'\s'))
map_lowercase = []
map_uppercase = []
for codepoint in range(0x110000):
char = chr(codepoint)
lower = ord(char.lower()[0]) lower = ord(char.lower()[0])
upper = ord(char.upper()[0])
if codepoint != lower: if codepoint != lower:
map_lowercase.append((codepoint, lower)) table_lowercase.append((codepoint, lower))
# uppercase conversion
upper = ord(char.upper()[0])
if codepoint != upper: if codepoint != upper:
map_uppercase.append((codepoint, upper)) table_uppercase.append((codepoint, upper))
print_cat("map", "lowercase", map_lowercase)
print_cat("map", "uppercase", map_uppercase)
# TODO: generate unicode_map_nfd ranges_flags = [(0, codepoint_flags[0])]
for codepoint, flags in enumerate(codepoint_flags):
if bytes(flags) != bytes(ranges_flags[-1][1]):
ranges_flags.append((codepoint, flags))
ranges_flags.append((MAX_CODEPOINTS, CoodepointFlags()))
# Generate 'unicode-data.cpp'
print("""\
// generated with scripts/gen-unicode-data.py
#include "unicode-data.h"
#include <cstdint>
#include <vector>
#include <unordered_map>
#include <unordered_set>
""")
print("const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags = { // start, flags // last=next_start-1")
for codepoint, flags in ranges_flags:
flags = int.from_bytes(bytes(flags), "little")
print("{0x%06X, 0x%04X}," % (codepoint, flags))
print("};\n")
print("const std::unordered_set<uint32_t> unicode_set_whitespace = {")
print(", ".join("0x%06X" % cpt for cpt in table_whitespace))
print("};\n")
print("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")
for tuple in table_lowercase:
print("{0x%06X, 0x%06X}," % tuple)
print("};\n")
print("const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase = {")
for tuple in table_uppercase:
print("{0x%06X, 0x%06X}," % tuple)
print("};\n")

File diff suppressed because it is too large Load diff

View file

@ -1,17 +1,13 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <map>
#include <utility>
#include <vector> #include <vector>
#include <unordered_map>
#include <unordered_set>
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_number; static const size_t MAX_CODEPOINTS = 0x110000;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_letter;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_separator; extern const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace; extern const std::unordered_set<uint32_t> unicode_set_whitespace;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_accent_mark; extern const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_punctuation; extern const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd;
extern const std::map<char32_t, char32_t> unicode_map_lowercase;

View file

@ -1,4 +1,4 @@
#include "unicode.h" #include "unicode.h"
#include "unicode-data.h" #include "unicode-data.h"
#include <cassert> #include <cassert>
@ -109,44 +109,32 @@ static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset)
// return result; // return result;
//} //}
static std::unordered_map<uint32_t, int> unicode_cpt_type_map() { static std::array<codepoint_flags, MAX_CODEPOINTS> unicode_cpt_flags_array() {
std::unordered_map<uint32_t, int> cpt_types; std::array<codepoint_flags, MAX_CODEPOINTS> cpt_flags;
for (auto p : unicode_ranges_number) {
for (auto i = p.first; i <= p.second; ++i) { assert (unicode_ranges_flags.front().first == 0);
cpt_types[i] = CODEPOINT_TYPE_NUMBER; assert (unicode_ranges_flags.back().first == MAX_CODEPOINTS);
for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
const auto range_ini = unicode_ranges_flags[i-1]; // codepoint_ini, flags
const auto range_end = unicode_ranges_flags[i]; // codepoint_end, flags
for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
cpt_flags[cpt] = range_ini.second;
} }
} }
for (auto p : unicode_ranges_letter) {
for (auto i = p.first; i <= p.second; ++i) { for (auto cpt : unicode_set_whitespace) {
cpt_types[i] = CODEPOINT_TYPE_LETTER; cpt_flags[cpt].is_whitespace = true;
} }
for (auto p : unicode_map_lowercase) {
cpt_flags[p.second].is_lowercase = true;
} }
for (auto p : unicode_ranges_separator) {
for (auto i = p.first; i <= p.second; ++i) { for (auto p : unicode_map_uppercase) {
cpt_types[i] = CODEPOINT_TYPE_SEPARATOR; cpt_flags[p.second].is_uppercase = true;
} }
}
for (auto p : unicode_ranges_accent_mark) { return cpt_flags;
for (auto i = p.first; i <= p.second; ++i) {
cpt_types[i] = CODEPOINT_TYPE_ACCENT_MARK;
}
}
for (auto p : unicode_ranges_punctuation) {
for (auto i = p.first; i <= p.second; ++i) {
cpt_types[i] = CODEPOINT_TYPE_PUNCTUATION;
}
}
for (auto p : unicode_ranges_symbol) {
for (auto i = p.first; i <= p.second; ++i) {
cpt_types[i] = CODEPOINT_TYPE_SYMBOL;
}
}
for (auto p : unicode_ranges_control) {
for (auto i = p.first; i <= p.second; ++i) {
cpt_types[i] = CODEPOINT_TYPE_CONTROL;
}
}
return cpt_types;
} }
static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() { static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
@ -238,8 +226,9 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0; return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
}; };
auto _get_cpt_type = [&] (const size_t pos) -> int { auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED; static const codepoint_flags undef(codepoint_flags::UNDEFINED);
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
}; };
size_t _prev_end = offset_ini; size_t _prev_end = offset_ini;
@ -261,7 +250,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
const char32_t cpt = _get_cpt(pos); const char32_t cpt = _get_cpt(pos);
const int cpt_type = _get_cpt_type(pos); const auto flags = _get_flags(pos);
// regex: 's|'t|'re|'ve|'m|'ll|'d // regex: 's|'t|'re|'ve|'m|'ll|'d
if (cpt == '\'' && pos+1 < offset_end) { if (cpt == '\'' && pos+1 < offset_end) {
@ -281,39 +270,37 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
} }
} }
char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt); auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
// regex: <space>?\p{L}+ // regex: <space>?\p{L}+
if (cpt2_type == CODEPOINT_TYPE_LETTER) { if (flags2.is_letter) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (cpt2_type == CODEPOINT_TYPE_LETTER) { while (flags2.is_letter) {
cpt2_type = _get_cpt_type(++pos); flags2 = _get_flags(++pos);
} }
_add_token(pos); _add_token(pos);
continue; continue;
} }
// regex: <space>?\p{N}+ // regex: <space>?\p{N}+
if (cpt2_type == CODEPOINT_TYPE_NUMBER) { if (flags2.is_number) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (cpt2_type == CODEPOINT_TYPE_NUMBER) { while (flags2.is_number) {
cpt2_type = _get_cpt_type(++pos); flags2 = _get_flags(++pos);
} }
_add_token(pos); _add_token(pos);
continue; continue;
} }
// regex: <space>?[^\s\p{L}\p{N}]+ // regex: <space>?[^\s\p{L}\p{N}]+
if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
cpt2_type = _get_cpt_type(++pos); flags2 = _get_flags(++pos);
cpt2 = _get_cpt(pos);
} }
_add_token(pos); _add_token(pos);
continue; continue;
} }
size_t num_whitespaces = 0; size_t num_whitespaces = 0;
while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) { while (_get_flags(pos+num_whitespaces).is_whitespace) {
num_whitespaces++; num_whitespaces++;
} }
@ -357,8 +344,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0; return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
}; };
auto _get_cpt_type = [&] (const size_t pos) -> int { auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED; static const codepoint_flags undef(codepoint_flags::UNDEFINED);
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
}; };
size_t _prev_end = offset_ini; size_t _prev_end = offset_ini;
@ -380,7 +368,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
const char32_t cpt = _get_cpt(pos); const char32_t cpt = _get_cpt(pos);
const int cpt_type = _get_cpt_type(pos); const auto flags = _get_flags(pos);
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
if (cpt == '\'' && pos+1 < offset_end) { if (cpt == '\'' && pos+1 < offset_end) {
@ -401,10 +389,10 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
} }
// regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct? // regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct?
if (cpt != '\r' && cpt != '\n' && /*cpt_type != CODEPOINT_TYPE_LETTER &&*/ cpt_type != CODEPOINT_TYPE_NUMBER) { if (!(cpt == '\r' || cpt == '\n' || /*flags.is_letter |*/ flags.is_number)) {
if (cpt_type == CODEPOINT_TYPE_LETTER || _get_cpt_type(pos+1) == CODEPOINT_TYPE_LETTER) { // one or more letters if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
pos++; pos++;
while (_get_cpt_type(pos) == CODEPOINT_TYPE_LETTER) { while (_get_flags(pos).is_letter) {
pos++; pos++;
} }
_add_token(pos); _add_token(pos);
@ -413,9 +401,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
} }
// regex: \p{N}{1,3} // regex: \p{N}{1,3}
if (cpt_type == CODEPOINT_TYPE_NUMBER) { if (flags.is_number) {
size_t ini = pos; size_t ini = pos;
while (_get_cpt_type(pos) == CODEPOINT_TYPE_NUMBER) { while (_get_flags(pos).is_number) {
if (++pos - ini >= 3 ) { if (++pos - ini >= 3 ) {
_add_token(pos); _add_token(pos);
ini = pos; ini = pos;
@ -426,14 +414,13 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
} }
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]* // regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt); auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type); if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
cpt2_type = _get_cpt_type(++pos); flags2 = _get_flags(++pos);
cpt2 = _get_cpt(pos);
} }
char32_t cpt2 = _get_cpt(pos);
while (cpt2 == '\r' || cpt2 == '\n') { while (cpt2 == '\r' || cpt2 == '\n') {
cpt2 = _get_cpt(++pos); cpt2 = _get_cpt(++pos);
} }
@ -443,7 +430,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
size_t num_whitespaces = 0; size_t num_whitespaces = 0;
size_t last_end_r_or_n = 0; size_t last_end_r_or_n = 0;
while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) { while (_get_flags(pos+num_whitespaces).is_whitespace) {
char32_t cpt2 = _get_cpt(pos+num_whitespaces); char32_t cpt2 = _get_cpt(pos+num_whitespaces);
if (cpt2 == '\r' || cpt2 == '\n') { if (cpt2 == '\r' || cpt2 == '\n') {
last_end_r_or_n = pos + num_whitespaces + 1; last_end_r_or_n = pos + num_whitespaces + 1;
@ -589,17 +576,8 @@ std::string unicode_cpt_to_utf8(uint32_t cp) {
} }
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) { std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
std::vector<uint32_t> result; (void) cpts;
result.reserve(cpts.size()); return {}; //####WIP
for (size_t i = 0; i < cpts.size(); ++i) {
auto it = unicode_map_nfd.find(cpts[i]);
if (it == unicode_map_nfd.end()) {
result.push_back(cpts[i]);
} else {
result.push_back(it->second);
}
}
return result;
} }
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) { std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
@ -611,31 +589,19 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
return result; return result;
} }
int unicode_cpt_type(uint32_t cp) { codepoint_flags unicode_cpt_flags(const uint32_t cp) {
static std::unordered_map<uint32_t, int> cpt_types = unicode_cpt_type_map(); static const codepoint_flags undef(codepoint_flags::UNDEFINED);
const auto it = cpt_types.find(cp); static const auto cpt_flags = unicode_cpt_flags_array();
return it == cpt_types.end() ? CODEPOINT_TYPE_UNIDENTIFIED : it->second; return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
} }
int unicode_cpt_type(const std::string & utf8) { codepoint_flags unicode_cpt_flags(const std::string & utf8) {
if (utf8.length() == 0) { static const codepoint_flags undef(codepoint_flags::UNDEFINED);
return CODEPOINT_TYPE_UNIDENTIFIED; if (utf8.empty()) {
return undef; // undefined
} }
size_t offset = 0; size_t offset = 0;
return unicode_cpt_type(unicode_cpt_from_utf8(utf8, offset)); return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset));
}
bool unicode_cpt_is_whitespace(uint32_t cp) {
static const std::unordered_set<uint32_t> is_whitespace = [] {
std::unordered_set<uint32_t> is_whitespace;
for (auto p : unicode_ranges_whitespace) {
for (auto i = p.first; i <= p.second; ++i) {
is_whitespace.insert(i);
}
}
return is_whitespace;
}();
return (bool)is_whitespace.count(cp);
} }
std::string unicode_byte_to_utf8(uint8_t byte) { std::string unicode_byte_to_utf8(uint8_t byte) {
@ -656,21 +622,21 @@ char32_t unicode_tolower(char32_t cp) {
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) { std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
// unicode categories // unicode categories
static const std::map<std::string, int> k_ucat_enum = { static const std::map<std::string, int> k_ucat_enum = {
{ "\\p{N}", CODEPOINT_TYPE_NUMBER }, { "\\p{N}", codepoint_flags::NUMBER },
{ "\\p{L}", CODEPOINT_TYPE_LETTER }, { "\\p{L}", codepoint_flags::LETTER },
{ "\\p{P}", CODEPOINT_TYPE_PUNCTUATION }, { "\\p{P}", codepoint_flags::PUNCTUATION },
}; };
static const std::map<int, int> k_ucat_cpt = { static const std::map<int, int> k_ucat_cpt = {
{ CODEPOINT_TYPE_NUMBER, 0xD1 }, { codepoint_flags::NUMBER, 0xD1 },
{ CODEPOINT_TYPE_LETTER, 0xD2 }, { codepoint_flags::LETTER, 0xD2 },
{ CODEPOINT_TYPE_PUNCTUATION, 0xD3 }, { codepoint_flags::PUNCTUATION, 0xD3 },
}; };
static const std::map<int, std::string> k_ucat_map = { static const std::map<int, std::string> k_ucat_map = {
{ CODEPOINT_TYPE_NUMBER, "\x30-\x39" }, // 0-9 { codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
{ CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z { codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} { codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
}; };
// compute collapsed codepoints only if needed by at least one regex // compute collapsed codepoints only if needed by at least one regex
@ -701,10 +667,10 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
continue; continue;
} }
const int cpt_type = unicode_cpt_type(cpts[i]); const int cpt_flag = unicode_cpt_flags(cpts[i]).category_flag();
if (k_ucat_cpt.find(cpt_type) != k_ucat_cpt.end()) { if (k_ucat_cpt.find(cpt_flag) != k_ucat_cpt.end()) {
text_collapsed[i] = k_ucat_cpt.at(cpt_type); text_collapsed[i] = k_ucat_cpt.at(cpt_flag);
} else { } else {
text_collapsed[i] = (char) 0xD0; // fallback text_collapsed[i] = (char) 0xD0; // fallback
} }

View file

@ -4,24 +4,56 @@
#include <string> #include <string>
#include <vector> #include <vector>
#define CODEPOINT_TYPE_UNIDENTIFIED 0 struct codepoint_flags {
#define CODEPOINT_TYPE_NUMBER 1 enum {
#define CODEPOINT_TYPE_LETTER 2 UNDEFINED = 0x0001,
#define CODEPOINT_TYPE_SEPARATOR 3 NUMBER = 0x0002, // regex: \p{N}
#define CODEPOINT_TYPE_ACCENT_MARK 4 LETTER = 0x0004, // regex: \p{L}
#define CODEPOINT_TYPE_PUNCTUATION 5 SEPARATOR = 0x0008, // regex: \p{Z}
#define CODEPOINT_TYPE_SYMBOL 6 ACCENT_MARK = 0x0010, // regex: \p{M}
#define CODEPOINT_TYPE_CONTROL 7 PUNCTUATION = 0x0020, // regex: \p{P}
SYMBOL = 0x0040, // regex: \p{S}
CONTROL = 0x0080, // regex: \p{C}
MASK_CATEGORIES = 0x00FF,
};
// codepoint type
uint16_t is_undefined : 1;
uint16_t is_number : 1; // regex: \p{N}
uint16_t is_letter : 1; // regex: \p{L}
uint16_t is_separator : 1; // regex: \p{Z}
uint16_t is_accent_mark : 1; // regex: \p{M}
uint16_t is_punctuation : 1; // regex: \p{P}
uint16_t is_symbol : 1; // regex: \p{S}
uint16_t is_control : 1; // regex: \p{C}
// helper flags
uint16_t is_whitespace : 1; // regex: \s
uint16_t is_lowercase : 1;
uint16_t is_uppercase : 1;
uint16_t is_nfd : 1;
// decode from uint16
inline codepoint_flags(const uint16_t flags=0) {
*reinterpret_cast<uint16_t*>(this) = flags;
}
inline uint16_t as_uint() const {
return *reinterpret_cast<const uint16_t*>(this);
}
inline uint16_t category_flag() const {
return this->as_uint() & MASK_CATEGORIES;
}
};
std::string unicode_cpt_to_utf8(uint32_t cp); std::string unicode_cpt_to_utf8(uint32_t cp);
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8); std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts); std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
int unicode_cpt_type(uint32_t cp); codepoint_flags unicode_cpt_flags(const uint32_t cp);
int unicode_cpt_type(const std::string & utf8); codepoint_flags unicode_cpt_flags(const std::string & utf8);
bool unicode_cpt_is_whitespace(uint32_t cp);
std::string unicode_byte_to_utf8(uint8_t byte); std::string unicode_byte_to_utf8(uint8_t byte);
uint8_t unicode_utf8_to_byte(const std::string & utf8); uint8_t unicode_utf8_to_byte(const std::string & utf8);