Unicode tables: separator, lowercase, uppercase and whitespace

This commit is contained in:
jaime-m-p 2024-05-05 01:19:20 +02:00
parent 69a49ac3a1
commit 8fd849eb90
6 changed files with 987 additions and 433 deletions

View file

@ -12463,7 +12463,7 @@ struct llm_tokenizer_wpm {
continue;
}
code = unicode_tolower(code);
if (type == CODEPOINT_TYPE_WHITESPACE) {
if (type == CODEPOINT_TYPE_SEPARATOR) {
code = ' ';
}
std::string s = unicode_cpt_to_utf8(code);

View file

@ -1,31 +1,51 @@
import regex
import unicodedata
def cpt_to_utf8_str(cpt):
if cpt <= 0xFF:
return bytes([cpt, 0, 0, 0])
elif cpt <= 0xFFFF:
return bytes([cpt & 0xFF, cpt >> 8, 0, 0])
elif cpt <= 0xFFFFFF:
return bytes([cpt & 0xFF, (cpt >> 8) & 0xFF, (cpt >> 16) & 0xFF, 0])
else:
return bytes([cpt & 0xFF, (cpt >> 8) & 0xFF, (cpt >> 16) & 0xFF, cpt >> 24])
if False:
# This code is equivalent to: cpt.to_bytes(4, "little"))
def cpt_to_utf8_str(cpt):
if cpt <= 0xFF:
return bytes([cpt, 0, 0, 0])
elif cpt <= 0xFFFF:
return bytes([cpt & 0xFF, cpt >> 8, 0, 0])
elif cpt <= 0xFFFFFF:
return bytes([cpt & 0xFF, (cpt >> 8) & 0xFF, (cpt >> 16) & 0xFF, 0])
else:
return bytes([cpt & 0xFF, (cpt >> 8) & 0xFF, (cpt >> 16) & 0xFF, cpt >> 24])
def is_match(codepoint, regex_expr):
try:
res = regex.match(regex_expr, cpt_to_utf8_str(codepoint).decode('utf-32'))
return res is not None
except Exception:
return False
# This code is equivalent to: regex_expr_compiled.match(chr(codepoint))
def is_match(codepoint, regex_expr):
try:
res = regex_expr.match(cpt_to_utf8_str(codepoint).decode('utf-32'))
return res is not None
except Exception:
return False
# Verify previous statements, using chr() and ord()
for codepoint in range(0x110000):
temp = cpt_to_utf8_str(codepoint)
assert(temp == codepoint.to_bytes(4, "little"))
try:
char = temp.decode('utf-32')
if codepoint == 0xFEFF: # BOM
assert(char == "") # why?
char = "\uFEFF"
except UnicodeDecodeError:
continue
assert(char == chr(codepoint) )
assert(ord(char) == codepoint )
def get_matches(regex_expr):
regex_expr_compiled = regex.compile(regex_expr)
unicode_ranges = []
current_range = None
for codepoint in range(0x110000):
if is_match(codepoint, regex_expr):
char = chr(codepoint)
if regex_expr_compiled.match(char):
if current_range is None:
current_range = [codepoint, codepoint]
else:
@ -40,27 +60,54 @@ def get_matches(regex_expr):
return unicode_ranges
def print_cat(cat, ranges):
print("const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_{} = {{".format(cat))
cnt = 0
for start, end in ranges:
if cnt % 4 != 0:
print(" ", end="")
print("{{0x{:08X}, 0x{:08X}}},".format(start, end), end="")
if cnt % 4 == 3:
print("")
cnt += 1
if cnt % 4 != 0:
print("")
def print_cat(mode, cat, ranges):
if mode == "range":
print("const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_{} = {{".format(cat))
if mode == "range_value":
print("const std::vector<std::tuple<uint32_t, uint32_t, uint32_t>> unicode_ranges_{} = {{".format(cat))
if mode == "map":
print("const std::map<uint32_t, uint32_t> unicode_map_{} = {{".format(cat))
for i, values in enumerate(ranges):
end = ",\n" if (i%4 == 3 or i+1 == len(ranges)) else ", "
values = ["0x%08X"%value for value in values]
print("{" + ", ".join(values) + "}", end=end)
print("};")
print("")
print_cat("number", get_matches(r'\p{N}'))
print_cat("letter", get_matches(r'\p{L}'))
print_cat("whitespace", get_matches(r'\p{Z}'))
print_cat("accent_mark", get_matches(r'\p{M}'))
print_cat("punctuation", get_matches(r'\p{P}'))
print_cat("symbol", get_matches(r'\p{S}'))
print_cat("control", get_matches(r'\p{C}'))
print_cat("range", "number", get_matches(r'\p{N}'))
print_cat("range", "letter", get_matches(r'\p{L}'))
print_cat("range", "separator", get_matches(r'\p{Z}'))
print_cat("range", "accent_mark", get_matches(r'\p{M}'))
print_cat("range", "punctuation", get_matches(r'\p{P}'))
print_cat("range", "symbol", get_matches(r'\p{S}'))
print_cat("range", "control", get_matches(r'\p{C}'))
print_cat("range", "whitespace", get_matches(r'\s'))
map_lowercase = []
map_uppercase = []
for codepoint in range(0x110000):
char = chr(codepoint)
lower = ord(char.lower()[0])
upper = ord(char.upper()[0])
if codepoint != lower:
map_lowercase.append((codepoint,lower))
if codepoint != upper:
map_uppercase.append((codepoint,upper))
print_cat("map", "lowercase", map_lowercase)
print_cat("map", "uppercase", map_uppercase)
inv_map_nfd = {}
for codepoint in range(0x110000):
char = chr(codepoint)
norm = ord(unicodedata.normalize('NFD', char)[0])
if codepoint != norm:
a, b = inv_map_nfd.get(norm, (codepoint, codepoint))
inv_map_nfd[norm] = (min(a, codepoint), max(b, codepoint))
nfd_ranges = [ (a, b, nfd) for nfd,(a,b) in inv_map_nfd.items() ]
nfd_ranges = list(sorted(nfd_ranges))
del inv_map_nfd
print_cat("range_value", "nfd", nfd_ranges)

File diff suppressed because it is too large Load diff

View file

@ -7,6 +7,7 @@
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_number;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_letter;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_separator;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_accent_mark;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_punctuation;

View file

@ -1,4 +1,4 @@
#include "unicode.h"
#include "unicode.h"
#include "unicode-data.h"
#include <cassert>
@ -9,6 +9,7 @@
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <locale>
@ -120,9 +121,9 @@ static std::unordered_map<uint32_t, int> unicode_cpt_type_map() {
cpt_types[i] = CODEPOINT_TYPE_LETTER;
}
}
for (auto p : unicode_ranges_whitespace) {
for (auto p : unicode_ranges_separator) {
for (auto i = p.first; i <= p.second; ++ i) {
cpt_types[i] = CODEPOINT_TYPE_WHITESPACE;
cpt_types[i] = CODEPOINT_TYPE_SEPARATOR;
}
}
for (auto p : unicode_ranges_accent_mark) {
@ -300,9 +301,9 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
continue;
}
// regex: <space>?[^\s\p{L}\p{N}]+
if (cpt2_type != CODEPOINT_TYPE_WHITESPACE && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
if (cpt2_type != CODEPOINT_TYPE_SEPARATOR && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
pos += (cpt == ' ');
while (cpt2_type != CODEPOINT_TYPE_WHITESPACE && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
while (cpt2_type != CODEPOINT_TYPE_SEPARATOR && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
cpt2_type = _get_cpt_type(++pos);
}
_add_token(pos);
@ -310,7 +311,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
}
size_t num_whitespaces = 0;
while (_get_cpt_type(pos+num_whitespaces) == CODEPOINT_TYPE_WHITESPACE) {
while (_get_cpt_type(pos+num_whitespaces) == CODEPOINT_TYPE_SEPARATOR) {
num_whitespaces++;
}
@ -424,9 +425,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
if (cpt2_type != CODEPOINT_TYPE_WHITESPACE && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
if (cpt2_type != CODEPOINT_TYPE_SEPARATOR && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
pos += (cpt == ' ');
while (cpt2_type != CODEPOINT_TYPE_WHITESPACE && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
while (cpt2_type != CODEPOINT_TYPE_SEPARATOR && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
cpt2_type = _get_cpt_type(++pos);
}
char32_t cpt2 = _get_cpt(pos);
@ -439,7 +440,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
size_t num_whitespaces = 0;
size_t last_end_r_or_n = 0;
while (_get_cpt_type(pos+num_whitespaces) == CODEPOINT_TYPE_WHITESPACE) {
while (_get_cpt_type(pos+num_whitespaces) == CODEPOINT_TYPE_SEPARATOR) {
char32_t cpt2 = _get_cpt(pos+num_whitespaces);
if (cpt2 == '\r' || cpt2 == '\n') {
last_end_r_or_n = pos + num_whitespaces + 1;
@ -621,6 +622,19 @@ int unicode_cpt_type(const std::string & utf8) {
return unicode_cpt_type(unicode_cpt_from_utf8(utf8, offset));
}
bool unicode_cpt_is_whitespace(uint32_t cp) {
static const std::unordered_set<uint32_t> is_whitespace = [] {
std::unordered_set<uint32_t> is_whitespace;
for (auto p : unicode_ranges_whitespace) {
for (auto i = p.first; i <= p.second; ++ i) {
is_whitespace.insert(i);
}
}
return is_whitespace;
}();
return (bool)is_whitespace.count(cp);
}
std::string unicode_byte_to_utf8(uint8_t byte) {
static std::unordered_map<uint8_t, std::string> map = unicode_byte_to_utf8_map();
return map.at(byte);

View file

@ -7,7 +7,7 @@
#define CODEPOINT_TYPE_UNIDENTIFIED 0
#define CODEPOINT_TYPE_NUMBER 1
#define CODEPOINT_TYPE_LETTER 2
#define CODEPOINT_TYPE_WHITESPACE 3
#define CODEPOINT_TYPE_SEPARATOR 3
#define CODEPOINT_TYPE_ACCENT_MARK 4
#define CODEPOINT_TYPE_PUNCTUATION 5
#define CODEPOINT_TYPE_SYMBOL 6
@ -21,6 +21,8 @@ std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & c
int unicode_cpt_type(uint32_t cp);
int unicode_cpt_type(const std::string & utf8);
bool unicode_cpt_is_whitespace(uint32_t cp);
std::string unicode_byte_to_utf8(uint8_t byte);
uint8_t unicode_utf8_to_byte(const std::string & utf8);