This commit is contained in:
jaime-m-p 2024-09-08 10:06:51 +02:00 committed by GitHub
commit dffb4b1909
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 5275 additions and 2746 deletions

View file

@ -49,53 +49,42 @@ def unicode_data_iter():
yield (cpt, cpt_lower, cpt_upper, categ, bidir)
# see definition in unicode.h
CODEPOINT_FLAG_UNDEFINED = 0x0001 #
CODEPOINT_FLAG_NUMBER = 0x0002 # \p{N}
CODEPOINT_FLAG_LETTER = 0x0004 # \p{L}
CODEPOINT_FLAG_SEPARATOR = 0x0008 # \p{Z}
CODEPOINT_FLAG_MARK = 0x0010 # \p{M}
CODEPOINT_FLAG_PUNCTUATION = 0x0020 # \p{P}
CODEPOINT_FLAG_SYMBOL = 0x0040 # \p{S}
CODEPOINT_FLAG_CONTROL = 0x0080 # \p{C}
UNICODE_CATEGORY_TO_FLAG = {
"Cn": CODEPOINT_FLAG_UNDEFINED, # Undefined
"Cc": CODEPOINT_FLAG_CONTROL, # Control
"Cf": CODEPOINT_FLAG_CONTROL, # Format
"Co": CODEPOINT_FLAG_CONTROL, # Private Use
"Cs": CODEPOINT_FLAG_CONTROL, # Surrrogate
"Ll": CODEPOINT_FLAG_LETTER, # Lowercase Letter
"Lm": CODEPOINT_FLAG_LETTER, # Modifier Letter
"Lo": CODEPOINT_FLAG_LETTER, # Other Letter
"Lt": CODEPOINT_FLAG_LETTER, # Titlecase Letter
"Lu": CODEPOINT_FLAG_LETTER, # Uppercase Letter
"L&": CODEPOINT_FLAG_LETTER, # Cased Letter
"Mc": CODEPOINT_FLAG_MARK, # Spacing Mark
"Me": CODEPOINT_FLAG_MARK, # Enclosing Mark
"Mn": CODEPOINT_FLAG_MARK, # Nonspacing Mark
"Nd": CODEPOINT_FLAG_NUMBER, # Decimal Number
"Nl": CODEPOINT_FLAG_NUMBER, # Letter Number
"No": CODEPOINT_FLAG_NUMBER, # Other Number
"Pc": CODEPOINT_FLAG_PUNCTUATION, # Connector Punctuation
"Pd": CODEPOINT_FLAG_PUNCTUATION, # Dash Punctuation
"Pe": CODEPOINT_FLAG_PUNCTUATION, # Close Punctuation
"Pf": CODEPOINT_FLAG_PUNCTUATION, # Final Punctuation
"Pi": CODEPOINT_FLAG_PUNCTUATION, # Initial Punctuation
"Po": CODEPOINT_FLAG_PUNCTUATION, # Other Punctuation
"Ps": CODEPOINT_FLAG_PUNCTUATION, # Open Punctuation
"Sc": CODEPOINT_FLAG_SYMBOL, # Currency Symbol
"Sk": CODEPOINT_FLAG_SYMBOL, # Modifier Symbol
"Sm": CODEPOINT_FLAG_SYMBOL, # Math Symbol
"So": CODEPOINT_FLAG_SYMBOL, # Other Symbol
"Zl": CODEPOINT_FLAG_SEPARATOR, # Line Separator
"Zp": CODEPOINT_FLAG_SEPARATOR, # Paragraph Separator
"Zs": CODEPOINT_FLAG_SEPARATOR, # Space Separator
# see codepoint_categ::from_index() in unicode.h
UNICODE_CATEGORY_TO_INDEX = {
"Cn": 0, # \p{Cn} Undefined
"Cc": 1, # \p{Cc} Control
"Cf": 2, # \p{Cf} Format
"Co": 3, # \p{Co} Private Use
"Cs": 4, # \p{Cs} Surrrogate
"Ll": 5, # \p{Ll} Lowercase Letter
"Lm": 6, # \p{Lm} Modifier Letter
"Lo": 7, # \p{Lo} Other Letter
"Lt": 8, # \p{Lt} Titlecase Letter
"Lu": 9, # \p{Lu} Uppercase Letter
"Mc": 10, # \p{Mc} Spacing Mark
"Me": 11, # \p{Me} Enclosing Mark
"Mn": 12, # \p{Mn} Nonspacing Mark
"Nd": 13, # \p{Nd} Decimal Number
"Nl": 14, # \p{Nl} Letter Number
"No": 15, # \p{No} Other Number
"Pc": 16, # \p{Pc} Connector Punctuation
"Pd": 17, # \p{Pd} Dash Punctuation
"Pe": 18, # \p{Pe} Close Punctuation
"Pf": 19, # \p{Pf} Final Punctuation
"Pi": 20, # \p{Pi} Initial Punctuation
"Po": 21, # \p{Po} Other Punctuation
"Ps": 22, # \p{Ps} Open Punctuation
"Sc": 23, # \p{Sc} Currency Symbol
"Sk": 24, # \p{Sk} Modifier Symbol
"Sm": 25, # \p{Sm} Math Symbol
"So": 26, # \p{So} Other Symbol
"Zl": 27, # \p{Zl} Line Separator
"Zp": 28, # \p{Zp} Paragraph Separator
"Zs": 29, # \p{Zs} Space Separator
}
codepoint_flags = array.array('H', [CODEPOINT_FLAG_UNDEFINED]) * MAX_CODEPOINTS
table_whitespace = []
codepoint_categs = array.array('B', [0]) * MAX_CODEPOINTS # Undefined
table_lowercase = []
table_uppercase = []
table_nfd = []
@ -105,7 +94,7 @@ for (cpt, cpt_lower, cpt_upper, categ, bidir) in unicode_data_iter():
char = chr(cpt)
# codepoint category flags
codepoint_flags[cpt] = UNICODE_CATEGORY_TO_FLAG[categ]
codepoint_categs[cpt] = UNICODE_CATEGORY_TO_INDEX[categ]
# lowercase conversion
if cpt_lower:
@ -121,25 +110,31 @@ for (cpt, cpt_lower, cpt_upper, categ, bidir) in unicode_data_iter():
table_nfd.append((cpt, norm))
# whitespaces, see "<White_Space>" https://www.unicode.org/Public/UCD/latest/ucd/PropList.txt
table_whitespace.extend(range(0x0009, 0x000D + 1))
table_whitespace.extend(range(0x2000, 0x200A + 1))
table_whitespace.extend([0x0020, 0x0085, 0x00A0, 0x1680, 0x2028, 0x2029, 0x202F, 0x205F, 0x3000])
# sort by codepoint
table_whitespace.sort()
table_lowercase.sort()
table_uppercase.sort()
table_nfd.sort()
# group ranges with same flags
ranges_flags: list[tuple[int, int]] = [(0, codepoint_flags[0])] # start, flags
for codepoint, flags in enumerate(codepoint_flags):
if flags != ranges_flags[-1][1]:
ranges_flags.append((codepoint, flags))
ranges_flags.append((MAX_CODEPOINTS, 0x0000))
# whitespaces, see "<White_Space>" https://www.unicode.org/Public/UCD/latest/ucd/PropList.txt
whitespace_ranges: list[tuple[int, int]] = [] # start, last
whitespace_ranges.append((0x0009, 0x000D))
whitespace_ranges.append((0x2000, 0x200A))
for whitespace in [0x0020, 0x0085, 0x00A0, 0x1680, 0x2028, 0x2029, 0x202F, 0x205F, 0x3000]:
whitespace_ranges.append((whitespace, whitespace))
# run length encoding, see unicode_cpt_category() in unicode.cpp
assert (max(UNICODE_CATEGORY_TO_INDEX.values()) < 32)
codepoint_categs_runs = [codepoint_categs[0]] # 5 bits categ + 11 bits length
for cpt, categ in enumerate(codepoint_categs[1:], 1):
prev = codepoint_categs_runs[-1]
if prev <= (0xFFFF - 32) and (prev & 31) == categ:
codepoint_categs_runs[-1] += 32 # increment run length
else:
codepoint_categs_runs.append(categ) # new run value
assert (codepoint_categs_runs[-1] < 0xFFFF)
assert (MAX_CODEPOINTS == sum((rle >> 5) + 1 for rle in codepoint_categs_runs))
# group ranges with same nfd
@ -153,7 +148,7 @@ for codepoint, norm in table_nfd:
# Generate 'unicode-data.cpp':
# python ./scripts//gen-unicode-data.py > unicode-data.cpp
# python ./scripts//gen-unicode-data.py > ./src/unicode-data.cpp
def out(line=""):
print(line, end='\n') # noqa
@ -167,17 +162,16 @@ out("""\
#include <cstdint>
#include <vector>
#include <unordered_map>
#include <unordered_set>
""")
out("const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags = { // start, flags // last=next_start-1")
for codepoint, flags in ranges_flags:
out("{0x%06X, 0x%04X}," % (codepoint, flags))
out("const std::vector<uint16_t> unicode_rle_codepoints_categs = { // run length encoding, 5 bits categ + 11 bits length")
for rle in codepoint_categs_runs:
out("0x%04X," % rle)
out("};\n")
out("const std::unordered_set<uint32_t> unicode_set_whitespace = {")
for codepoint in table_whitespace:
out("0x%06X," % codepoint)
out("const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace = {")
for (start, last) in whitespace_ranges:
out("{0x%06X, 0x%06X}," % (start, last))
out("};\n")
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")

View file

@ -444,10 +444,8 @@ struct llm_tokenizer_bpe {
};
break;
case LLAMA_VOCAB_PRE_TYPE_TEKKEN:
// original regex from tokenizer.json
// "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
regex_exprs = {
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
default:
@ -704,22 +702,21 @@ struct llm_tokenizer_wpm {
std::vector<std::string> words(1, "");
for (const uint32_t cpt : cpts_nfd) {
const auto flags = unicode_cpt_flags(cpt);
const auto categ = unicode_cpt_category(cpt);
if (flags.is_whitespace) {
if (categ.is_whitespace()) {
if (words.back().size()) { // finish previous word if any
words.emplace_back();
}
continue;
}
assert (!flags.is_separator);
if (cpt == 0 || cpt == 0xFFFD || flags.is_control) {
if (cpt == 0 || cpt == 0xFFFD || categ.is_C()) {
continue;
}
const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt));
if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) {
if (categ.is_P() || (cpt < 0x7F && categ.is_S()) || is_chinese_char(cpt)) {
if (words.back().size()) { // finish previous word if any
words.emplace_back();
}
@ -737,7 +734,7 @@ struct llm_tokenizer_wpm {
return words;
}
static bool is_chinese_char(uint32_t cpt) {
static bool is_chinese_char(uint32_t cpt) { //TODO: move to unicode-data.cpp? unicode_cpt_category(cpt).is_chinese()?
return
(cpt >= 0x04E00 && cpt <= 0x09FFF) ||
(cpt >= 0x03400 && cpt <= 0x04DBF) ||

File diff suppressed because it is too large Load diff

View file

@ -3,7 +3,6 @@
#include <cstdint>
#include <vector>
#include <unordered_map>
#include <unordered_set>
struct range_nfd {
uint32_t first;
@ -13,8 +12,8 @@ struct range_nfd {
static const uint32_t MAX_CODEPOINTS = 0x110000;
extern const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags;
extern const std::unordered_set<uint32_t> unicode_set_whitespace;
extern const std::vector<uint16_t> unicode_rle_codepoints_categs;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace;
extern const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase;
extern const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase;
extern const std::vector<range_nfd> unicode_ranges_nfd;

View file

@ -2,10 +2,10 @@
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
#endif
#include "ggml.h"
#include "unicode.h"
#include "unicode-data.h"
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <map>
@ -119,38 +119,6 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
// return result;
//}
static std::vector<codepoint_flags> unicode_cpt_flags_array() {
std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
assert (unicode_ranges_flags.front().first == 0);
assert (unicode_ranges_flags.back().first == MAX_CODEPOINTS);
for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
const auto range_ini = unicode_ranges_flags[i-1]; // codepoint_ini, flags
const auto range_end = unicode_ranges_flags[i]; // codepoint_end, flags
for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
cpt_flags[cpt] = range_ini.second;
}
}
for (auto cpt : unicode_set_whitespace) {
cpt_flags[cpt].is_whitespace = true;
}
for (auto p : unicode_map_lowercase) {
cpt_flags[p.second].is_lowercase = true;
}
for (auto p : unicode_map_uppercase) {
cpt_flags[p.second].is_uppercase = true;
}
for (auto &range : unicode_ranges_nfd) { // start, last, nfd
cpt_flags[range.nfd].is_nfd = true;
}
return cpt_flags;
}
static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
std::unordered_map<uint8_t, std::string> map;
for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
@ -233,7 +201,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
for (auto offset : offsets) {
const size_t offset_ini = start;
const size_t offset_end = start + offset;
assert(offset_end <= cpts.size());
GGML_ASSERT(offset_end <= cpts.size());
start = offset_end;
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
@ -241,13 +209,14 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
};
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
static const codepoint_categ SENTINEL = codepoint_categ::UNDEF + 1;
auto _get_categ = [&] (const size_t pos) -> codepoint_categ {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_category(cpts[pos]) : SENTINEL;
};
size_t _prev_end = offset_ini;
auto _add_token = [&] (const size_t end) -> size_t {
assert(_prev_end <= end && end <= offset_end);
GGML_ASSERT(_prev_end <= end && end <= offset_end);
size_t len = end - _prev_end;
if (len > 0) {
bpe_offsets.push_back(len);
@ -264,7 +233,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
const uint32_t cpt = _get_cpt(pos);
const auto flags = _get_flags(pos);
const auto categ = _get_categ(pos);
// regex: 's|'t|'re|'ve|'m|'ll|'d
if (cpt == '\'' && pos+1 < offset_end) {
@ -284,37 +253,37 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
}
}
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
auto categ2 = (cpt == ' ' ? _get_categ(pos+1) : categ);
// regex: <space>?\p{L}+
if (flags2.is_letter) {
if (categ2.is_L()) {
pos += (cpt == ' ');
while (flags2.is_letter) {
flags2 = _get_flags(++pos);
while (categ2.is_L()) {
categ2 = _get_categ(++pos);
}
_add_token(pos);
continue;
}
// regex: <space>?\p{N}+
if (flags2.is_number) {
if (categ2.is_N()) {
pos += (cpt == ' ');
while (flags2.is_number) {
flags2 = _get_flags(++pos);
while (categ2.is_N()) {
categ2 = _get_categ(++pos);
}
_add_token(pos);
continue;
}
// regex: <space>?[^\s\p{L}\p{N}]+
if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
if (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) {
pos += (cpt == ' ');
while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
flags2 = _get_flags(++pos);
while (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) {
categ2 = _get_categ(++pos);
}
_add_token(pos);
continue;
}
size_t num_whitespaces = 0;
while (_get_flags(pos+num_whitespaces).is_whitespace) {
while (_get_categ(pos+num_whitespaces).is_whitespace()) {
num_whitespaces++;
}
@ -351,7 +320,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
for (auto offset : offsets) {
const size_t offset_ini = start;
const size_t offset_end = start + offset;
assert(offset_end <= cpts.size());
GGML_ASSERT(offset_end <= cpts.size());
start = offset_end;
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
@ -359,13 +328,14 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
};
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
static const codepoint_categ SENTINEL = codepoint_categ::UNDEF + 1;
auto _get_categ = [&] (const size_t pos) -> codepoint_categ {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_category(cpts[pos]) : SENTINEL;
};
size_t _prev_end = offset_ini;
auto _add_token = [&] (const size_t end) -> size_t {
assert(_prev_end <= end && end <= offset_end);
GGML_ASSERT(_prev_end <= end && end <= offset_end);
size_t len = end - _prev_end;
if (len > 0) {
bpe_offsets.push_back(len);
@ -382,7 +352,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
const uint32_t cpt = _get_cpt(pos);
const auto flags = _get_flags(pos);
const auto categ = _get_categ(pos);
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
if (cpt == '\'' && pos+1 < offset_end) {
@ -403,10 +373,10 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
}
// regex: [^\r\n\p{L}\p{N}]?\p{L}+
if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) {
if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
if (!(cpt == '\r' || cpt == '\n' || categ.is_N())) {
if (categ.is_L() || _get_categ(pos+1).is_L()) { // one or more letters
pos++;
while (_get_flags(pos).is_letter) {
while (_get_categ(pos).is_L()) {
pos++;
}
_add_token(pos);
@ -415,9 +385,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
}
// regex: \p{N}{1,3}
if (flags.is_number) {
if (categ.is_N()) {
size_t ini = pos;
while (_get_flags(pos).is_number) {
while (_get_categ(pos).is_N()) {
if (++pos - ini >= 3 ) {
_add_token(pos);
ini = pos;
@ -428,11 +398,11 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
}
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) {
auto categ2 = (cpt == ' ' ? _get_categ(pos+1) : categ);
if (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) {
pos += (cpt == ' ');
while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
flags2 = _get_flags(++pos);
while (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) {
categ2 = _get_categ(++pos);
}
uint32_t cpt2 = _get_cpt(pos);
while (cpt2 == '\r' || cpt2 == '\n') {
@ -444,7 +414,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
size_t num_whitespaces = 0;
size_t last_end_r_or_n = 0;
while (_get_flags(pos+num_whitespaces).is_whitespace) {
while (_get_categ(pos+num_whitespaces).is_whitespace()) {
uint32_t cpt2 = _get_cpt(pos+num_whitespaces);
if (cpt2 == '\r' || cpt2 == '\n') {
last_end_r_or_n = pos + num_whitespaces + 1;
@ -481,66 +451,6 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
return bpe_offsets;
}
// use std::wregex to split the text
static std::vector<size_t> unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector<size_t> & offsets) {
std::wregex expr(regex_expr);
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
for (auto offset : offsets) {
std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
std::wcregex_iterator end;
int64_t start_idx = 0;
while (it != end) {
std::wcmatch match = *it;
if (match.position() > start_idx) {
bpe_offsets.emplace_back(match.position() - start_idx);
}
bpe_offsets.emplace_back(match.length());
start_idx = match.position() + match.length();
++it;
}
if (start_idx < (int64_t) offset) {
bpe_offsets.emplace_back(offset - start_idx);
}
start += offset;
}
return bpe_offsets;
}
// use std::regex to split the text
static std::vector<size_t> unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
std::regex expr(regex_expr);
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
for (auto offset : offsets) {
std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
std::cregex_iterator end;
int64_t start_idx = 0;
while (it != end) {
std::cmatch match = *it;
if (match.position() > start_idx) {
bpe_offsets.emplace_back(match.position() - start_idx);
}
bpe_offsets.emplace_back(match.length());
start_idx = match.position() + match.length();
++it;
}
if (start_idx < (int64_t) offset) {
bpe_offsets.emplace_back(offset - start_idx);
}
start += offset;
}
return bpe_offsets;
}
static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets;
@ -556,6 +466,269 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
return bpe_offsets;
}
// Custom std::regex specializations for 32bit unicode codepoints
// std::wregex does not support unicode categories: \p{N}, \p{L}, \p{Lu}, \p{Ll} ...
// std::wregex does not support unicode whitespaces \s: 0x85, 0xA0, 0x001680 ... 0x003000.
// std::wregex supports full 32 bit codepoints, not limited to standard max 0x110000.
namespace std {
// codepoint type for all template specializations
#if (WCHAR_MAX > 0xFFFF)
using codepoint = wchar_t; // sizeof(wchar_t) == 4
#else
using codepoint = uint32_t; // Windows: sizeof(wchar_t) == 2
#define CUSTOM_CTYPE_CODEPOINT
#endif
#ifdef CUSTOM_CTYPE_CODEPOINT
// Minimal required implementation for std::regex string processing
template<> // custom specialized std::ctype<codepoint>
class ctype<codepoint> {
public:
using CharT = codepoint;
using char_type = CharT;
using mask = uint8_t; //NOTE: see std::ctype_base
static const mask digit = 1; // requiered variable names
static const mask xdigit = 2; // user defined values
static const mask alpha = 3; // used to be a bitmask
static const mask upper = 4; // we do not need a bitmask
static const mask lower = 5; // using a sequence instead
static locale::id id; // required by std::locale::facet
bool is(mask m, char_type c) const {
switch (m) {
case digit: return ('0' <= c && c <= '9');
case xdigit: return ('0' <= c && c <= '9') || ('A' <= c && c <= 'F');
case alpha: return ('A' <= c && c <= 'Z') || ('a' <= c && c <= 'z');
case upper: return ('A' <= c && c <= 'Z');
case lower: return ('a' <= c && c <= 'z');
default: return false;
}
}
char_type toupper(char_type c) const {
return ('a' <= c && c <= 'z') ? c - ('a' - 'A') : c;
}
char_type tolower(char_type c) const {
return ('A' <= c && c <= 'Z') ? c + ('a' - 'A') : c;
}
char_type widen(char c) const { // char to codepoint
return (char_type) c;
}
char narrow(char_type c, char dfault) const { // codepoint to char
return (c < 0x80 ? (char)c : dfault);
}
};
locale::id ctype<codepoint>::id = {};
template<> // specialization to use our custom specialized std::ctype<codepoint>
const std::ctype<codepoint> & use_facet<std::ctype<codepoint>>(const std::locale &) {
static std::ctype<codepoint> ctype_uint32 = {};
return ctype_uint32;
}
template<> // specialization to use our custom specialized std::ctype<codepoint>
const std::ctype<codepoint> & use_facet<const std::ctype<codepoint>>(const std::locale & loc) {
return use_facet<std::ctype<codepoint>>(loc);
}
#endif
// Minimal required implementation for std::regex string processing
template<> // custom specialized std::regex_traits<codepoint>
class regex_traits<codepoint> {
public:
using CharT = codepoint;
using char_type = codepoint;
using size_type = size_t;
using string_type = std::basic_string<CharT>;
using locale_type = std::locale;
using char_class_type = uint64_t;
#if (defined(_WIN32) || defined(_WIN64)) // MSVC class _Regex_traits
using _Uelem = CharT;
static const auto _Ch_upper = std::ctype<CharT>::upper;
static const auto _Ch_alpha = std::ctype<CharT>::alpha;
#endif
CharT translate(CharT c) const {
return c;
}
CharT translate_nocase(CharT c) const {
return unicode_tolower(c);
}
template<typename It>
string_type transform(It first, It last) const {
GGML_ASSERT(false); //TODO: not needed ?
return {first, last}; //TODO: not tested
}
template<typename It>
string_type transform_primary(It first, It last) const {
(void) first;
(void) last;
GGML_ASSERT((uint32_t) *first < MAX_CODEPOINTS); // check valid codepoint
return {};
}
template<typename It>
string_type lookup_collatename(It first, It last) const {
(void) last;
GGML_ASSERT(*first & (1 << 31));
return {*first};
}
template<typename It>
char_class_type lookup_classname(It first, It last, bool icase = false) const {
(void) last;
(void) icase;
const uint32_t encoded = *first;
codepoint_categ categ = {};
switch(encoded) {
case 's':
case 'S': // negation is internally tracked
categ.set_flag(codepoint_categ::WHITESPACES);
return categ.expand_bits();
case 'w':
case 'W': // negation is internally tracked
categ.set_flag(codepoint_categ::WORDS);
return categ.expand_bits();
case 'd':
case 'D': // negation is internally tracked
categ.set_flag(codepoint_categ::DIGITS);
return categ.expand_bits();
default: { // unicode category \p{Xx} encoded in codepoint
GGML_ASSERT(encoded & (1 << 31)); // make sure its our custom codepoint encoding the category
const bool negated = encoded & (1 << 30); // negation of 'character class expression' are not internally tracked
categ = {(uint16_t) encoded};
return ((uint64_t) negated << 63) | categ.expand_bits(false);
}
}
}
bool isctype(CharT c, char_class_type mask) const {
const bool negated = mask & (1llu << 63);
mask &= unicode_cpt_category(c).expand_bits();
return negated ^ (bool) mask;
}
int value(CharT c, int radix) const { // char to int value
switch (radix) {
case 8: return ('0' <= c && c <= '7') ? (int)c - '0' : -1;
case 10: return ('0' <= c && c <= '9') ? (int)c - '0' : -1;
case 16: return ('0' <= c && c <= '9') ? (int)c - '0' : (('A' <= c && c <= 'F') ? (int)c - 'A' + 10 : -1);
default: return -1;
}
}
const locale_type & imbue(const locale_type &) { // set locale //NOTE: ignoring locales
return std::locale::classic();
}
const locale_type & getloc() const { // get locale //NOTE: ignoring locales
return std::locale::classic();
}
};
}
static std::vector<uint32_t> unicode_regex_prepare(const std::string & regex) {
std::vector<uint32_t> regex_cpts;
regex_cpts.reserve(regex.size() * 12 / 10); // estimate +20%
size_t offset = 0;
int inside_square = 0;
bool any_positive = false;
bool any_negative = false;
const size_t size = regex.size();
while (offset < size) {
inside_square += regex[offset] == '[';
inside_square -= regex[offset] == ']';
GGML_ASSERT(inside_square >= 0);
if (!inside_square) {
any_positive = false;
any_negative = false;
}
if (regex[offset] == '\\') {
const size_t i = offset + 1;
if (regex[i] == 'p' || regex[i] == 'P') {
// convert \p{Xx} to custom 'character class expression' [:Xy:]
if (regex[i + 1] == '{' && regex[i + 2] && regex[i + 3]) {
codepoint_categ categ = {};
if (regex[i + 3] == '}') {
categ = codepoint_categ::from_chars(regex[i + 2]);
offset += 5;
} else if (regex[i + 3] != '}' && regex[i + 4] == '}') {
categ = codepoint_categ::from_chars(regex[i + 2], regex[i + 3]);
offset += 6;
}
bool negated = regex[i] == 'P';
any_positive |= !negated;
any_negative |= negated;
GGML_ASSERT(any_positive != any_negative); //BUG: can not mix 'p' and 'P' inside []
GGML_ASSERT(sizeof(categ) <= 2);
// encoded category in 32 bits codepoint
uint32_t cpt_categ = (1 << 31) | (negated << 30) | categ.encoded;
if (inside_square) {
regex_cpts.insert(regex_cpts.end(), {'[', ':', cpt_categ, ':', ']'});
} else {
regex_cpts.insert(regex_cpts.end(), {'[', '[', ':', cpt_categ, ':', ']', ']'});
}
continue;
}
}
}
regex_cpts.push_back(unicode_cpt_from_utf8(regex, offset));
}
return regex_cpts;
}
// use std::basic_regex<uint32_t> to split the text codepoints
static std::vector<size_t> unicode_regex_split_stl(const std::vector<uint32_t> & text_cpts, const std::vector<uint32_t> & regex_cpts, const std::vector<size_t> & offsets) {
GGML_ASSERT(sizeof(std::codepoint) == sizeof(uint32_t));
using regex_type = std::basic_regex<std::codepoint>;
using iter_type = std::regex_iterator<const std::codepoint *>;
const std::codepoint * text_data = (const std::codepoint *) text_cpts.data();
const std::codepoint * regex_data = (const std::codepoint *) regex_cpts.data();
regex_type regex(regex_data, regex_data+regex_cpts.size());
const iter_type end;
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // reserve memory for the approximate size
for (auto offset : offsets) {
iter_type it(text_data, text_data + offset, regex);
int64_t start_idx = 0;
while (it != end) {
if (it->position() > start_idx) {
bpe_offsets.emplace_back(it->position() - start_idx);
}
bpe_offsets.emplace_back(it->length());
start_idx = it->position() + it->length();
++it;
}
if (start_idx < (int64_t) offset) {
bpe_offsets.emplace_back(offset - start_idx);
}
text_data += offset;
}
return bpe_offsets;
}
//
// interface
//
@ -612,19 +785,46 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
return result;
}
codepoint_flags unicode_cpt_flags(const uint32_t cp) {
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
static const auto cpt_flags = unicode_cpt_flags_array();
return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
codepoint_categ unicode_cpt_category(const uint32_t cp) {
static const std::vector<codepoint_categ> cpt_categs = [] {
std::vector<codepoint_categ> cpt_categs(MAX_CODEPOINTS, codepoint_categ::UNDEF);
uint32_t cpt = 0;
for (uint16_t rle : unicode_rle_codepoints_categs) {
const uint32_t index = rle & 31;
const uint32_t count = rle >> 5;
auto categ = codepoint_categ::from_index(index);
//printf("Codepoints 0x%05X to 0x%05X categ %s\n", cpt, cpt + count, categ.c_str());
categ.set_flag(codepoint_categ::DIGITS, categ.is_Nd()); // \d --> \p{Nd}
categ.set_flag(codepoint_categ::WORDS, categ.is_L() | categ.is_N()); // \w --> \p{L} \p{N} _
for (uint32_t i = 0; i <= count; ++i) {
cpt_categs[cpt++] = categ;
}
}
GGML_ASSERT(cpt == MAX_CODEPOINTS);
cpt_categs['_'].set_flag(codepoint_categ::WORDS); // \w --> \p{L} \p{N} _
for (auto p : unicode_ranges_whitespace) {
for (uint32_t cpt = p.first; cpt <= p.second; ++cpt) {
cpt_categs[cpt].set_flag(codepoint_categ::WHITESPACES);
}
}
//for (auto &range : unicode_ranges_nfd) { // start, last, nfd
// cpt_categs[cpt].set_flag(codepoint_categ::NORM_NFD);
//}
return cpt_categs;
}();
return cp < cpt_categs.size() ? cpt_categs[cp] : codepoint_categ{};
}
codepoint_flags unicode_cpt_flags(const std::string & utf8) {
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
codepoint_categ unicode_cpt_category(const std::string & utf8) {
if (utf8.empty()) {
return undef; // undefined
return codepoint_categ{}; // undefined
}
size_t offset = 0;
return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset));
return unicode_cpt_category(unicode_cpt_from_utf8(utf8, offset));
}
std::string unicode_byte_to_utf8(uint8_t byte) {
@ -642,171 +842,28 @@ uint32_t unicode_tolower(uint32_t cp) {
return it == unicode_map_lowercase.end() ? cp : it->second;
}
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
// unicode categories
static const std::map<std::string, int> k_ucat_enum = {
{ "\\p{N}", codepoint_flags::NUMBER },
{ "\\p{L}", codepoint_flags::LETTER },
{ "\\p{P}", codepoint_flags::PUNCTUATION },
};
static const std::map<int, int> k_ucat_cpt = {
{ codepoint_flags::NUMBER, 0xD1 },
{ codepoint_flags::LETTER, 0xD2 },
{ codepoint_flags::PUNCTUATION, 0xD3 },
};
static const std::map<int, std::string> k_ucat_map = {
{ codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
{ codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
};
// compute collapsed codepoints only if needed by at least one regex
bool need_collapse = false;
for (auto & regex_expr : regex_exprs) {
// search for unicode categories
for (const auto & ucat : k_ucat_enum) {
if (std::string::npos != regex_expr.find(ucat.first)) {
need_collapse = true;
break;
}
}
}
const auto cpts = unicode_cpts_from_utf8(text);
// generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
// ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
std::string text_collapsed;
if (need_collapse) {
// collapse all unicode categories
text_collapsed.resize(cpts.size());
for (size_t i = 0; i < cpts.size(); ++i) {
// keep single-byte codepoints as is
if (cpts[i] < 128) {
text_collapsed[i] = cpts[i];
continue;
}
const auto flags = unicode_cpt_flags(cpts[i]);
if (flags.is_whitespace) {
//NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
//text_collapsed[i] = (char) 0x85; // <Next Line> as whitespace fallback
text_collapsed[i] = (char) 0x0B; // <vertical tab> as whitespace fallback
} else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
text_collapsed[i] = k_ucat_cpt.at(flags.category_flag());
} else {
text_collapsed[i] = (char) 0xD0; // fallback
}
}
}
std::vector<size_t> bpe_offsets = { cpts.size() };
std::vector<std::string> unicode_regex_split(const std::string & text_utf8, const std::vector<std::string> & regex_exprs) {
const std::vector<uint32_t> cpts = unicode_cpts_from_utf8(text_utf8);
std::vector<size_t> offsets = { cpts.size() };
for (auto & regex_expr : regex_exprs) {
// first, see if we have an efficient custom regex implementation
auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets);
auto tmp = unicode_regex_split_custom(text_utf8, regex_expr, offsets);
if (!tmp.empty()) {
bpe_offsets = std::move(tmp);
offsets = std::move(tmp);
continue;
}
// fallback to general-purpose std::regex / std::wregex
try {
// if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
// with the corresponding collapsed representation
bool use_collapsed = false;
for (auto & ucat : k_ucat_enum) {
if (std::string::npos != regex_expr.find(ucat.first)) {
use_collapsed = true;
break;
}
}
if (use_collapsed) {
// sanity-check that the original regex does not contain any non-ASCII characters
const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
for (size_t i = 0; i < cpts_regex.size(); ++i) {
if (cpts_regex[i] >= 128) {
throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported");
}
}
// generate a collapsed representation of the regex
std::string regex_expr_collapsed;
// track if we are inside [], because nested [] are not allowed
bool inside = false;
for (size_t i = 0; i < regex_expr.size(); ++i) {
if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) {
regex_expr_collapsed += '[';
inside = true;
continue;
}
if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') {
regex_expr_collapsed += ']';
inside = false;
continue;
}
if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() &&
regex_expr[i + 1] == 'p' &&
regex_expr[i + 2] == '{' &&
regex_expr[i + 4] == '}') {
const std::string pat = regex_expr.substr(i, 5);
if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
if (!inside) {
regex_expr_collapsed += '[';
}
regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
if (!inside) {
regex_expr_collapsed += ']';
}
i += 4;
continue;
}
}
regex_expr_collapsed += regex_expr[i];
}
//printf("text_collapsed: %s\n", text_collapsed.c_str());
//printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str());
bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
} else {
// no unicode category used, we can use std::wregex directly
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
std::wstring wtext(cpts.begin(), cpts.end());
for (size_t i = 0; i < wtext.size(); ++i) {
if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) {
wtext[i] = 0x0B;
}
}
//printf("text: %s\n", text.c_str());
//printf("regex_expr: %s\n", regex_expr.c_str());
bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
}
} catch (std::regex_error & e) {
fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str());
fprintf(stderr, "Regex error: %s\n", e.what());
throw std::runtime_error("Failed to process regex");
}
const auto regex_cpts = unicode_regex_prepare(regex_expr);
offsets = unicode_regex_split_stl(cpts, regex_cpts, offsets);
}
std::vector<std::string> bpe_words;
bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size
bpe_words.reserve(offsets.size()); // reserve memory for the approximate size
size_t start = 0;
for (size_t & offset : bpe_offsets) {
for (size_t & offset : offsets) {
bpe_words.emplace_back();
for (size_t i = start; i < start + offset; ++i) {
bpe_words.back() += unicode_cpt_to_utf8(cpts[i]);

View file

@ -1,51 +1,185 @@
#pragma once
#include <cstdint>
#include <cassert>
#include <cstring>
#include <string>
#include <vector>
#include <array>
#include <map>
// TODO: prefix all symbols with "llama_"
struct codepoint_flags {
enum {
UNDEFINED = 0x0001,
NUMBER = 0x0002, // regex: \p{N}
LETTER = 0x0004, // regex: \p{L}
SEPARATOR = 0x0008, // regex: \p{Z}
ACCENT_MARK = 0x0010, // regex: \p{M}
PUNCTUATION = 0x0020, // regex: \p{P}
SYMBOL = 0x0040, // regex: \p{S}
CONTROL = 0x0080, // regex: \p{C}
MASK_CATEGORIES = 0x00FF,
struct codepoint_categ {
// 0bffffff'ccccccc'sss --> 6 bits flags + 7 bits category + 3 bits subcategory
enum _category : uint16_t {
UNDEF = 0, // \p{Cn} Undefined
C = 1 << (0 + 3), // \p{C} Control
L = 1 << (1 + 3), // \p{L} Letter
M = 1 << (2 + 3), // \p{M} Mark
N = 1 << (3 + 3), // \p{N} Number
P = 1 << (4 + 3), // \p{P} Punctuation
S = 1 << (5 + 3), // \p{S} Symbol
Z = 1 << (6 + 3), // \p{Z} Separator
Cc = C | 1, // \p{Cc} Control
Cf = C | 2, // \p{Cf} Format
Co = C | 3, // \p{Co} Private Use
Cs = C | 4, // \p{Cs} Surrrogate
Ll = L | 1, // \p{Ll} Lowercase Letter
Lm = L | 2, // \p{Lm} Modifier Letter
Lo = L | 3, // \p{Lo} Other Letter
Lt = L | 4, // \p{Lt} Titlecase Letter
Lu = L | 5, // \p{Lu} Uppercase Letter
Mc = M | 1, // \p{Mc} Spacing Mark
Me = M | 2, // \p{Me} Enclosing Mark
Mn = M | 3, // \p{Mn} Nonspacing Mark
Nd = N | 1, // \p{Nd} Decimal Number
Nl = N | 2, // \p{Nl} Letter Number
No = N | 3, // \p{No} Other Number
Pc = P | 1, // \p{Pc} Connector Punctuation
Pd = P | 2, // \p{Pd} Dash Punctuation
Pe = P | 3, // \p{Pe} Close Punctuation
Pf = P | 4, // \p{Pf} Final Punctuation
Pi = P | 5, // \p{Pi} Initial Punctuation
Po = P | 6, // \p{Po} Other Punctuation
Ps = P | 7, // \p{Ps} Open Punctuation
Sc = S | 1, // \p{Sc} Currency Symbol
Sk = S | 2, // \p{Sk} Modifier Symbol
Sm = S | 3, // \p{Sm} Math Symbol
So = S | 4, // \p{So} Other Symbol
Zl = Z | 1, // \p{Zl} Line Separator
Zp = Z | 2, // \p{Zp} Paragraph Separator
Zs = Z | 3, // \p{Zs} Space Separator
SUBMASK = (1 << 3) - 1, // 3 bits 0b000000'0000000'111
MASK = (1 << 10) - 1, // 7+3 bits 0b000000'1111111'111
};
// codepoint type
uint16_t is_undefined : 1;
uint16_t is_number : 1; // regex: \p{N}
uint16_t is_letter : 1; // regex: \p{L}
uint16_t is_separator : 1; // regex: \p{Z}
uint16_t is_accent_mark : 1; // regex: \p{M}
uint16_t is_punctuation : 1; // regex: \p{P}
uint16_t is_symbol : 1; // regex: \p{S}
uint16_t is_control : 1; // regex: \p{C}
// helper flags
uint16_t is_whitespace : 1; // regex: \s
uint16_t is_lowercase : 1;
uint16_t is_uppercase : 1;
uint16_t is_nfd : 1;
enum _flags : uint16_t {
WHITESPACES = (1 << 10), // regex: \s
WORDS = (1 << 11), // regex: \w
DIGITS = (1 << 12), // regex: \d
//Norm NFD/NFC = ...,
};
// decode from uint16
inline codepoint_flags(const uint16_t flags=0) {
*reinterpret_cast<uint16_t*>(this) = flags;
inline codepoint_categ(const uint16_t categ=0) : encoded{categ} {}
inline void set_flag(_flags flags, bool value = true) {
flags = (_flags) (flags & ~MASK); // do not modify category bits
encoded = value ? (encoded | flags) : (encoded & ~flags);
}
inline uint16_t as_uint() const {
return *reinterpret_cast<const uint16_t*>(this);
inline uint16_t get_category() const { return encoded & MASK; }
inline bool is_undefined() const { return !encoded; }
inline bool is_defined() const { return encoded; }
inline uint16_t is_whitespace() const { return encoded & WHITESPACES; }
inline uint16_t is_word() const { return encoded & WORDS; }
inline uint16_t is_digit() const { return encoded & DIGITS; }
inline uint16_t is_C() const { return encoded & C; }
inline uint16_t is_L() const { return encoded & L; }
inline uint16_t is_M() const { return encoded & M; }
inline uint16_t is_N() const { return encoded & N; }
inline uint16_t is_P() const { return encoded & P; }
inline uint16_t is_S() const { return encoded & S; }
inline uint16_t is_Z() const { return encoded & Z; }
inline bool is_Cc() const { return (encoded & MASK) == Cc; }
inline bool is_Cf() const { return (encoded & MASK) == Cf; }
inline bool is_Co() const { return (encoded & MASK) == Co; }
inline bool is_Cs() const { return (encoded & MASK) == Cs; }
inline bool is_Ll() const { return (encoded & MASK) == Ll; }
inline bool is_Lm() const { return (encoded & MASK) == Lm; }
inline bool is_Lo() const { return (encoded & MASK) == Lo; }
inline bool is_Lt() const { return (encoded & MASK) == Lt; }
inline bool is_Lu() const { return (encoded & MASK) == Lu; }
inline bool is_Mc() const { return (encoded & MASK) == Mc; }
inline bool is_Me() const { return (encoded & MASK) == Me; }
inline bool is_Mn() const { return (encoded & MASK) == Mn; }
inline bool is_Nd() const { return (encoded & MASK) == Nd; }
inline bool is_Nl() const { return (encoded & MASK) == Nl; }
inline bool is_No() const { return (encoded & MASK) == No; }
inline bool is_Pc() const { return (encoded & MASK) == Pc; }
inline bool is_Pd() const { return (encoded & MASK) == Pd; }
inline bool is_Pe() const { return (encoded & MASK) == Pe; }
inline bool is_Pf() const { return (encoded & MASK) == Pf; }
inline bool is_Pi() const { return (encoded & MASK) == Pi; }
inline bool is_Po() const { return (encoded & MASK) == Po; }
inline bool is_Ps() const { return (encoded & MASK) == Ps; }
inline bool is_Sc() const { return (encoded & MASK) == Sc; }
inline bool is_Sk() const { return (encoded & MASK) == Sk; }
inline bool is_Sm() const { return (encoded & MASK) == Sm; }
inline bool is_So() const { return (encoded & MASK) == So; }
inline bool is_Zl() const { return (encoded & MASK) == Zl; }
inline bool is_Zp() const { return (encoded & MASK) == Zp; }
inline bool is_Zs() const { return (encoded & MASK) == Zs; }
inline uint64_t expand_bits(const bool add_categ=true) const { // one bit for each category/subcateory and flags
const uint32_t subindex = encoded & SUBMASK;
const uint64_t bits = (encoded & MASK) >> 3;
const uint64_t flags = encoded >> 10;
return (flags << (7 * 8)) | (bits << (7 * subindex)) | (bits * add_categ);
}
inline uint16_t category_flag() const {
return this->as_uint() & MASK_CATEGORIES;
inline bool is_in_range(const codepoint_categ other) const { // this.first <= other <= this.last
if (encoded & SUBMASK) {
return encoded == other.encoded; // no range
}
if (encoded & MASK) {
return encoded == (other.encoded & ~SUBMASK); // from 0bffffff'ccccccc'000 to 0bffffff'ccccccc'111
}
return encoded == (other.encoded & ~MASK); // from 0bffffff'0000000'000 to 0bffffff'1111111'111
}
inline bool operator == (const codepoint_categ other) const {
return encoded == other.encoded;
}
inline bool operator != (const codepoint_categ other) const {
return encoded != other.encoded;
}
const char * c_str() const {
static const std::map<uint16_t, const char *> map = {
{UNDEF, "UNDEF"}, {C, "C"}, {L, "L"}, {M, "M"}, {N, "N"}, {P, "P"}, {S, "S"}, {Z, "Z"},
{Cc, "Cc"}, {Cf, "Cf"}, {Co, "Co"}, {Cs, "Cs"}, {Ll, "Ll"}, {Lm, "Lm"}, {Lo, "Lo"}, {Lt, "Lt"},
{Lu, "Lu"}, {Mc, "Mc"}, {Me, "Me"}, {Mn, "Mn"}, {Nd, "Nd"}, {Nl, "Nl"}, {No, "No"}, {Pc, "Pc"},
{Pd, "Pd"}, {Pe, "Pe"}, {Pf, "Pf"}, {Pi, "Pi"}, {Po, "Po"}, {Ps, "Ps"}, {Sc, "Sc"}, {Sk, "Sk"},
{Sm, "Sm"}, {So, "So"}, {Zl, "Zl"}, {Zp, "Zp"}, {Zs, "Zs"},
};
const auto it = map.find(encoded & MASK);
return it == map.end() ? "INVALID" : it->second;
}
static codepoint_categ from_index(int index) {
static const std::array<codepoint_categ, 32> table = {
UNDEF, Cc, Cf, Co, Cs, Ll, Lm, Lo, Lt, Lu, Mc, Me, Mn, Nd, Nl, No, Pc, Pd, Pe, Pf, Pi, Po, Ps, Sc, Sk, Sm, So, Zl, Zp, Zs, UNDEF, UNDEF
};
return (size_t)index < table.size() ? table[index] : table[0];
}
static codepoint_categ from_chars(const char categ, const char subcateg = '\0') {
auto _subindex = [] (const char subcateg, const char subcategs[]) -> uint16_t {
if (!subcateg) {
return 0;
}
const char * p = strchr(subcategs, subcateg);
GGML_ASSERT(p);
return (uint16_t) (p - subcategs + 1);
};
switch(categ) {
case 'C': if(subcateg == 'n') return 0; // undefined
return C | _subindex(subcateg, "cfos" );
case 'L': return L | _subindex(subcateg, "lmotu" );
case 'M': return M | _subindex(subcateg, "cen" );
case 'N': return N | _subindex(subcateg, "dlo" );
case 'P': return P | _subindex(subcateg, "cdefios");
case 'S': return S | _subindex(subcateg, "ckmo" );
case 'Z': return Z | _subindex(subcateg, "lps" );
default: GGML_ABORT("invalid category character");
}
}
uint16_t encoded;
};
size_t unicode_len_utf8(char src);
@ -56,8 +190,8 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
codepoint_flags unicode_cpt_flags(const uint32_t cp);
codepoint_flags unicode_cpt_flags(const std::string & utf8);
codepoint_categ unicode_cpt_category(const uint32_t cp);
codepoint_categ unicode_cpt_category(const std::string & utf8);
std::string unicode_byte_to_utf8(uint8_t byte);
uint8_t unicode_utf8_to_byte(const std::string & utf8);

View file

@ -116,9 +116,24 @@ class LibLlamaModel:
num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special)
return str(cast(Buffer, self.ffi.buffer(self.text_buff, num)), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
def get_vocab(self, detokenize=False) -> list[str]:
vocab: list[str] = []
num_tokens = self.lib.llama_n_vocab(self.model)
for id in range(num_tokens):
if detokenize:
text = self.detokenize([id], remove_special=False, unparse_special=True)
else:
text = self.lib.llama_token_get_text(self.model, id)
text = str(cast(bytes, self.ffi.string(text)), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
vocab.append(text)
return vocab
class Tokenizer:
def get_vocab(self, detokenize=False) -> list[str]:
raise NotImplementedError
def encode(self, text: str) -> list[int]:
raise NotImplementedError
@ -129,7 +144,7 @@ class Tokenizer:
class TokenizerGroundtruth (Tokenizer):
def __init__(self, dir_tokenizer: str):
self.model: PreTrainedTokenizer = AutoTokenizer.from_pretrained(dir_tokenizer)
self.model: PreTrainedTokenizer = AutoTokenizer.from_pretrained(dir_tokenizer, trust_remote_code=False)
# guess BOS and EOS
ids = self.encode("a")
assert 1 <= len(ids) <= 3
@ -138,15 +153,25 @@ class TokenizerGroundtruth (Tokenizer):
self.add_bos_token = getattr(self.model, "add_bos_token", add_bos_token)
self.add_eos_token = getattr(self.model, "add_eos_token", add_eos_token)
# build vocab
tokens = list(self.model.get_vocab().values())
self.vocab = self.model.batch_decode(tokens, skip_special_tokens=True)
self.vocab = list(sorted(self.vocab))
self.vocab = self.get_vocab(detokenize=True)
# tokens and lists
self.special_tokens = list(self.model.all_special_tokens)
self.added_tokens = self.model.batch_decode(self.model.added_tokens_encoder.values(), skip_special_tokens=False)
self.special_tokens = [self.vocab[i] for i in sorted(self.model.all_special_ids)]
self.added_tokens = [self.vocab[i] for i in sorted(self.model.added_tokens_encoder.values())]
self.bos_token = self.model.bos_token
self.eos_token = self.model.eos_token
def get_vocab(self, detokenize=False) -> list[str]:
vocab: list[str] = []
max_token_id = max(self.model.get_vocab().values())
if detokenize:
ids = list(range(max_token_id + 1))
vocab = self.model.batch_decode(ids, skip_special_tokens=False)
else:
vocab = [""] * (max_token_id + 1)
for text, id in self.model.get_vocab().items():
vocab[id] = text
return vocab
def encode(self, text: str) -> list[int]:
return self.model.encode(text, add_special_tokens=True)
@ -163,6 +188,9 @@ class TokenizerLlamaCpp (Tokenizer):
self.libllama = LibLlama()
self.model = LibLlamaModel(self.libllama, vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
def get_vocab(self, detokenize=False) -> list[str]:
return self.model.get_vocab(detokenize)
def encode(self, text: str) -> list[int]:
return self.model.tokenize(text, add_special=True, parse_special=True)
@ -253,6 +281,23 @@ def generator_vocab_words(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
yield from tokenizer.vocab
def generator_byte_tokens() -> Iterator[str]:
"""Brute force check common byte encoding"""
for a, b in ["<>", "[]", "()", ("\\", "")]:
yield from [f"{a}{i}{b}" for i in range(256)]
yield from [f"{a}{i:x}{b}" for i in range(256)]
yield from [f"{a}{i:X}{b}" for i in range(256)]
yield from [f"{a}x{i:x}{b}" for i in range(256)]
yield from [f"{a}x{i:X}{b}" for i in range(256)]
yield from [f"{a}x{i:02x}{b}" for i in range(256)]
yield from [f"{a}x{i:02X}{b}" for i in range(256)]
yield from [f"{a}0x{i:x}{b}" for i in range(256)]
yield from [f"{a}0x{i:X}{b}" for i in range(256)]
yield from [f"{a}0x{i:02x}{b}" for i in range(256)]
yield from [f"{a}0x{i:02X}{b}" for i in range(256)]
yield from [f"{a}{chr(i)}{b}" for i in range(256)]
def generator_ascii_lr_strip() -> Iterator[str]:
WHITESPACES = ["", " ", " "]
CHARACTERS = list(chr(i) for i in range(1, 0x80)) + [""]
@ -275,10 +320,11 @@ def generator_apostrophe() -> Iterator[str]:
yield char1 + lstrip + "'" + rstrip + char2
yield char1 + char2 + lstrip + "'" + rstrip + "z"
yield "a" + lstrip + "'" + rstrip + char1 + char2
yield "a" + lstrip + "'" + char1 + char2 + rstrip + "z"
def generator_added_lr_strip(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
WHITESPACES = ["", " ", " ", "\n", "\r\n", "\n\n", "\t", "\t\t"]
WHITESPACES = ["", " ", " ", "\n", "\r\n", "\n\n", "\t", "\t\t", " "]
all_tokens = list(sorted(set(tokenizer.special_tokens + tokenizer.added_tokens)))
for token in all_tokens:
for lstrip in WHITESPACES:
@ -409,14 +455,6 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str):
for i, (a, b) in enumerate(zip(ids1, ids2)):
if a != b:
return i
if len(ids1) == len(ids2):
return -1
return min(len(ids1), len(ids2))
def check_detokenizer(text: str, text1: str, text2: str) -> bool:
if text1 == text2: # equal to TokenizerGroundtruth?
return True
@ -434,9 +472,11 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
t_decode1 = 0
t_decode2 = 0
t_start = time.perf_counter()
total_tests = 0
failing_texts = set()
encode_errors = 0
decode_errors = 0
MAX_ERRORS = 10
MAX_ERRORS = 5
logger.info("%s: %s" % (generator.__qualname__, "ini"))
for text in generator:
@ -455,31 +495,100 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
t_encode2 += t2 - t1
t_decode1 += t3 - t2
t_decode2 += t4 - t3
if encode_errors < MAX_ERRORS and ids1 != ids2:
i = find_first_mismatch(ids1, ids2)
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
logger.error(" Expected: " + str(ids1))
logger.error(" Result: " + str(ids2))
encode_errors += 1
logger.error(f" {encode_errors=}")
if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2):
i = find_first_mismatch(text1, text2)
text1 = list(text1[max(0, i - 2) : i + 5 + 1])
text2 = list(text2[max(0, i - 2) : i + 5 + 1])
logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1))
logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2))
decode_errors += 1
logger.error(f" {decode_errors=}")
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
# raise Exception()
break
total_tests += 1
# compare
encode_ok = ids1 == ids2
decode_ok = check_detokenizer(text, text1, text2)
if not (encode_ok and decode_ok):
def _compare(text: str):
ids1 = tokenizer1.encode(text)
ids2 = tokenizer2.encode(text)
text1 = tokenizer1.decode(ids1)
text2 = tokenizer2.decode(ids1)
encode_ok = ids1 == ids2
decode_ok = check_detokenizer(text, text1, text2)
ok = encode_ok and decode_ok
return ok, ids1, ids2, text1, text2
# binary search upper and lower failing range
a, b = 0, len(text)
step = b
while step > 1:
step = (step + 1) // 2
t = max(a, b - step)
if not _compare(text[a : t])[0]:
b = t
step = b
while step > 1:
step = (step + 1) // 2
t = min(a + step, b)
if not _compare(text[t : b])[0]:
a = t
ok, ids1, ids2, text1, text2 = _compare(text[a : b])
assert a <= b and not ok
# show unique failing texts differences
failing_text = text[a : b]
if failing_text not in failing_texts:
failing_texts.add(failing_text)
if encode_errors < MAX_ERRORS and not encode_ok:
encode_errors += 1
logger.error(f" {encode_errors=}")
logger.error(" Text:" + repr(failing_text))
logger.error(" " + " ".join(repr(x) + ":" + hex(ord(x)) for x in failing_text))
logger.error(" Expected: " + str(ids1))
logger.error(" Result: " + str(ids2))
if decode_errors < MAX_ERRORS and not decode_ok:
decode_errors += 1
logger.error(f" {decode_errors=}")
logger.error(" Text:" + repr(failing_text))
logger.error(" " + " ".join(repr(x) + ":" + hex(ord(x)) for x in failing_text))
logger.error(" Expected: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text1))
logger.error(" Result: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text2))
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
# raise Exception()
break
t_total = time.perf_counter() - t_start
logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
def compare_vocabs(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp):
MAX_PRINT_ERRORS = 10
logger.info("compare_vocabs: ini")
t_start = time.perf_counter()
for detokenize in (False, True):
vocab1 = tokenizer1.get_vocab(detokenize)
vocab2 = tokenizer2.get_vocab(detokenize)
if vocab1 != vocab2:
num_errors = 0
for i in range(max(len(vocab1), len(vocab2))):
text1 = vocab1[i] if i < len(vocab1) else None
text2 = vocab2[i] if i < len(vocab2) else None
if text1 != text2:
# is "[UNUSED_TOKEN_" and "[PAD" valid for all models ? #TODO: use toktypes
if text1 is not None:
text1 = text1.replace("[UNUSED_TOKEN_", "[PAD")
if text2 is not None:
text2 = text2.replace("[UNUSED_TOKEN_", "[PAD")
if text1 is None and (text2 or "").startswith('[PAD'):
text2 = None
if text2 is None and (text1 or "").startswith('[PAD'):
text1 = None
if text1 != text2:
num_errors += 1
if num_errors < MAX_PRINT_ERRORS:
logger.error(f" {detokenize=} id={i} expected={repr(text1)} result={repr(text2)}")
if num_errors:
logger.error(f" {num_errors=}")
t_total = time.perf_counter() - t_start
logger.info(f"compare_vocabs: end, {t_total=:.3f}")
def main(argv: list[str] | None = None):
parser = argparse.ArgumentParser()
parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file")
@ -493,18 +602,21 @@ def main(argv: list[str] | None = None):
tokenizer1 = TokenizerGroundtruth(args.dir_tokenizer)
tokenizer2 = TokenizerLlamaCpp(args.vocab_file)
# compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text())
# compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text_edge_cases())
compare_vocabs(tokenizer1, tokenizer2)
compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text())
compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text_edge_cases())
compare_tokenizers(tokenizer1, tokenizer2, generator_byte_tokens())
compare_tokenizers(tokenizer1, tokenizer2, generator_ascii_lr_strip())
compare_tokenizers(tokenizer1, tokenizer2, generator_apostrophe())
compare_tokenizers(tokenizer1, tokenizer2, generator_unicodes())
compare_tokenizers(tokenizer1, tokenizer2, generator_vocab_words(tokenizer1))
compare_tokenizers(tokenizer1, tokenizer2, generator_added_lr_strip(tokenizer1))
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_added_tokens(tokenizer1, 10_000))
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_chars(10_000))
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_unicodes(10_000))
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_chars(tokenizer1, 10_000))
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_words(tokenizer1, 5_000))
compare_tokenizers(tokenizer1, tokenizer2, generator_random_added_tokens(tokenizer1, 10_000))
compare_tokenizers(tokenizer1, tokenizer2, generator_random_chars(10_000))
compare_tokenizers(tokenizer1, tokenizer2, generator_random_unicodes(10_000))
compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_chars(tokenizer1, 10_000))
compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_words(tokenizer1, 5_000))
tokenizer2.model.free()
@ -533,21 +645,19 @@ if __name__ == "__main__":
"phi-3", # SPM
"gemma", # SPM
"gemma-2", # SPM
"baichuan", # SPM
# "baichuan", # SPM
"bert-bge", # WPM
"jina-v2-en", # WPM
# "t5", # UGM
"llama-bpe", # BPE
"phi-2", # BPE
"deepseek-llm", # BPE
"deepseek-coder", # BPE
"falcon", # BPE
"mpt", # BPE
"starcoder", # BPE
"gpt-2", # BPE
"stablelm2", # BPE
"refact", # BPE
"qwen2", # BPE
"olmo", # BPE
"jina-v2-es", # BPE
"jina-v2-de", # BPE
"smaug-bpe", # BPE
@ -555,6 +665,14 @@ if __name__ == "__main__":
"jina-v2-code", # BPE
"viking", # BPE
"jais", # BPE
"codeshell", # BPE
"tekken", # BPE
"smollm", # BPE
"mpt", # BPE NFC
"command-r", # BPE NFC
"qwen2", # BPE NFC
"olmo", # BPE NFC
"gpt-neox", # BPE NFC
]
logger.info("=" * 50)