This commit is contained in:
jaime-m-p 2024-09-08 10:06:51 +02:00 committed by GitHub
commit dffb4b1909
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 5275 additions and 2746 deletions

View file

@ -49,53 +49,42 @@ def unicode_data_iter():
yield (cpt, cpt_lower, cpt_upper, categ, bidir) yield (cpt, cpt_lower, cpt_upper, categ, bidir)
# see definition in unicode.h # see codepoint_categ::from_index() in unicode.h
CODEPOINT_FLAG_UNDEFINED = 0x0001 # UNICODE_CATEGORY_TO_INDEX = {
CODEPOINT_FLAG_NUMBER = 0x0002 # \p{N} "Cn": 0, # \p{Cn} Undefined
CODEPOINT_FLAG_LETTER = 0x0004 # \p{L} "Cc": 1, # \p{Cc} Control
CODEPOINT_FLAG_SEPARATOR = 0x0008 # \p{Z} "Cf": 2, # \p{Cf} Format
CODEPOINT_FLAG_MARK = 0x0010 # \p{M} "Co": 3, # \p{Co} Private Use
CODEPOINT_FLAG_PUNCTUATION = 0x0020 # \p{P} "Cs": 4, # \p{Cs} Surrrogate
CODEPOINT_FLAG_SYMBOL = 0x0040 # \p{S} "Ll": 5, # \p{Ll} Lowercase Letter
CODEPOINT_FLAG_CONTROL = 0x0080 # \p{C} "Lm": 6, # \p{Lm} Modifier Letter
"Lo": 7, # \p{Lo} Other Letter
UNICODE_CATEGORY_TO_FLAG = { "Lt": 8, # \p{Lt} Titlecase Letter
"Cn": CODEPOINT_FLAG_UNDEFINED, # Undefined "Lu": 9, # \p{Lu} Uppercase Letter
"Cc": CODEPOINT_FLAG_CONTROL, # Control "Mc": 10, # \p{Mc} Spacing Mark
"Cf": CODEPOINT_FLAG_CONTROL, # Format "Me": 11, # \p{Me} Enclosing Mark
"Co": CODEPOINT_FLAG_CONTROL, # Private Use "Mn": 12, # \p{Mn} Nonspacing Mark
"Cs": CODEPOINT_FLAG_CONTROL, # Surrrogate "Nd": 13, # \p{Nd} Decimal Number
"Ll": CODEPOINT_FLAG_LETTER, # Lowercase Letter "Nl": 14, # \p{Nl} Letter Number
"Lm": CODEPOINT_FLAG_LETTER, # Modifier Letter "No": 15, # \p{No} Other Number
"Lo": CODEPOINT_FLAG_LETTER, # Other Letter "Pc": 16, # \p{Pc} Connector Punctuation
"Lt": CODEPOINT_FLAG_LETTER, # Titlecase Letter "Pd": 17, # \p{Pd} Dash Punctuation
"Lu": CODEPOINT_FLAG_LETTER, # Uppercase Letter "Pe": 18, # \p{Pe} Close Punctuation
"L&": CODEPOINT_FLAG_LETTER, # Cased Letter "Pf": 19, # \p{Pf} Final Punctuation
"Mc": CODEPOINT_FLAG_MARK, # Spacing Mark "Pi": 20, # \p{Pi} Initial Punctuation
"Me": CODEPOINT_FLAG_MARK, # Enclosing Mark "Po": 21, # \p{Po} Other Punctuation
"Mn": CODEPOINT_FLAG_MARK, # Nonspacing Mark "Ps": 22, # \p{Ps} Open Punctuation
"Nd": CODEPOINT_FLAG_NUMBER, # Decimal Number "Sc": 23, # \p{Sc} Currency Symbol
"Nl": CODEPOINT_FLAG_NUMBER, # Letter Number "Sk": 24, # \p{Sk} Modifier Symbol
"No": CODEPOINT_FLAG_NUMBER, # Other Number "Sm": 25, # \p{Sm} Math Symbol
"Pc": CODEPOINT_FLAG_PUNCTUATION, # Connector Punctuation "So": 26, # \p{So} Other Symbol
"Pd": CODEPOINT_FLAG_PUNCTUATION, # Dash Punctuation "Zl": 27, # \p{Zl} Line Separator
"Pe": CODEPOINT_FLAG_PUNCTUATION, # Close Punctuation "Zp": 28, # \p{Zp} Paragraph Separator
"Pf": CODEPOINT_FLAG_PUNCTUATION, # Final Punctuation "Zs": 29, # \p{Zs} Space Separator
"Pi": CODEPOINT_FLAG_PUNCTUATION, # Initial Punctuation
"Po": CODEPOINT_FLAG_PUNCTUATION, # Other Punctuation
"Ps": CODEPOINT_FLAG_PUNCTUATION, # Open Punctuation
"Sc": CODEPOINT_FLAG_SYMBOL, # Currency Symbol
"Sk": CODEPOINT_FLAG_SYMBOL, # Modifier Symbol
"Sm": CODEPOINT_FLAG_SYMBOL, # Math Symbol
"So": CODEPOINT_FLAG_SYMBOL, # Other Symbol
"Zl": CODEPOINT_FLAG_SEPARATOR, # Line Separator
"Zp": CODEPOINT_FLAG_SEPARATOR, # Paragraph Separator
"Zs": CODEPOINT_FLAG_SEPARATOR, # Space Separator
} }
codepoint_flags = array.array('H', [CODEPOINT_FLAG_UNDEFINED]) * MAX_CODEPOINTS codepoint_categs = array.array('B', [0]) * MAX_CODEPOINTS # Undefined
table_whitespace = []
table_lowercase = [] table_lowercase = []
table_uppercase = [] table_uppercase = []
table_nfd = [] table_nfd = []
@ -105,7 +94,7 @@ for (cpt, cpt_lower, cpt_upper, categ, bidir) in unicode_data_iter():
char = chr(cpt) char = chr(cpt)
# codepoint category flags # codepoint category flags
codepoint_flags[cpt] = UNICODE_CATEGORY_TO_FLAG[categ] codepoint_categs[cpt] = UNICODE_CATEGORY_TO_INDEX[categ]
# lowercase conversion # lowercase conversion
if cpt_lower: if cpt_lower:
@ -121,25 +110,31 @@ for (cpt, cpt_lower, cpt_upper, categ, bidir) in unicode_data_iter():
table_nfd.append((cpt, norm)) table_nfd.append((cpt, norm))
# whitespaces, see "<White_Space>" https://www.unicode.org/Public/UCD/latest/ucd/PropList.txt
table_whitespace.extend(range(0x0009, 0x000D + 1))
table_whitespace.extend(range(0x2000, 0x200A + 1))
table_whitespace.extend([0x0020, 0x0085, 0x00A0, 0x1680, 0x2028, 0x2029, 0x202F, 0x205F, 0x3000])
# sort by codepoint # sort by codepoint
table_whitespace.sort()
table_lowercase.sort() table_lowercase.sort()
table_uppercase.sort() table_uppercase.sort()
table_nfd.sort() table_nfd.sort()
# group ranges with same flags # whitespaces, see "<White_Space>" https://www.unicode.org/Public/UCD/latest/ucd/PropList.txt
ranges_flags: list[tuple[int, int]] = [(0, codepoint_flags[0])] # start, flags whitespace_ranges: list[tuple[int, int]] = [] # start, last
for codepoint, flags in enumerate(codepoint_flags): whitespace_ranges.append((0x0009, 0x000D))
if flags != ranges_flags[-1][1]: whitespace_ranges.append((0x2000, 0x200A))
ranges_flags.append((codepoint, flags)) for whitespace in [0x0020, 0x0085, 0x00A0, 0x1680, 0x2028, 0x2029, 0x202F, 0x205F, 0x3000]:
ranges_flags.append((MAX_CODEPOINTS, 0x0000)) whitespace_ranges.append((whitespace, whitespace))
# run length encoding, see unicode_cpt_category() in unicode.cpp
assert (max(UNICODE_CATEGORY_TO_INDEX.values()) < 32)
codepoint_categs_runs = [codepoint_categs[0]] # 5 bits categ + 11 bits length
for cpt, categ in enumerate(codepoint_categs[1:], 1):
prev = codepoint_categs_runs[-1]
if prev <= (0xFFFF - 32) and (prev & 31) == categ:
codepoint_categs_runs[-1] += 32 # increment run length
else:
codepoint_categs_runs.append(categ) # new run value
assert (codepoint_categs_runs[-1] < 0xFFFF)
assert (MAX_CODEPOINTS == sum((rle >> 5) + 1 for rle in codepoint_categs_runs))
# group ranges with same nfd # group ranges with same nfd
@ -153,7 +148,7 @@ for codepoint, norm in table_nfd:
# Generate 'unicode-data.cpp': # Generate 'unicode-data.cpp':
# python ./scripts//gen-unicode-data.py > unicode-data.cpp # python ./scripts//gen-unicode-data.py > ./src/unicode-data.cpp
def out(line=""): def out(line=""):
print(line, end='\n') # noqa print(line, end='\n') # noqa
@ -167,17 +162,16 @@ out("""\
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
""") """)
out("const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags = { // start, flags // last=next_start-1") out("const std::vector<uint16_t> unicode_rle_codepoints_categs = { // run length encoding, 5 bits categ + 11 bits length")
for codepoint, flags in ranges_flags: for rle in codepoint_categs_runs:
out("{0x%06X, 0x%04X}," % (codepoint, flags)) out("0x%04X," % rle)
out("};\n") out("};\n")
out("const std::unordered_set<uint32_t> unicode_set_whitespace = {") out("const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace = {")
for codepoint in table_whitespace: for (start, last) in whitespace_ranges:
out("0x%06X," % codepoint) out("{0x%06X, 0x%06X}," % (start, last))
out("};\n") out("};\n")
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {") out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")

View file

@ -444,10 +444,8 @@ struct llm_tokenizer_bpe {
}; };
break; break;
case LLAMA_VOCAB_PRE_TYPE_TEKKEN: case LLAMA_VOCAB_PRE_TYPE_TEKKEN:
// original regex from tokenizer.json
// "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
regex_exprs = { regex_exprs = {
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}; };
break; break;
default: default:
@ -704,22 +702,21 @@ struct llm_tokenizer_wpm {
std::vector<std::string> words(1, ""); std::vector<std::string> words(1, "");
for (const uint32_t cpt : cpts_nfd) { for (const uint32_t cpt : cpts_nfd) {
const auto flags = unicode_cpt_flags(cpt); const auto categ = unicode_cpt_category(cpt);
if (flags.is_whitespace) { if (categ.is_whitespace()) {
if (words.back().size()) { // finish previous word if any if (words.back().size()) { // finish previous word if any
words.emplace_back(); words.emplace_back();
} }
continue; continue;
} }
assert (!flags.is_separator); if (cpt == 0 || cpt == 0xFFFD || categ.is_C()) {
if (cpt == 0 || cpt == 0xFFFD || flags.is_control) {
continue; continue;
} }
const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt)); const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt));
if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) { if (categ.is_P() || (cpt < 0x7F && categ.is_S()) || is_chinese_char(cpt)) {
if (words.back().size()) { // finish previous word if any if (words.back().size()) { // finish previous word if any
words.emplace_back(); words.emplace_back();
} }
@ -737,7 +734,7 @@ struct llm_tokenizer_wpm {
return words; return words;
} }
static bool is_chinese_char(uint32_t cpt) { static bool is_chinese_char(uint32_t cpt) { //TODO: move to unicode-data.cpp? unicode_cpt_category(cpt).is_chinese()?
return return
(cpt >= 0x04E00 && cpt <= 0x09FFF) || (cpt >= 0x04E00 && cpt <= 0x09FFF) ||
(cpt >= 0x03400 && cpt <= 0x04DBF) || (cpt >= 0x03400 && cpt <= 0x04DBF) ||

File diff suppressed because it is too large Load diff

View file

@ -3,7 +3,6 @@
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
struct range_nfd { struct range_nfd {
uint32_t first; uint32_t first;
@ -13,8 +12,8 @@ struct range_nfd {
static const uint32_t MAX_CODEPOINTS = 0x110000; static const uint32_t MAX_CODEPOINTS = 0x110000;
extern const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags; extern const std::vector<uint16_t> unicode_rle_codepoints_categs;
extern const std::unordered_set<uint32_t> unicode_set_whitespace; extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace;
extern const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase; extern const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase;
extern const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase; extern const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase;
extern const std::vector<range_nfd> unicode_ranges_nfd; extern const std::vector<range_nfd> unicode_ranges_nfd;

View file

@ -2,10 +2,10 @@
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING #define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
#endif #endif
#include "ggml.h"
#include "unicode.h" #include "unicode.h"
#include "unicode-data.h" #include "unicode-data.h"
#include <cassert>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <map> #include <map>
@ -119,38 +119,6 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
// return result; // return result;
//} //}
static std::vector<codepoint_flags> unicode_cpt_flags_array() {
std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
assert (unicode_ranges_flags.front().first == 0);
assert (unicode_ranges_flags.back().first == MAX_CODEPOINTS);
for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
const auto range_ini = unicode_ranges_flags[i-1]; // codepoint_ini, flags
const auto range_end = unicode_ranges_flags[i]; // codepoint_end, flags
for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
cpt_flags[cpt] = range_ini.second;
}
}
for (auto cpt : unicode_set_whitespace) {
cpt_flags[cpt].is_whitespace = true;
}
for (auto p : unicode_map_lowercase) {
cpt_flags[p.second].is_lowercase = true;
}
for (auto p : unicode_map_uppercase) {
cpt_flags[p.second].is_uppercase = true;
}
for (auto &range : unicode_ranges_nfd) { // start, last, nfd
cpt_flags[range.nfd].is_nfd = true;
}
return cpt_flags;
}
static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() { static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
std::unordered_map<uint8_t, std::string> map; std::unordered_map<uint8_t, std::string> map;
for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~' for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
@ -233,7 +201,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
for (auto offset : offsets) { for (auto offset : offsets) {
const size_t offset_ini = start; const size_t offset_ini = start;
const size_t offset_end = start + offset; const size_t offset_end = start + offset;
assert(offset_end <= cpts.size()); GGML_ASSERT(offset_end <= cpts.size());
start = offset_end; start = offset_end;
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF; static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
@ -241,13 +209,14 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
}; };
auto _get_flags = [&] (const size_t pos) -> codepoint_flags { static const codepoint_categ SENTINEL = codepoint_categ::UNDEF + 1;
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{}; auto _get_categ = [&] (const size_t pos) -> codepoint_categ {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_category(cpts[pos]) : SENTINEL;
}; };
size_t _prev_end = offset_ini; size_t _prev_end = offset_ini;
auto _add_token = [&] (const size_t end) -> size_t { auto _add_token = [&] (const size_t end) -> size_t {
assert(_prev_end <= end && end <= offset_end); GGML_ASSERT(_prev_end <= end && end <= offset_end);
size_t len = end - _prev_end; size_t len = end - _prev_end;
if (len > 0) { if (len > 0) {
bpe_offsets.push_back(len); bpe_offsets.push_back(len);
@ -264,7 +233,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
const uint32_t cpt = _get_cpt(pos); const uint32_t cpt = _get_cpt(pos);
const auto flags = _get_flags(pos); const auto categ = _get_categ(pos);
// regex: 's|'t|'re|'ve|'m|'ll|'d // regex: 's|'t|'re|'ve|'m|'ll|'d
if (cpt == '\'' && pos+1 < offset_end) { if (cpt == '\'' && pos+1 < offset_end) {
@ -284,37 +253,37 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
} }
} }
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags); auto categ2 = (cpt == ' ' ? _get_categ(pos+1) : categ);
// regex: <space>?\p{L}+ // regex: <space>?\p{L}+
if (flags2.is_letter) { if (categ2.is_L()) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (flags2.is_letter) { while (categ2.is_L()) {
flags2 = _get_flags(++pos); categ2 = _get_categ(++pos);
} }
_add_token(pos); _add_token(pos);
continue; continue;
} }
// regex: <space>?\p{N}+ // regex: <space>?\p{N}+
if (flags2.is_number) { if (categ2.is_N()) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (flags2.is_number) { while (categ2.is_N()) {
flags2 = _get_flags(++pos); categ2 = _get_categ(++pos);
} }
_add_token(pos); _add_token(pos);
continue; continue;
} }
// regex: <space>?[^\s\p{L}\p{N}]+ // regex: <space>?[^\s\p{L}\p{N}]+
if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) { if (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) { while (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) {
flags2 = _get_flags(++pos); categ2 = _get_categ(++pos);
} }
_add_token(pos); _add_token(pos);
continue; continue;
} }
size_t num_whitespaces = 0; size_t num_whitespaces = 0;
while (_get_flags(pos+num_whitespaces).is_whitespace) { while (_get_categ(pos+num_whitespaces).is_whitespace()) {
num_whitespaces++; num_whitespaces++;
} }
@ -351,7 +320,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
for (auto offset : offsets) { for (auto offset : offsets) {
const size_t offset_ini = start; const size_t offset_ini = start;
const size_t offset_end = start + offset; const size_t offset_end = start + offset;
assert(offset_end <= cpts.size()); GGML_ASSERT(offset_end <= cpts.size());
start = offset_end; start = offset_end;
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF; static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
@ -359,13 +328,14 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
}; };
auto _get_flags = [&] (const size_t pos) -> codepoint_flags { static const codepoint_categ SENTINEL = codepoint_categ::UNDEF + 1;
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{}; auto _get_categ = [&] (const size_t pos) -> codepoint_categ {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_category(cpts[pos]) : SENTINEL;
}; };
size_t _prev_end = offset_ini; size_t _prev_end = offset_ini;
auto _add_token = [&] (const size_t end) -> size_t { auto _add_token = [&] (const size_t end) -> size_t {
assert(_prev_end <= end && end <= offset_end); GGML_ASSERT(_prev_end <= end && end <= offset_end);
size_t len = end - _prev_end; size_t len = end - _prev_end;
if (len > 0) { if (len > 0) {
bpe_offsets.push_back(len); bpe_offsets.push_back(len);
@ -382,7 +352,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
const uint32_t cpt = _get_cpt(pos); const uint32_t cpt = _get_cpt(pos);
const auto flags = _get_flags(pos); const auto categ = _get_categ(pos);
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
if (cpt == '\'' && pos+1 < offset_end) { if (cpt == '\'' && pos+1 < offset_end) {
@ -403,10 +373,10 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
} }
// regex: [^\r\n\p{L}\p{N}]?\p{L}+ // regex: [^\r\n\p{L}\p{N}]?\p{L}+
if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) { if (!(cpt == '\r' || cpt == '\n' || categ.is_N())) {
if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters if (categ.is_L() || _get_categ(pos+1).is_L()) { // one or more letters
pos++; pos++;
while (_get_flags(pos).is_letter) { while (_get_categ(pos).is_L()) {
pos++; pos++;
} }
_add_token(pos); _add_token(pos);
@ -415,9 +385,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
} }
// regex: \p{N}{1,3} // regex: \p{N}{1,3}
if (flags.is_number) { if (categ.is_N()) {
size_t ini = pos; size_t ini = pos;
while (_get_flags(pos).is_number) { while (_get_categ(pos).is_N()) {
if (++pos - ini >= 3 ) { if (++pos - ini >= 3 ) {
_add_token(pos); _add_token(pos);
ini = pos; ini = pos;
@ -428,11 +398,11 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
} }
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]* // regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags); auto categ2 = (cpt == ' ' ? _get_categ(pos+1) : categ);
if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) { if (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) { while (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) {
flags2 = _get_flags(++pos); categ2 = _get_categ(++pos);
} }
uint32_t cpt2 = _get_cpt(pos); uint32_t cpt2 = _get_cpt(pos);
while (cpt2 == '\r' || cpt2 == '\n') { while (cpt2 == '\r' || cpt2 == '\n') {
@ -444,7 +414,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
size_t num_whitespaces = 0; size_t num_whitespaces = 0;
size_t last_end_r_or_n = 0; size_t last_end_r_or_n = 0;
while (_get_flags(pos+num_whitespaces).is_whitespace) { while (_get_categ(pos+num_whitespaces).is_whitespace()) {
uint32_t cpt2 = _get_cpt(pos+num_whitespaces); uint32_t cpt2 = _get_cpt(pos+num_whitespaces);
if (cpt2 == '\r' || cpt2 == '\n') { if (cpt2 == '\r' || cpt2 == '\n') {
last_end_r_or_n = pos + num_whitespaces + 1; last_end_r_or_n = pos + num_whitespaces + 1;
@ -481,66 +451,6 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
return bpe_offsets; return bpe_offsets;
} }
// use std::wregex to split the text
static std::vector<size_t> unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector<size_t> & offsets) {
std::wregex expr(regex_expr);
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
for (auto offset : offsets) {
std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
std::wcregex_iterator end;
int64_t start_idx = 0;
while (it != end) {
std::wcmatch match = *it;
if (match.position() > start_idx) {
bpe_offsets.emplace_back(match.position() - start_idx);
}
bpe_offsets.emplace_back(match.length());
start_idx = match.position() + match.length();
++it;
}
if (start_idx < (int64_t) offset) {
bpe_offsets.emplace_back(offset - start_idx);
}
start += offset;
}
return bpe_offsets;
}
// use std::regex to split the text
static std::vector<size_t> unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
std::regex expr(regex_expr);
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
for (auto offset : offsets) {
std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
std::cregex_iterator end;
int64_t start_idx = 0;
while (it != end) {
std::cmatch match = *it;
if (match.position() > start_idx) {
bpe_offsets.emplace_back(match.position() - start_idx);
}
bpe_offsets.emplace_back(match.length());
start_idx = match.position() + match.length();
++it;
}
if (start_idx < (int64_t) offset) {
bpe_offsets.emplace_back(offset - start_idx);
}
start += offset;
}
return bpe_offsets;
}
static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) { static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets; std::vector<size_t> bpe_offsets;
@ -556,6 +466,269 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
return bpe_offsets; return bpe_offsets;
} }
// Custom std::regex specializations for 32bit unicode codepoints
// std::wregex does not support unicode categories: \p{N}, \p{L}, \p{Lu}, \p{Ll} ...
// std::wregex does not support unicode whitespaces \s: 0x85, 0xA0, 0x001680 ... 0x003000.
// std::wregex supports full 32 bit codepoints, not limited to standard max 0x110000.
namespace std {
// codepoint type for all template specializations
#if (WCHAR_MAX > 0xFFFF)
using codepoint = wchar_t; // sizeof(wchar_t) == 4
#else
using codepoint = uint32_t; // Windows: sizeof(wchar_t) == 2
#define CUSTOM_CTYPE_CODEPOINT
#endif
#ifdef CUSTOM_CTYPE_CODEPOINT
// Minimal required implementation for std::regex string processing
template<> // custom specialized std::ctype<codepoint>
class ctype<codepoint> {
public:
using CharT = codepoint;
using char_type = CharT;
using mask = uint8_t; //NOTE: see std::ctype_base
static const mask digit = 1; // requiered variable names
static const mask xdigit = 2; // user defined values
static const mask alpha = 3; // used to be a bitmask
static const mask upper = 4; // we do not need a bitmask
static const mask lower = 5; // using a sequence instead
static locale::id id; // required by std::locale::facet
bool is(mask m, char_type c) const {
switch (m) {
case digit: return ('0' <= c && c <= '9');
case xdigit: return ('0' <= c && c <= '9') || ('A' <= c && c <= 'F');
case alpha: return ('A' <= c && c <= 'Z') || ('a' <= c && c <= 'z');
case upper: return ('A' <= c && c <= 'Z');
case lower: return ('a' <= c && c <= 'z');
default: return false;
}
}
char_type toupper(char_type c) const {
return ('a' <= c && c <= 'z') ? c - ('a' - 'A') : c;
}
char_type tolower(char_type c) const {
return ('A' <= c && c <= 'Z') ? c + ('a' - 'A') : c;
}
char_type widen(char c) const { // char to codepoint
return (char_type) c;
}
char narrow(char_type c, char dfault) const { // codepoint to char
return (c < 0x80 ? (char)c : dfault);
}
};
locale::id ctype<codepoint>::id = {};
template<> // specialization to use our custom specialized std::ctype<codepoint>
const std::ctype<codepoint> & use_facet<std::ctype<codepoint>>(const std::locale &) {
static std::ctype<codepoint> ctype_uint32 = {};
return ctype_uint32;
}
template<> // specialization to use our custom specialized std::ctype<codepoint>
const std::ctype<codepoint> & use_facet<const std::ctype<codepoint>>(const std::locale & loc) {
return use_facet<std::ctype<codepoint>>(loc);
}
#endif
// Minimal required implementation for std::regex string processing
template<> // custom specialized std::regex_traits<codepoint>
class regex_traits<codepoint> {
public:
using CharT = codepoint;
using char_type = codepoint;
using size_type = size_t;
using string_type = std::basic_string<CharT>;
using locale_type = std::locale;
using char_class_type = uint64_t;
#if (defined(_WIN32) || defined(_WIN64)) // MSVC class _Regex_traits
using _Uelem = CharT;
static const auto _Ch_upper = std::ctype<CharT>::upper;
static const auto _Ch_alpha = std::ctype<CharT>::alpha;
#endif
CharT translate(CharT c) const {
return c;
}
CharT translate_nocase(CharT c) const {
return unicode_tolower(c);
}
template<typename It>
string_type transform(It first, It last) const {
GGML_ASSERT(false); //TODO: not needed ?
return {first, last}; //TODO: not tested
}
template<typename It>
string_type transform_primary(It first, It last) const {
(void) first;
(void) last;
GGML_ASSERT((uint32_t) *first < MAX_CODEPOINTS); // check valid codepoint
return {};
}
template<typename It>
string_type lookup_collatename(It first, It last) const {
(void) last;
GGML_ASSERT(*first & (1 << 31));
return {*first};
}
template<typename It>
char_class_type lookup_classname(It first, It last, bool icase = false) const {
(void) last;
(void) icase;
const uint32_t encoded = *first;
codepoint_categ categ = {};
switch(encoded) {
case 's':
case 'S': // negation is internally tracked
categ.set_flag(codepoint_categ::WHITESPACES);
return categ.expand_bits();
case 'w':
case 'W': // negation is internally tracked
categ.set_flag(codepoint_categ::WORDS);
return categ.expand_bits();
case 'd':
case 'D': // negation is internally tracked
categ.set_flag(codepoint_categ::DIGITS);
return categ.expand_bits();
default: { // unicode category \p{Xx} encoded in codepoint
GGML_ASSERT(encoded & (1 << 31)); // make sure its our custom codepoint encoding the category
const bool negated = encoded & (1 << 30); // negation of 'character class expression' are not internally tracked
categ = {(uint16_t) encoded};
return ((uint64_t) negated << 63) | categ.expand_bits(false);
}
}
}
bool isctype(CharT c, char_class_type mask) const {
const bool negated = mask & (1llu << 63);
mask &= unicode_cpt_category(c).expand_bits();
return negated ^ (bool) mask;
}
int value(CharT c, int radix) const { // char to int value
switch (radix) {
case 8: return ('0' <= c && c <= '7') ? (int)c - '0' : -1;
case 10: return ('0' <= c && c <= '9') ? (int)c - '0' : -1;
case 16: return ('0' <= c && c <= '9') ? (int)c - '0' : (('A' <= c && c <= 'F') ? (int)c - 'A' + 10 : -1);
default: return -1;
}
}
const locale_type & imbue(const locale_type &) { // set locale //NOTE: ignoring locales
return std::locale::classic();
}
const locale_type & getloc() const { // get locale //NOTE: ignoring locales
return std::locale::classic();
}
};
}
static std::vector<uint32_t> unicode_regex_prepare(const std::string & regex) {
std::vector<uint32_t> regex_cpts;
regex_cpts.reserve(regex.size() * 12 / 10); // estimate +20%
size_t offset = 0;
int inside_square = 0;
bool any_positive = false;
bool any_negative = false;
const size_t size = regex.size();
while (offset < size) {
inside_square += regex[offset] == '[';
inside_square -= regex[offset] == ']';
GGML_ASSERT(inside_square >= 0);
if (!inside_square) {
any_positive = false;
any_negative = false;
}
if (regex[offset] == '\\') {
const size_t i = offset + 1;
if (regex[i] == 'p' || regex[i] == 'P') {
// convert \p{Xx} to custom 'character class expression' [:Xy:]
if (regex[i + 1] == '{' && regex[i + 2] && regex[i + 3]) {
codepoint_categ categ = {};
if (regex[i + 3] == '}') {
categ = codepoint_categ::from_chars(regex[i + 2]);
offset += 5;
} else if (regex[i + 3] != '}' && regex[i + 4] == '}') {
categ = codepoint_categ::from_chars(regex[i + 2], regex[i + 3]);
offset += 6;
}
bool negated = regex[i] == 'P';
any_positive |= !negated;
any_negative |= negated;
GGML_ASSERT(any_positive != any_negative); //BUG: can not mix 'p' and 'P' inside []
GGML_ASSERT(sizeof(categ) <= 2);
// encoded category in 32 bits codepoint
uint32_t cpt_categ = (1 << 31) | (negated << 30) | categ.encoded;
if (inside_square) {
regex_cpts.insert(regex_cpts.end(), {'[', ':', cpt_categ, ':', ']'});
} else {
regex_cpts.insert(regex_cpts.end(), {'[', '[', ':', cpt_categ, ':', ']', ']'});
}
continue;
}
}
}
regex_cpts.push_back(unicode_cpt_from_utf8(regex, offset));
}
return regex_cpts;
}
// use std::basic_regex<uint32_t> to split the text codepoints
static std::vector<size_t> unicode_regex_split_stl(const std::vector<uint32_t> & text_cpts, const std::vector<uint32_t> & regex_cpts, const std::vector<size_t> & offsets) {
GGML_ASSERT(sizeof(std::codepoint) == sizeof(uint32_t));
using regex_type = std::basic_regex<std::codepoint>;
using iter_type = std::regex_iterator<const std::codepoint *>;
const std::codepoint * text_data = (const std::codepoint *) text_cpts.data();
const std::codepoint * regex_data = (const std::codepoint *) regex_cpts.data();
regex_type regex(regex_data, regex_data+regex_cpts.size());
const iter_type end;
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // reserve memory for the approximate size
for (auto offset : offsets) {
iter_type it(text_data, text_data + offset, regex);
int64_t start_idx = 0;
while (it != end) {
if (it->position() > start_idx) {
bpe_offsets.emplace_back(it->position() - start_idx);
}
bpe_offsets.emplace_back(it->length());
start_idx = it->position() + it->length();
++it;
}
if (start_idx < (int64_t) offset) {
bpe_offsets.emplace_back(offset - start_idx);
}
text_data += offset;
}
return bpe_offsets;
}
// //
// interface // interface
// //
@ -612,19 +785,46 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
return result; return result;
} }
codepoint_flags unicode_cpt_flags(const uint32_t cp) { codepoint_categ unicode_cpt_category(const uint32_t cp) {
static const codepoint_flags undef(codepoint_flags::UNDEFINED); static const std::vector<codepoint_categ> cpt_categs = [] {
static const auto cpt_flags = unicode_cpt_flags_array(); std::vector<codepoint_categ> cpt_categs(MAX_CODEPOINTS, codepoint_categ::UNDEF);
return cp < cpt_flags.size() ? cpt_flags[cp] : undef; uint32_t cpt = 0;
for (uint16_t rle : unicode_rle_codepoints_categs) {
const uint32_t index = rle & 31;
const uint32_t count = rle >> 5;
auto categ = codepoint_categ::from_index(index);
//printf("Codepoints 0x%05X to 0x%05X categ %s\n", cpt, cpt + count, categ.c_str());
categ.set_flag(codepoint_categ::DIGITS, categ.is_Nd()); // \d --> \p{Nd}
categ.set_flag(codepoint_categ::WORDS, categ.is_L() | categ.is_N()); // \w --> \p{L} \p{N} _
for (uint32_t i = 0; i <= count; ++i) {
cpt_categs[cpt++] = categ;
}
}
GGML_ASSERT(cpt == MAX_CODEPOINTS);
cpt_categs['_'].set_flag(codepoint_categ::WORDS); // \w --> \p{L} \p{N} _
for (auto p : unicode_ranges_whitespace) {
for (uint32_t cpt = p.first; cpt <= p.second; ++cpt) {
cpt_categs[cpt].set_flag(codepoint_categ::WHITESPACES);
}
}
//for (auto &range : unicode_ranges_nfd) { // start, last, nfd
// cpt_categs[cpt].set_flag(codepoint_categ::NORM_NFD);
//}
return cpt_categs;
}();
return cp < cpt_categs.size() ? cpt_categs[cp] : codepoint_categ{};
} }
codepoint_flags unicode_cpt_flags(const std::string & utf8) { codepoint_categ unicode_cpt_category(const std::string & utf8) {
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
if (utf8.empty()) { if (utf8.empty()) {
return undef; // undefined return codepoint_categ{}; // undefined
} }
size_t offset = 0; size_t offset = 0;
return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset)); return unicode_cpt_category(unicode_cpt_from_utf8(utf8, offset));
} }
std::string unicode_byte_to_utf8(uint8_t byte) { std::string unicode_byte_to_utf8(uint8_t byte) {
@ -642,171 +842,28 @@ uint32_t unicode_tolower(uint32_t cp) {
return it == unicode_map_lowercase.end() ? cp : it->second; return it == unicode_map_lowercase.end() ? cp : it->second;
} }
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) { std::vector<std::string> unicode_regex_split(const std::string & text_utf8, const std::vector<std::string> & regex_exprs) {
// unicode categories const std::vector<uint32_t> cpts = unicode_cpts_from_utf8(text_utf8);
static const std::map<std::string, int> k_ucat_enum = { std::vector<size_t> offsets = { cpts.size() };
{ "\\p{N}", codepoint_flags::NUMBER },
{ "\\p{L}", codepoint_flags::LETTER },
{ "\\p{P}", codepoint_flags::PUNCTUATION },
};
static const std::map<int, int> k_ucat_cpt = {
{ codepoint_flags::NUMBER, 0xD1 },
{ codepoint_flags::LETTER, 0xD2 },
{ codepoint_flags::PUNCTUATION, 0xD3 },
};
static const std::map<int, std::string> k_ucat_map = {
{ codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
{ codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
};
// compute collapsed codepoints only if needed by at least one regex
bool need_collapse = false;
for (auto & regex_expr : regex_exprs) {
// search for unicode categories
for (const auto & ucat : k_ucat_enum) {
if (std::string::npos != regex_expr.find(ucat.first)) {
need_collapse = true;
break;
}
}
}
const auto cpts = unicode_cpts_from_utf8(text);
// generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
// ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
std::string text_collapsed;
if (need_collapse) {
// collapse all unicode categories
text_collapsed.resize(cpts.size());
for (size_t i = 0; i < cpts.size(); ++i) {
// keep single-byte codepoints as is
if (cpts[i] < 128) {
text_collapsed[i] = cpts[i];
continue;
}
const auto flags = unicode_cpt_flags(cpts[i]);
if (flags.is_whitespace) {
//NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
//text_collapsed[i] = (char) 0x85; // <Next Line> as whitespace fallback
text_collapsed[i] = (char) 0x0B; // <vertical tab> as whitespace fallback
} else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
text_collapsed[i] = k_ucat_cpt.at(flags.category_flag());
} else {
text_collapsed[i] = (char) 0xD0; // fallback
}
}
}
std::vector<size_t> bpe_offsets = { cpts.size() };
for (auto & regex_expr : regex_exprs) { for (auto & regex_expr : regex_exprs) {
// first, see if we have an efficient custom regex implementation // first, see if we have an efficient custom regex implementation
auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets); auto tmp = unicode_regex_split_custom(text_utf8, regex_expr, offsets);
if (!tmp.empty()) { if (!tmp.empty()) {
bpe_offsets = std::move(tmp); offsets = std::move(tmp);
continue; continue;
} }
// fallback to general-purpose std::regex / std::wregex const auto regex_cpts = unicode_regex_prepare(regex_expr);
try { offsets = unicode_regex_split_stl(cpts, regex_cpts, offsets);
// if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
// with the corresponding collapsed representation
bool use_collapsed = false;
for (auto & ucat : k_ucat_enum) {
if (std::string::npos != regex_expr.find(ucat.first)) {
use_collapsed = true;
break;
}
}
if (use_collapsed) {
// sanity-check that the original regex does not contain any non-ASCII characters
const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
for (size_t i = 0; i < cpts_regex.size(); ++i) {
if (cpts_regex[i] >= 128) {
throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported");
}
}
// generate a collapsed representation of the regex
std::string regex_expr_collapsed;
// track if we are inside [], because nested [] are not allowed
bool inside = false;
for (size_t i = 0; i < regex_expr.size(); ++i) {
if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) {
regex_expr_collapsed += '[';
inside = true;
continue;
}
if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') {
regex_expr_collapsed += ']';
inside = false;
continue;
}
if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() &&
regex_expr[i + 1] == 'p' &&
regex_expr[i + 2] == '{' &&
regex_expr[i + 4] == '}') {
const std::string pat = regex_expr.substr(i, 5);
if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
if (!inside) {
regex_expr_collapsed += '[';
}
regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
if (!inside) {
regex_expr_collapsed += ']';
}
i += 4;
continue;
}
}
regex_expr_collapsed += regex_expr[i];
}
//printf("text_collapsed: %s\n", text_collapsed.c_str());
//printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str());
bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
} else {
// no unicode category used, we can use std::wregex directly
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
std::wstring wtext(cpts.begin(), cpts.end());
for (size_t i = 0; i < wtext.size(); ++i) {
if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) {
wtext[i] = 0x0B;
}
}
//printf("text: %s\n", text.c_str());
//printf("regex_expr: %s\n", regex_expr.c_str());
bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
}
} catch (std::regex_error & e) {
fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str());
fprintf(stderr, "Regex error: %s\n", e.what());
throw std::runtime_error("Failed to process regex");
}
} }
std::vector<std::string> bpe_words; std::vector<std::string> bpe_words;
bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size bpe_words.reserve(offsets.size()); // reserve memory for the approximate size
size_t start = 0; size_t start = 0;
for (size_t & offset : bpe_offsets) { for (size_t & offset : offsets) {
bpe_words.emplace_back(); bpe_words.emplace_back();
for (size_t i = start; i < start + offset; ++i) { for (size_t i = start; i < start + offset; ++i) {
bpe_words.back() += unicode_cpt_to_utf8(cpts[i]); bpe_words.back() += unicode_cpt_to_utf8(cpts[i]);

View file

@ -1,51 +1,185 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <cassert>
#include <cstring>
#include <string> #include <string>
#include <vector> #include <vector>
#include <array>
#include <map>
// TODO: prefix all symbols with "llama_" struct codepoint_categ {
// 0bffffff'ccccccc'sss --> 6 bits flags + 7 bits category + 3 bits subcategory
struct codepoint_flags { enum _category : uint16_t {
enum { UNDEF = 0, // \p{Cn} Undefined
UNDEFINED = 0x0001, C = 1 << (0 + 3), // \p{C} Control
NUMBER = 0x0002, // regex: \p{N} L = 1 << (1 + 3), // \p{L} Letter
LETTER = 0x0004, // regex: \p{L} M = 1 << (2 + 3), // \p{M} Mark
SEPARATOR = 0x0008, // regex: \p{Z} N = 1 << (3 + 3), // \p{N} Number
ACCENT_MARK = 0x0010, // regex: \p{M} P = 1 << (4 + 3), // \p{P} Punctuation
PUNCTUATION = 0x0020, // regex: \p{P} S = 1 << (5 + 3), // \p{S} Symbol
SYMBOL = 0x0040, // regex: \p{S} Z = 1 << (6 + 3), // \p{Z} Separator
CONTROL = 0x0080, // regex: \p{C} Cc = C | 1, // \p{Cc} Control
MASK_CATEGORIES = 0x00FF, Cf = C | 2, // \p{Cf} Format
Co = C | 3, // \p{Co} Private Use
Cs = C | 4, // \p{Cs} Surrrogate
Ll = L | 1, // \p{Ll} Lowercase Letter
Lm = L | 2, // \p{Lm} Modifier Letter
Lo = L | 3, // \p{Lo} Other Letter
Lt = L | 4, // \p{Lt} Titlecase Letter
Lu = L | 5, // \p{Lu} Uppercase Letter
Mc = M | 1, // \p{Mc} Spacing Mark
Me = M | 2, // \p{Me} Enclosing Mark
Mn = M | 3, // \p{Mn} Nonspacing Mark
Nd = N | 1, // \p{Nd} Decimal Number
Nl = N | 2, // \p{Nl} Letter Number
No = N | 3, // \p{No} Other Number
Pc = P | 1, // \p{Pc} Connector Punctuation
Pd = P | 2, // \p{Pd} Dash Punctuation
Pe = P | 3, // \p{Pe} Close Punctuation
Pf = P | 4, // \p{Pf} Final Punctuation
Pi = P | 5, // \p{Pi} Initial Punctuation
Po = P | 6, // \p{Po} Other Punctuation
Ps = P | 7, // \p{Ps} Open Punctuation
Sc = S | 1, // \p{Sc} Currency Symbol
Sk = S | 2, // \p{Sk} Modifier Symbol
Sm = S | 3, // \p{Sm} Math Symbol
So = S | 4, // \p{So} Other Symbol
Zl = Z | 1, // \p{Zl} Line Separator
Zp = Z | 2, // \p{Zp} Paragraph Separator
Zs = Z | 3, // \p{Zs} Space Separator
SUBMASK = (1 << 3) - 1, // 3 bits 0b000000'0000000'111
MASK = (1 << 10) - 1, // 7+3 bits 0b000000'1111111'111
}; };
// codepoint type enum _flags : uint16_t {
uint16_t is_undefined : 1; WHITESPACES = (1 << 10), // regex: \s
uint16_t is_number : 1; // regex: \p{N} WORDS = (1 << 11), // regex: \w
uint16_t is_letter : 1; // regex: \p{L} DIGITS = (1 << 12), // regex: \d
uint16_t is_separator : 1; // regex: \p{Z} //Norm NFD/NFC = ...,
uint16_t is_accent_mark : 1; // regex: \p{M} };
uint16_t is_punctuation : 1; // regex: \p{P}
uint16_t is_symbol : 1; // regex: \p{S}
uint16_t is_control : 1; // regex: \p{C}
// helper flags
uint16_t is_whitespace : 1; // regex: \s
uint16_t is_lowercase : 1;
uint16_t is_uppercase : 1;
uint16_t is_nfd : 1;
// decode from uint16 inline codepoint_categ(const uint16_t categ=0) : encoded{categ} {}
inline codepoint_flags(const uint16_t flags=0) {
*reinterpret_cast<uint16_t*>(this) = flags; inline void set_flag(_flags flags, bool value = true) {
flags = (_flags) (flags & ~MASK); // do not modify category bits
encoded = value ? (encoded | flags) : (encoded & ~flags);
} }
inline uint16_t as_uint() const { inline uint16_t get_category() const { return encoded & MASK; }
return *reinterpret_cast<const uint16_t*>(this);
inline bool is_undefined() const { return !encoded; }
inline bool is_defined() const { return encoded; }
inline uint16_t is_whitespace() const { return encoded & WHITESPACES; }
inline uint16_t is_word() const { return encoded & WORDS; }
inline uint16_t is_digit() const { return encoded & DIGITS; }
inline uint16_t is_C() const { return encoded & C; }
inline uint16_t is_L() const { return encoded & L; }
inline uint16_t is_M() const { return encoded & M; }
inline uint16_t is_N() const { return encoded & N; }
inline uint16_t is_P() const { return encoded & P; }
inline uint16_t is_S() const { return encoded & S; }
inline uint16_t is_Z() const { return encoded & Z; }
inline bool is_Cc() const { return (encoded & MASK) == Cc; }
inline bool is_Cf() const { return (encoded & MASK) == Cf; }
inline bool is_Co() const { return (encoded & MASK) == Co; }
inline bool is_Cs() const { return (encoded & MASK) == Cs; }
inline bool is_Ll() const { return (encoded & MASK) == Ll; }
inline bool is_Lm() const { return (encoded & MASK) == Lm; }
inline bool is_Lo() const { return (encoded & MASK) == Lo; }
inline bool is_Lt() const { return (encoded & MASK) == Lt; }
inline bool is_Lu() const { return (encoded & MASK) == Lu; }
inline bool is_Mc() const { return (encoded & MASK) == Mc; }
inline bool is_Me() const { return (encoded & MASK) == Me; }
inline bool is_Mn() const { return (encoded & MASK) == Mn; }
inline bool is_Nd() const { return (encoded & MASK) == Nd; }
inline bool is_Nl() const { return (encoded & MASK) == Nl; }
inline bool is_No() const { return (encoded & MASK) == No; }
inline bool is_Pc() const { return (encoded & MASK) == Pc; }
inline bool is_Pd() const { return (encoded & MASK) == Pd; }
inline bool is_Pe() const { return (encoded & MASK) == Pe; }
inline bool is_Pf() const { return (encoded & MASK) == Pf; }
inline bool is_Pi() const { return (encoded & MASK) == Pi; }
inline bool is_Po() const { return (encoded & MASK) == Po; }
inline bool is_Ps() const { return (encoded & MASK) == Ps; }
inline bool is_Sc() const { return (encoded & MASK) == Sc; }
inline bool is_Sk() const { return (encoded & MASK) == Sk; }
inline bool is_Sm() const { return (encoded & MASK) == Sm; }
inline bool is_So() const { return (encoded & MASK) == So; }
inline bool is_Zl() const { return (encoded & MASK) == Zl; }
inline bool is_Zp() const { return (encoded & MASK) == Zp; }
inline bool is_Zs() const { return (encoded & MASK) == Zs; }
inline uint64_t expand_bits(const bool add_categ=true) const { // one bit for each category/subcateory and flags
const uint32_t subindex = encoded & SUBMASK;
const uint64_t bits = (encoded & MASK) >> 3;
const uint64_t flags = encoded >> 10;
return (flags << (7 * 8)) | (bits << (7 * subindex)) | (bits * add_categ);
} }
inline uint16_t category_flag() const { inline bool is_in_range(const codepoint_categ other) const { // this.first <= other <= this.last
return this->as_uint() & MASK_CATEGORIES; if (encoded & SUBMASK) {
return encoded == other.encoded; // no range
} }
if (encoded & MASK) {
return encoded == (other.encoded & ~SUBMASK); // from 0bffffff'ccccccc'000 to 0bffffff'ccccccc'111
}
return encoded == (other.encoded & ~MASK); // from 0bffffff'0000000'000 to 0bffffff'1111111'111
}
inline bool operator == (const codepoint_categ other) const {
return encoded == other.encoded;
}
inline bool operator != (const codepoint_categ other) const {
return encoded != other.encoded;
}
const char * c_str() const {
static const std::map<uint16_t, const char *> map = {
{UNDEF, "UNDEF"}, {C, "C"}, {L, "L"}, {M, "M"}, {N, "N"}, {P, "P"}, {S, "S"}, {Z, "Z"},
{Cc, "Cc"}, {Cf, "Cf"}, {Co, "Co"}, {Cs, "Cs"}, {Ll, "Ll"}, {Lm, "Lm"}, {Lo, "Lo"}, {Lt, "Lt"},
{Lu, "Lu"}, {Mc, "Mc"}, {Me, "Me"}, {Mn, "Mn"}, {Nd, "Nd"}, {Nl, "Nl"}, {No, "No"}, {Pc, "Pc"},
{Pd, "Pd"}, {Pe, "Pe"}, {Pf, "Pf"}, {Pi, "Pi"}, {Po, "Po"}, {Ps, "Ps"}, {Sc, "Sc"}, {Sk, "Sk"},
{Sm, "Sm"}, {So, "So"}, {Zl, "Zl"}, {Zp, "Zp"}, {Zs, "Zs"},
};
const auto it = map.find(encoded & MASK);
return it == map.end() ? "INVALID" : it->second;
}
static codepoint_categ from_index(int index) {
static const std::array<codepoint_categ, 32> table = {
UNDEF, Cc, Cf, Co, Cs, Ll, Lm, Lo, Lt, Lu, Mc, Me, Mn, Nd, Nl, No, Pc, Pd, Pe, Pf, Pi, Po, Ps, Sc, Sk, Sm, So, Zl, Zp, Zs, UNDEF, UNDEF
};
return (size_t)index < table.size() ? table[index] : table[0];
}
static codepoint_categ from_chars(const char categ, const char subcateg = '\0') {
auto _subindex = [] (const char subcateg, const char subcategs[]) -> uint16_t {
if (!subcateg) {
return 0;
}
const char * p = strchr(subcategs, subcateg);
GGML_ASSERT(p);
return (uint16_t) (p - subcategs + 1);
};
switch(categ) {
case 'C': if(subcateg == 'n') return 0; // undefined
return C | _subindex(subcateg, "cfos" );
case 'L': return L | _subindex(subcateg, "lmotu" );
case 'M': return M | _subindex(subcateg, "cen" );
case 'N': return N | _subindex(subcateg, "dlo" );
case 'P': return P | _subindex(subcateg, "cdefios");
case 'S': return S | _subindex(subcateg, "ckmo" );
case 'Z': return Z | _subindex(subcateg, "lps" );
default: GGML_ABORT("invalid category character");
}
}
uint16_t encoded;
}; };
size_t unicode_len_utf8(char src); size_t unicode_len_utf8(char src);
@ -56,8 +190,8 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts); std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
codepoint_flags unicode_cpt_flags(const uint32_t cp); codepoint_categ unicode_cpt_category(const uint32_t cp);
codepoint_flags unicode_cpt_flags(const std::string & utf8); codepoint_categ unicode_cpt_category(const std::string & utf8);
std::string unicode_byte_to_utf8(uint8_t byte); std::string unicode_byte_to_utf8(uint8_t byte);
uint8_t unicode_utf8_to_byte(const std::string & utf8); uint8_t unicode_utf8_to_byte(const std::string & utf8);

View file

@ -116,9 +116,24 @@ class LibLlamaModel:
num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special) num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special)
return str(cast(Buffer, self.ffi.buffer(self.text_buff, num)), encoding="utf-8", errors="replace") # replace errors with '\uFFFD' return str(cast(Buffer, self.ffi.buffer(self.text_buff, num)), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
def get_vocab(self, detokenize=False) -> list[str]:
vocab: list[str] = []
num_tokens = self.lib.llama_n_vocab(self.model)
for id in range(num_tokens):
if detokenize:
text = self.detokenize([id], remove_special=False, unparse_special=True)
else:
text = self.lib.llama_token_get_text(self.model, id)
text = str(cast(bytes, self.ffi.string(text)), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
vocab.append(text)
return vocab
class Tokenizer: class Tokenizer:
def get_vocab(self, detokenize=False) -> list[str]:
raise NotImplementedError
def encode(self, text: str) -> list[int]: def encode(self, text: str) -> list[int]:
raise NotImplementedError raise NotImplementedError
@ -129,7 +144,7 @@ class Tokenizer:
class TokenizerGroundtruth (Tokenizer): class TokenizerGroundtruth (Tokenizer):
def __init__(self, dir_tokenizer: str): def __init__(self, dir_tokenizer: str):
self.model: PreTrainedTokenizer = AutoTokenizer.from_pretrained(dir_tokenizer) self.model: PreTrainedTokenizer = AutoTokenizer.from_pretrained(dir_tokenizer, trust_remote_code=False)
# guess BOS and EOS # guess BOS and EOS
ids = self.encode("a") ids = self.encode("a")
assert 1 <= len(ids) <= 3 assert 1 <= len(ids) <= 3
@ -138,15 +153,25 @@ class TokenizerGroundtruth (Tokenizer):
self.add_bos_token = getattr(self.model, "add_bos_token", add_bos_token) self.add_bos_token = getattr(self.model, "add_bos_token", add_bos_token)
self.add_eos_token = getattr(self.model, "add_eos_token", add_eos_token) self.add_eos_token = getattr(self.model, "add_eos_token", add_eos_token)
# build vocab # build vocab
tokens = list(self.model.get_vocab().values()) self.vocab = self.get_vocab(detokenize=True)
self.vocab = self.model.batch_decode(tokens, skip_special_tokens=True)
self.vocab = list(sorted(self.vocab))
# tokens and lists # tokens and lists
self.special_tokens = list(self.model.all_special_tokens) self.special_tokens = [self.vocab[i] for i in sorted(self.model.all_special_ids)]
self.added_tokens = self.model.batch_decode(self.model.added_tokens_encoder.values(), skip_special_tokens=False) self.added_tokens = [self.vocab[i] for i in sorted(self.model.added_tokens_encoder.values())]
self.bos_token = self.model.bos_token self.bos_token = self.model.bos_token
self.eos_token = self.model.eos_token self.eos_token = self.model.eos_token
def get_vocab(self, detokenize=False) -> list[str]:
vocab: list[str] = []
max_token_id = max(self.model.get_vocab().values())
if detokenize:
ids = list(range(max_token_id + 1))
vocab = self.model.batch_decode(ids, skip_special_tokens=False)
else:
vocab = [""] * (max_token_id + 1)
for text, id in self.model.get_vocab().items():
vocab[id] = text
return vocab
def encode(self, text: str) -> list[int]: def encode(self, text: str) -> list[int]:
return self.model.encode(text, add_special_tokens=True) return self.model.encode(text, add_special_tokens=True)
@ -163,6 +188,9 @@ class TokenizerLlamaCpp (Tokenizer):
self.libllama = LibLlama() self.libllama = LibLlama()
self.model = LibLlamaModel(self.libllama, vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096)) self.model = LibLlamaModel(self.libllama, vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
def get_vocab(self, detokenize=False) -> list[str]:
return self.model.get_vocab(detokenize)
def encode(self, text: str) -> list[int]: def encode(self, text: str) -> list[int]:
return self.model.tokenize(text, add_special=True, parse_special=True) return self.model.tokenize(text, add_special=True, parse_special=True)
@ -253,6 +281,23 @@ def generator_vocab_words(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
yield from tokenizer.vocab yield from tokenizer.vocab
def generator_byte_tokens() -> Iterator[str]:
"""Brute force check common byte encoding"""
for a, b in ["<>", "[]", "()", ("\\", "")]:
yield from [f"{a}{i}{b}" for i in range(256)]
yield from [f"{a}{i:x}{b}" for i in range(256)]
yield from [f"{a}{i:X}{b}" for i in range(256)]
yield from [f"{a}x{i:x}{b}" for i in range(256)]
yield from [f"{a}x{i:X}{b}" for i in range(256)]
yield from [f"{a}x{i:02x}{b}" for i in range(256)]
yield from [f"{a}x{i:02X}{b}" for i in range(256)]
yield from [f"{a}0x{i:x}{b}" for i in range(256)]
yield from [f"{a}0x{i:X}{b}" for i in range(256)]
yield from [f"{a}0x{i:02x}{b}" for i in range(256)]
yield from [f"{a}0x{i:02X}{b}" for i in range(256)]
yield from [f"{a}{chr(i)}{b}" for i in range(256)]
def generator_ascii_lr_strip() -> Iterator[str]: def generator_ascii_lr_strip() -> Iterator[str]:
WHITESPACES = ["", " ", " "] WHITESPACES = ["", " ", " "]
CHARACTERS = list(chr(i) for i in range(1, 0x80)) + [""] CHARACTERS = list(chr(i) for i in range(1, 0x80)) + [""]
@ -275,10 +320,11 @@ def generator_apostrophe() -> Iterator[str]:
yield char1 + lstrip + "'" + rstrip + char2 yield char1 + lstrip + "'" + rstrip + char2
yield char1 + char2 + lstrip + "'" + rstrip + "z" yield char1 + char2 + lstrip + "'" + rstrip + "z"
yield "a" + lstrip + "'" + rstrip + char1 + char2 yield "a" + lstrip + "'" + rstrip + char1 + char2
yield "a" + lstrip + "'" + char1 + char2 + rstrip + "z"
def generator_added_lr_strip(tokenizer: TokenizerGroundtruth) -> Iterator[str]: def generator_added_lr_strip(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
WHITESPACES = ["", " ", " ", "\n", "\r\n", "\n\n", "\t", "\t\t"] WHITESPACES = ["", " ", " ", "\n", "\r\n", "\n\n", "\t", "\t\t", " "]
all_tokens = list(sorted(set(tokenizer.special_tokens + tokenizer.added_tokens))) all_tokens = list(sorted(set(tokenizer.special_tokens + tokenizer.added_tokens)))
for token in all_tokens: for token in all_tokens:
for lstrip in WHITESPACES: for lstrip in WHITESPACES:
@ -409,14 +455,6 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]): def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str):
for i, (a, b) in enumerate(zip(ids1, ids2)):
if a != b:
return i
if len(ids1) == len(ids2):
return -1
return min(len(ids1), len(ids2))
def check_detokenizer(text: str, text1: str, text2: str) -> bool: def check_detokenizer(text: str, text1: str, text2: str) -> bool:
if text1 == text2: # equal to TokenizerGroundtruth? if text1 == text2: # equal to TokenizerGroundtruth?
return True return True
@ -434,9 +472,11 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
t_decode1 = 0 t_decode1 = 0
t_decode2 = 0 t_decode2 = 0
t_start = time.perf_counter() t_start = time.perf_counter()
total_tests = 0
failing_texts = set()
encode_errors = 0 encode_errors = 0
decode_errors = 0 decode_errors = 0
MAX_ERRORS = 10 MAX_ERRORS = 5
logger.info("%s: %s" % (generator.__qualname__, "ini")) logger.info("%s: %s" % (generator.__qualname__, "ini"))
for text in generator: for text in generator:
@ -455,22 +495,54 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
t_encode2 += t2 - t1 t_encode2 += t2 - t1
t_decode1 += t3 - t2 t_decode1 += t3 - t2
t_decode2 += t4 - t3 t_decode2 += t4 - t3
if encode_errors < MAX_ERRORS and ids1 != ids2: total_tests += 1
i = find_first_mismatch(ids1, ids2) # compare
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1] encode_ok = ids1 == ids2
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1] decode_ok = check_detokenizer(text, text1, text2)
logger.error(" Expected: " + str(ids1)) if not (encode_ok and decode_ok):
logger.error(" Result: " + str(ids2)) def _compare(text: str):
ids1 = tokenizer1.encode(text)
ids2 = tokenizer2.encode(text)
text1 = tokenizer1.decode(ids1)
text2 = tokenizer2.decode(ids1)
encode_ok = ids1 == ids2
decode_ok = check_detokenizer(text, text1, text2)
ok = encode_ok and decode_ok
return ok, ids1, ids2, text1, text2
# binary search upper and lower failing range
a, b = 0, len(text)
step = b
while step > 1:
step = (step + 1) // 2
t = max(a, b - step)
if not _compare(text[a : t])[0]:
b = t
step = b
while step > 1:
step = (step + 1) // 2
t = min(a + step, b)
if not _compare(text[t : b])[0]:
a = t
ok, ids1, ids2, text1, text2 = _compare(text[a : b])
assert a <= b and not ok
# show unique failing texts differences
failing_text = text[a : b]
if failing_text not in failing_texts:
failing_texts.add(failing_text)
if encode_errors < MAX_ERRORS and not encode_ok:
encode_errors += 1 encode_errors += 1
logger.error(f" {encode_errors=}") logger.error(f" {encode_errors=}")
if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2): logger.error(" Text:" + repr(failing_text))
i = find_first_mismatch(text1, text2) logger.error(" " + " ".join(repr(x) + ":" + hex(ord(x)) for x in failing_text))
text1 = list(text1[max(0, i - 2) : i + 5 + 1]) logger.error(" Expected: " + str(ids1))
text2 = list(text2[max(0, i - 2) : i + 5 + 1]) logger.error(" Result: " + str(ids2))
logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1)) if decode_errors < MAX_ERRORS and not decode_ok:
logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2))
decode_errors += 1 decode_errors += 1
logger.error(f" {decode_errors=}") logger.error(f" {decode_errors=}")
logger.error(" Text:" + repr(failing_text))
logger.error(" " + " ".join(repr(x) + ":" + hex(ord(x)) for x in failing_text))
logger.error(" Expected: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text1))
logger.error(" Result: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text2))
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS: if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
logger.error(f" EXIT: {encode_errors=} {decode_errors=}") logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
# raise Exception() # raise Exception()
@ -480,6 +552,43 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}") logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
def compare_vocabs(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp):
MAX_PRINT_ERRORS = 10
logger.info("compare_vocabs: ini")
t_start = time.perf_counter()
for detokenize in (False, True):
vocab1 = tokenizer1.get_vocab(detokenize)
vocab2 = tokenizer2.get_vocab(detokenize)
if vocab1 != vocab2:
num_errors = 0
for i in range(max(len(vocab1), len(vocab2))):
text1 = vocab1[i] if i < len(vocab1) else None
text2 = vocab2[i] if i < len(vocab2) else None
if text1 != text2:
# is "[UNUSED_TOKEN_" and "[PAD" valid for all models ? #TODO: use toktypes
if text1 is not None:
text1 = text1.replace("[UNUSED_TOKEN_", "[PAD")
if text2 is not None:
text2 = text2.replace("[UNUSED_TOKEN_", "[PAD")
if text1 is None and (text2 or "").startswith('[PAD'):
text2 = None
if text2 is None and (text1 or "").startswith('[PAD'):
text1 = None
if text1 != text2:
num_errors += 1
if num_errors < MAX_PRINT_ERRORS:
logger.error(f" {detokenize=} id={i} expected={repr(text1)} result={repr(text2)}")
if num_errors:
logger.error(f" {num_errors=}")
t_total = time.perf_counter() - t_start
logger.info(f"compare_vocabs: end, {t_total=:.3f}")
def main(argv: list[str] | None = None): def main(argv: list[str] | None = None):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file") parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file")
@ -493,18 +602,21 @@ def main(argv: list[str] | None = None):
tokenizer1 = TokenizerGroundtruth(args.dir_tokenizer) tokenizer1 = TokenizerGroundtruth(args.dir_tokenizer)
tokenizer2 = TokenizerLlamaCpp(args.vocab_file) tokenizer2 = TokenizerLlamaCpp(args.vocab_file)
# compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text()) compare_vocabs(tokenizer1, tokenizer2)
# compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text_edge_cases())
compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text())
compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text_edge_cases())
compare_tokenizers(tokenizer1, tokenizer2, generator_byte_tokens())
compare_tokenizers(tokenizer1, tokenizer2, generator_ascii_lr_strip()) compare_tokenizers(tokenizer1, tokenizer2, generator_ascii_lr_strip())
compare_tokenizers(tokenizer1, tokenizer2, generator_apostrophe()) compare_tokenizers(tokenizer1, tokenizer2, generator_apostrophe())
compare_tokenizers(tokenizer1, tokenizer2, generator_unicodes()) compare_tokenizers(tokenizer1, tokenizer2, generator_unicodes())
compare_tokenizers(tokenizer1, tokenizer2, generator_vocab_words(tokenizer1)) compare_tokenizers(tokenizer1, tokenizer2, generator_vocab_words(tokenizer1))
compare_tokenizers(tokenizer1, tokenizer2, generator_added_lr_strip(tokenizer1)) compare_tokenizers(tokenizer1, tokenizer2, generator_added_lr_strip(tokenizer1))
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_added_tokens(tokenizer1, 10_000)) compare_tokenizers(tokenizer1, tokenizer2, generator_random_added_tokens(tokenizer1, 10_000))
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_chars(10_000)) compare_tokenizers(tokenizer1, tokenizer2, generator_random_chars(10_000))
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_unicodes(10_000)) compare_tokenizers(tokenizer1, tokenizer2, generator_random_unicodes(10_000))
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_chars(tokenizer1, 10_000)) compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_chars(tokenizer1, 10_000))
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_words(tokenizer1, 5_000)) compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_words(tokenizer1, 5_000))
tokenizer2.model.free() tokenizer2.model.free()
@ -533,21 +645,19 @@ if __name__ == "__main__":
"phi-3", # SPM "phi-3", # SPM
"gemma", # SPM "gemma", # SPM
"gemma-2", # SPM "gemma-2", # SPM
"baichuan", # SPM # "baichuan", # SPM
"bert-bge", # WPM "bert-bge", # WPM
"jina-v2-en", # WPM "jina-v2-en", # WPM
# "t5", # UGM
"llama-bpe", # BPE "llama-bpe", # BPE
"phi-2", # BPE "phi-2", # BPE
"deepseek-llm", # BPE "deepseek-llm", # BPE
"deepseek-coder", # BPE "deepseek-coder", # BPE
"falcon", # BPE "falcon", # BPE
"mpt", # BPE
"starcoder", # BPE "starcoder", # BPE
"gpt-2", # BPE "gpt-2", # BPE
"stablelm2", # BPE "stablelm2", # BPE
"refact", # BPE "refact", # BPE
"qwen2", # BPE
"olmo", # BPE
"jina-v2-es", # BPE "jina-v2-es", # BPE
"jina-v2-de", # BPE "jina-v2-de", # BPE
"smaug-bpe", # BPE "smaug-bpe", # BPE
@ -555,6 +665,14 @@ if __name__ == "__main__":
"jina-v2-code", # BPE "jina-v2-code", # BPE
"viking", # BPE "viking", # BPE
"jais", # BPE "jais", # BPE
"codeshell", # BPE
"tekken", # BPE
"smollm", # BPE
"mpt", # BPE NFC
"command-r", # BPE NFC
"qwen2", # BPE NFC
"olmo", # BPE NFC
"gpt-neox", # BPE NFC
] ]
logger.info("=" * 50) logger.info("=" * 50)