Merge db78320b4d
into a876861455
This commit is contained in:
commit
dffb4b1909
7 changed files with 5275 additions and 2746 deletions
|
@ -49,53 +49,42 @@ def unicode_data_iter():
|
|||
yield (cpt, cpt_lower, cpt_upper, categ, bidir)
|
||||
|
||||
|
||||
# see definition in unicode.h
|
||||
CODEPOINT_FLAG_UNDEFINED = 0x0001 #
|
||||
CODEPOINT_FLAG_NUMBER = 0x0002 # \p{N}
|
||||
CODEPOINT_FLAG_LETTER = 0x0004 # \p{L}
|
||||
CODEPOINT_FLAG_SEPARATOR = 0x0008 # \p{Z}
|
||||
CODEPOINT_FLAG_MARK = 0x0010 # \p{M}
|
||||
CODEPOINT_FLAG_PUNCTUATION = 0x0020 # \p{P}
|
||||
CODEPOINT_FLAG_SYMBOL = 0x0040 # \p{S}
|
||||
CODEPOINT_FLAG_CONTROL = 0x0080 # \p{C}
|
||||
|
||||
UNICODE_CATEGORY_TO_FLAG = {
|
||||
"Cn": CODEPOINT_FLAG_UNDEFINED, # Undefined
|
||||
"Cc": CODEPOINT_FLAG_CONTROL, # Control
|
||||
"Cf": CODEPOINT_FLAG_CONTROL, # Format
|
||||
"Co": CODEPOINT_FLAG_CONTROL, # Private Use
|
||||
"Cs": CODEPOINT_FLAG_CONTROL, # Surrrogate
|
||||
"Ll": CODEPOINT_FLAG_LETTER, # Lowercase Letter
|
||||
"Lm": CODEPOINT_FLAG_LETTER, # Modifier Letter
|
||||
"Lo": CODEPOINT_FLAG_LETTER, # Other Letter
|
||||
"Lt": CODEPOINT_FLAG_LETTER, # Titlecase Letter
|
||||
"Lu": CODEPOINT_FLAG_LETTER, # Uppercase Letter
|
||||
"L&": CODEPOINT_FLAG_LETTER, # Cased Letter
|
||||
"Mc": CODEPOINT_FLAG_MARK, # Spacing Mark
|
||||
"Me": CODEPOINT_FLAG_MARK, # Enclosing Mark
|
||||
"Mn": CODEPOINT_FLAG_MARK, # Nonspacing Mark
|
||||
"Nd": CODEPOINT_FLAG_NUMBER, # Decimal Number
|
||||
"Nl": CODEPOINT_FLAG_NUMBER, # Letter Number
|
||||
"No": CODEPOINT_FLAG_NUMBER, # Other Number
|
||||
"Pc": CODEPOINT_FLAG_PUNCTUATION, # Connector Punctuation
|
||||
"Pd": CODEPOINT_FLAG_PUNCTUATION, # Dash Punctuation
|
||||
"Pe": CODEPOINT_FLAG_PUNCTUATION, # Close Punctuation
|
||||
"Pf": CODEPOINT_FLAG_PUNCTUATION, # Final Punctuation
|
||||
"Pi": CODEPOINT_FLAG_PUNCTUATION, # Initial Punctuation
|
||||
"Po": CODEPOINT_FLAG_PUNCTUATION, # Other Punctuation
|
||||
"Ps": CODEPOINT_FLAG_PUNCTUATION, # Open Punctuation
|
||||
"Sc": CODEPOINT_FLAG_SYMBOL, # Currency Symbol
|
||||
"Sk": CODEPOINT_FLAG_SYMBOL, # Modifier Symbol
|
||||
"Sm": CODEPOINT_FLAG_SYMBOL, # Math Symbol
|
||||
"So": CODEPOINT_FLAG_SYMBOL, # Other Symbol
|
||||
"Zl": CODEPOINT_FLAG_SEPARATOR, # Line Separator
|
||||
"Zp": CODEPOINT_FLAG_SEPARATOR, # Paragraph Separator
|
||||
"Zs": CODEPOINT_FLAG_SEPARATOR, # Space Separator
|
||||
# see codepoint_categ::from_index() in unicode.h
|
||||
UNICODE_CATEGORY_TO_INDEX = {
|
||||
"Cn": 0, # \p{Cn} Undefined
|
||||
"Cc": 1, # \p{Cc} Control
|
||||
"Cf": 2, # \p{Cf} Format
|
||||
"Co": 3, # \p{Co} Private Use
|
||||
"Cs": 4, # \p{Cs} Surrrogate
|
||||
"Ll": 5, # \p{Ll} Lowercase Letter
|
||||
"Lm": 6, # \p{Lm} Modifier Letter
|
||||
"Lo": 7, # \p{Lo} Other Letter
|
||||
"Lt": 8, # \p{Lt} Titlecase Letter
|
||||
"Lu": 9, # \p{Lu} Uppercase Letter
|
||||
"Mc": 10, # \p{Mc} Spacing Mark
|
||||
"Me": 11, # \p{Me} Enclosing Mark
|
||||
"Mn": 12, # \p{Mn} Nonspacing Mark
|
||||
"Nd": 13, # \p{Nd} Decimal Number
|
||||
"Nl": 14, # \p{Nl} Letter Number
|
||||
"No": 15, # \p{No} Other Number
|
||||
"Pc": 16, # \p{Pc} Connector Punctuation
|
||||
"Pd": 17, # \p{Pd} Dash Punctuation
|
||||
"Pe": 18, # \p{Pe} Close Punctuation
|
||||
"Pf": 19, # \p{Pf} Final Punctuation
|
||||
"Pi": 20, # \p{Pi} Initial Punctuation
|
||||
"Po": 21, # \p{Po} Other Punctuation
|
||||
"Ps": 22, # \p{Ps} Open Punctuation
|
||||
"Sc": 23, # \p{Sc} Currency Symbol
|
||||
"Sk": 24, # \p{Sk} Modifier Symbol
|
||||
"Sm": 25, # \p{Sm} Math Symbol
|
||||
"So": 26, # \p{So} Other Symbol
|
||||
"Zl": 27, # \p{Zl} Line Separator
|
||||
"Zp": 28, # \p{Zp} Paragraph Separator
|
||||
"Zs": 29, # \p{Zs} Space Separator
|
||||
}
|
||||
|
||||
|
||||
codepoint_flags = array.array('H', [CODEPOINT_FLAG_UNDEFINED]) * MAX_CODEPOINTS
|
||||
table_whitespace = []
|
||||
codepoint_categs = array.array('B', [0]) * MAX_CODEPOINTS # Undefined
|
||||
table_lowercase = []
|
||||
table_uppercase = []
|
||||
table_nfd = []
|
||||
|
@ -105,7 +94,7 @@ for (cpt, cpt_lower, cpt_upper, categ, bidir) in unicode_data_iter():
|
|||
char = chr(cpt)
|
||||
|
||||
# codepoint category flags
|
||||
codepoint_flags[cpt] = UNICODE_CATEGORY_TO_FLAG[categ]
|
||||
codepoint_categs[cpt] = UNICODE_CATEGORY_TO_INDEX[categ]
|
||||
|
||||
# lowercase conversion
|
||||
if cpt_lower:
|
||||
|
@ -121,25 +110,31 @@ for (cpt, cpt_lower, cpt_upper, categ, bidir) in unicode_data_iter():
|
|||
table_nfd.append((cpt, norm))
|
||||
|
||||
|
||||
# whitespaces, see "<White_Space>" https://www.unicode.org/Public/UCD/latest/ucd/PropList.txt
|
||||
table_whitespace.extend(range(0x0009, 0x000D + 1))
|
||||
table_whitespace.extend(range(0x2000, 0x200A + 1))
|
||||
table_whitespace.extend([0x0020, 0x0085, 0x00A0, 0x1680, 0x2028, 0x2029, 0x202F, 0x205F, 0x3000])
|
||||
|
||||
|
||||
# sort by codepoint
|
||||
table_whitespace.sort()
|
||||
table_lowercase.sort()
|
||||
table_uppercase.sort()
|
||||
table_nfd.sort()
|
||||
|
||||
|
||||
# group ranges with same flags
|
||||
ranges_flags: list[tuple[int, int]] = [(0, codepoint_flags[0])] # start, flags
|
||||
for codepoint, flags in enumerate(codepoint_flags):
|
||||
if flags != ranges_flags[-1][1]:
|
||||
ranges_flags.append((codepoint, flags))
|
||||
ranges_flags.append((MAX_CODEPOINTS, 0x0000))
|
||||
# whitespaces, see "<White_Space>" https://www.unicode.org/Public/UCD/latest/ucd/PropList.txt
|
||||
whitespace_ranges: list[tuple[int, int]] = [] # start, last
|
||||
whitespace_ranges.append((0x0009, 0x000D))
|
||||
whitespace_ranges.append((0x2000, 0x200A))
|
||||
for whitespace in [0x0020, 0x0085, 0x00A0, 0x1680, 0x2028, 0x2029, 0x202F, 0x205F, 0x3000]:
|
||||
whitespace_ranges.append((whitespace, whitespace))
|
||||
|
||||
|
||||
# run length encoding, see unicode_cpt_category() in unicode.cpp
|
||||
assert (max(UNICODE_CATEGORY_TO_INDEX.values()) < 32)
|
||||
codepoint_categs_runs = [codepoint_categs[0]] # 5 bits categ + 11 bits length
|
||||
for cpt, categ in enumerate(codepoint_categs[1:], 1):
|
||||
prev = codepoint_categs_runs[-1]
|
||||
if prev <= (0xFFFF - 32) and (prev & 31) == categ:
|
||||
codepoint_categs_runs[-1] += 32 # increment run length
|
||||
else:
|
||||
codepoint_categs_runs.append(categ) # new run value
|
||||
assert (codepoint_categs_runs[-1] < 0xFFFF)
|
||||
assert (MAX_CODEPOINTS == sum((rle >> 5) + 1 for rle in codepoint_categs_runs))
|
||||
|
||||
|
||||
# group ranges with same nfd
|
||||
|
@ -153,7 +148,7 @@ for codepoint, norm in table_nfd:
|
|||
|
||||
|
||||
# Generate 'unicode-data.cpp':
|
||||
# python ./scripts//gen-unicode-data.py > unicode-data.cpp
|
||||
# python ./scripts//gen-unicode-data.py > ./src/unicode-data.cpp
|
||||
|
||||
def out(line=""):
|
||||
print(line, end='\n') # noqa
|
||||
|
@ -167,17 +162,16 @@ out("""\
|
|||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
""")
|
||||
|
||||
out("const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags = { // start, flags // last=next_start-1")
|
||||
for codepoint, flags in ranges_flags:
|
||||
out("{0x%06X, 0x%04X}," % (codepoint, flags))
|
||||
out("const std::vector<uint16_t> unicode_rle_codepoints_categs = { // run length encoding, 5 bits categ + 11 bits length")
|
||||
for rle in codepoint_categs_runs:
|
||||
out("0x%04X," % rle)
|
||||
out("};\n")
|
||||
|
||||
out("const std::unordered_set<uint32_t> unicode_set_whitespace = {")
|
||||
for codepoint in table_whitespace:
|
||||
out("0x%06X," % codepoint)
|
||||
out("const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace = {")
|
||||
for (start, last) in whitespace_ranges:
|
||||
out("{0x%06X, 0x%06X}," % (start, last))
|
||||
out("};\n")
|
||||
|
||||
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")
|
||||
|
|
|
@ -444,10 +444,8 @@ struct llm_tokenizer_bpe {
|
|||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_TEKKEN:
|
||||
// original regex from tokenizer.json
|
||||
// "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||
regex_exprs = {
|
||||
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
default:
|
||||
|
@ -704,22 +702,21 @@ struct llm_tokenizer_wpm {
|
|||
std::vector<std::string> words(1, "");
|
||||
|
||||
for (const uint32_t cpt : cpts_nfd) {
|
||||
const auto flags = unicode_cpt_flags(cpt);
|
||||
const auto categ = unicode_cpt_category(cpt);
|
||||
|
||||
if (flags.is_whitespace) {
|
||||
if (categ.is_whitespace()) {
|
||||
if (words.back().size()) { // finish previous word if any
|
||||
words.emplace_back();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
assert (!flags.is_separator);
|
||||
if (cpt == 0 || cpt == 0xFFFD || flags.is_control) {
|
||||
if (cpt == 0 || cpt == 0xFFFD || categ.is_C()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt));
|
||||
if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) {
|
||||
if (categ.is_P() || (cpt < 0x7F && categ.is_S()) || is_chinese_char(cpt)) {
|
||||
if (words.back().size()) { // finish previous word if any
|
||||
words.emplace_back();
|
||||
}
|
||||
|
@ -737,7 +734,7 @@ struct llm_tokenizer_wpm {
|
|||
return words;
|
||||
}
|
||||
|
||||
static bool is_chinese_char(uint32_t cpt) {
|
||||
static bool is_chinese_char(uint32_t cpt) { //TODO: move to unicode-data.cpp? unicode_cpt_category(cpt).is_chinese()?
|
||||
return
|
||||
(cpt >= 0x04E00 && cpt <= 0x09FFF) ||
|
||||
(cpt >= 0x03400 && cpt <= 0x04DBF) ||
|
||||
|
|
6832
src/unicode-data.cpp
6832
src/unicode-data.cpp
File diff suppressed because it is too large
Load diff
|
@ -3,7 +3,6 @@
|
|||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
struct range_nfd {
|
||||
uint32_t first;
|
||||
|
@ -13,8 +12,8 @@ struct range_nfd {
|
|||
|
||||
static const uint32_t MAX_CODEPOINTS = 0x110000;
|
||||
|
||||
extern const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags;
|
||||
extern const std::unordered_set<uint32_t> unicode_set_whitespace;
|
||||
extern const std::vector<uint16_t> unicode_rle_codepoints_categs;
|
||||
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace;
|
||||
extern const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase;
|
||||
extern const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase;
|
||||
extern const std::vector<range_nfd> unicode_ranges_nfd;
|
||||
|
|
625
src/unicode.cpp
625
src/unicode.cpp
|
@ -2,10 +2,10 @@
|
|||
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
|
||||
#endif
|
||||
|
||||
#include "ggml.h"
|
||||
#include "unicode.h"
|
||||
#include "unicode-data.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
|
@ -119,38 +119,6 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
|
|||
// return result;
|
||||
//}
|
||||
|
||||
static std::vector<codepoint_flags> unicode_cpt_flags_array() {
|
||||
std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
|
||||
|
||||
assert (unicode_ranges_flags.front().first == 0);
|
||||
assert (unicode_ranges_flags.back().first == MAX_CODEPOINTS);
|
||||
for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
|
||||
const auto range_ini = unicode_ranges_flags[i-1]; // codepoint_ini, flags
|
||||
const auto range_end = unicode_ranges_flags[i]; // codepoint_end, flags
|
||||
for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
|
||||
cpt_flags[cpt] = range_ini.second;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto cpt : unicode_set_whitespace) {
|
||||
cpt_flags[cpt].is_whitespace = true;
|
||||
}
|
||||
|
||||
for (auto p : unicode_map_lowercase) {
|
||||
cpt_flags[p.second].is_lowercase = true;
|
||||
}
|
||||
|
||||
for (auto p : unicode_map_uppercase) {
|
||||
cpt_flags[p.second].is_uppercase = true;
|
||||
}
|
||||
|
||||
for (auto &range : unicode_ranges_nfd) { // start, last, nfd
|
||||
cpt_flags[range.nfd].is_nfd = true;
|
||||
}
|
||||
|
||||
return cpt_flags;
|
||||
}
|
||||
|
||||
static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
|
||||
std::unordered_map<uint8_t, std::string> map;
|
||||
for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
|
||||
|
@ -233,7 +201,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
|
|||
for (auto offset : offsets) {
|
||||
const size_t offset_ini = start;
|
||||
const size_t offset_end = start + offset;
|
||||
assert(offset_end <= cpts.size());
|
||||
GGML_ASSERT(offset_end <= cpts.size());
|
||||
start = offset_end;
|
||||
|
||||
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
|
||||
|
@ -241,13 +209,14 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
|
|||
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
|
||||
};
|
||||
|
||||
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
|
||||
static const codepoint_categ SENTINEL = codepoint_categ::UNDEF + 1;
|
||||
auto _get_categ = [&] (const size_t pos) -> codepoint_categ {
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_category(cpts[pos]) : SENTINEL;
|
||||
};
|
||||
|
||||
size_t _prev_end = offset_ini;
|
||||
auto _add_token = [&] (const size_t end) -> size_t {
|
||||
assert(_prev_end <= end && end <= offset_end);
|
||||
GGML_ASSERT(_prev_end <= end && end <= offset_end);
|
||||
size_t len = end - _prev_end;
|
||||
if (len > 0) {
|
||||
bpe_offsets.push_back(len);
|
||||
|
@ -264,7 +233,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
|
|||
|
||||
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
||||
const uint32_t cpt = _get_cpt(pos);
|
||||
const auto flags = _get_flags(pos);
|
||||
const auto categ = _get_categ(pos);
|
||||
|
||||
// regex: 's|'t|'re|'ve|'m|'ll|'d
|
||||
if (cpt == '\'' && pos+1 < offset_end) {
|
||||
|
@ -284,37 +253,37 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
|
|||
}
|
||||
}
|
||||
|
||||
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
|
||||
auto categ2 = (cpt == ' ' ? _get_categ(pos+1) : categ);
|
||||
// regex: <space>?\p{L}+
|
||||
if (flags2.is_letter) {
|
||||
if (categ2.is_L()) {
|
||||
pos += (cpt == ' ');
|
||||
while (flags2.is_letter) {
|
||||
flags2 = _get_flags(++pos);
|
||||
while (categ2.is_L()) {
|
||||
categ2 = _get_categ(++pos);
|
||||
}
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
// regex: <space>?\p{N}+
|
||||
if (flags2.is_number) {
|
||||
if (categ2.is_N()) {
|
||||
pos += (cpt == ' ');
|
||||
while (flags2.is_number) {
|
||||
flags2 = _get_flags(++pos);
|
||||
while (categ2.is_N()) {
|
||||
categ2 = _get_categ(++pos);
|
||||
}
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
// regex: <space>?[^\s\p{L}\p{N}]+
|
||||
if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
|
||||
if (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) {
|
||||
pos += (cpt == ' ');
|
||||
while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
|
||||
flags2 = _get_flags(++pos);
|
||||
while (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) {
|
||||
categ2 = _get_categ(++pos);
|
||||
}
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t num_whitespaces = 0;
|
||||
while (_get_flags(pos+num_whitespaces).is_whitespace) {
|
||||
while (_get_categ(pos+num_whitespaces).is_whitespace()) {
|
||||
num_whitespaces++;
|
||||
}
|
||||
|
||||
|
@ -351,7 +320,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||
for (auto offset : offsets) {
|
||||
const size_t offset_ini = start;
|
||||
const size_t offset_end = start + offset;
|
||||
assert(offset_end <= cpts.size());
|
||||
GGML_ASSERT(offset_end <= cpts.size());
|
||||
start = offset_end;
|
||||
|
||||
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
|
||||
|
@ -359,13 +328,14 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
|
||||
};
|
||||
|
||||
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
|
||||
static const codepoint_categ SENTINEL = codepoint_categ::UNDEF + 1;
|
||||
auto _get_categ = [&] (const size_t pos) -> codepoint_categ {
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_category(cpts[pos]) : SENTINEL;
|
||||
};
|
||||
|
||||
size_t _prev_end = offset_ini;
|
||||
auto _add_token = [&] (const size_t end) -> size_t {
|
||||
assert(_prev_end <= end && end <= offset_end);
|
||||
GGML_ASSERT(_prev_end <= end && end <= offset_end);
|
||||
size_t len = end - _prev_end;
|
||||
if (len > 0) {
|
||||
bpe_offsets.push_back(len);
|
||||
|
@ -382,7 +352,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||
|
||||
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
||||
const uint32_t cpt = _get_cpt(pos);
|
||||
const auto flags = _get_flags(pos);
|
||||
const auto categ = _get_categ(pos);
|
||||
|
||||
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
|
||||
if (cpt == '\'' && pos+1 < offset_end) {
|
||||
|
@ -403,10 +373,10 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||
}
|
||||
|
||||
// regex: [^\r\n\p{L}\p{N}]?\p{L}+
|
||||
if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) {
|
||||
if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
|
||||
if (!(cpt == '\r' || cpt == '\n' || categ.is_N())) {
|
||||
if (categ.is_L() || _get_categ(pos+1).is_L()) { // one or more letters
|
||||
pos++;
|
||||
while (_get_flags(pos).is_letter) {
|
||||
while (_get_categ(pos).is_L()) {
|
||||
pos++;
|
||||
}
|
||||
_add_token(pos);
|
||||
|
@ -415,9 +385,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||
}
|
||||
|
||||
// regex: \p{N}{1,3}
|
||||
if (flags.is_number) {
|
||||
if (categ.is_N()) {
|
||||
size_t ini = pos;
|
||||
while (_get_flags(pos).is_number) {
|
||||
while (_get_categ(pos).is_N()) {
|
||||
if (++pos - ini >= 3 ) {
|
||||
_add_token(pos);
|
||||
ini = pos;
|
||||
|
@ -428,11 +398,11 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||
}
|
||||
|
||||
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
|
||||
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
|
||||
if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) {
|
||||
auto categ2 = (cpt == ' ' ? _get_categ(pos+1) : categ);
|
||||
if (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) {
|
||||
pos += (cpt == ' ');
|
||||
while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
|
||||
flags2 = _get_flags(++pos);
|
||||
while (!(categ2.is_whitespace() | categ2.is_L() | categ2.is_N()) && categ2 != SENTINEL) {
|
||||
categ2 = _get_categ(++pos);
|
||||
}
|
||||
uint32_t cpt2 = _get_cpt(pos);
|
||||
while (cpt2 == '\r' || cpt2 == '\n') {
|
||||
|
@ -444,7 +414,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||
|
||||
size_t num_whitespaces = 0;
|
||||
size_t last_end_r_or_n = 0;
|
||||
while (_get_flags(pos+num_whitespaces).is_whitespace) {
|
||||
while (_get_categ(pos+num_whitespaces).is_whitespace()) {
|
||||
uint32_t cpt2 = _get_cpt(pos+num_whitespaces);
|
||||
if (cpt2 == '\r' || cpt2 == '\n') {
|
||||
last_end_r_or_n = pos + num_whitespaces + 1;
|
||||
|
@ -481,66 +451,6 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||
return bpe_offsets;
|
||||
}
|
||||
|
||||
// use std::wregex to split the text
|
||||
static std::vector<size_t> unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector<size_t> & offsets) {
|
||||
std::wregex expr(regex_expr);
|
||||
std::vector<size_t> bpe_offsets; // store the offset of each word
|
||||
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
||||
size_t start = 0;
|
||||
for (auto offset : offsets) {
|
||||
std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
|
||||
std::wcregex_iterator end;
|
||||
|
||||
int64_t start_idx = 0;
|
||||
while (it != end) {
|
||||
std::wcmatch match = *it;
|
||||
if (match.position() > start_idx) {
|
||||
bpe_offsets.emplace_back(match.position() - start_idx);
|
||||
}
|
||||
bpe_offsets.emplace_back(match.length());
|
||||
start_idx = match.position() + match.length();
|
||||
++it;
|
||||
}
|
||||
|
||||
if (start_idx < (int64_t) offset) {
|
||||
bpe_offsets.emplace_back(offset - start_idx);
|
||||
}
|
||||
start += offset;
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
}
|
||||
|
||||
// use std::regex to split the text
|
||||
static std::vector<size_t> unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
|
||||
std::regex expr(regex_expr);
|
||||
std::vector<size_t> bpe_offsets; // store the offset of each word
|
||||
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
||||
size_t start = 0;
|
||||
for (auto offset : offsets) {
|
||||
std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
|
||||
std::cregex_iterator end;
|
||||
|
||||
int64_t start_idx = 0;
|
||||
while (it != end) {
|
||||
std::cmatch match = *it;
|
||||
if (match.position() > start_idx) {
|
||||
bpe_offsets.emplace_back(match.position() - start_idx);
|
||||
}
|
||||
bpe_offsets.emplace_back(match.length());
|
||||
start_idx = match.position() + match.length();
|
||||
++it;
|
||||
}
|
||||
|
||||
if (start_idx < (int64_t) offset) {
|
||||
bpe_offsets.emplace_back(offset - start_idx);
|
||||
}
|
||||
start += offset;
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
}
|
||||
|
||||
static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
|
||||
std::vector<size_t> bpe_offsets;
|
||||
|
||||
|
@ -556,6 +466,269 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
|
|||
return bpe_offsets;
|
||||
}
|
||||
|
||||
// Custom std::regex specializations for 32bit unicode codepoints
|
||||
// std::wregex does not support unicode categories: \p{N}, \p{L}, \p{Lu}, \p{Ll} ...
|
||||
// std::wregex does not support unicode whitespaces \s: 0x85, 0xA0, 0x001680 ... 0x003000.
|
||||
// std::wregex supports full 32 bit codepoints, not limited to standard max 0x110000.
|
||||
namespace std {
|
||||
|
||||
// codepoint type for all template specializations
|
||||
#if (WCHAR_MAX > 0xFFFF)
|
||||
using codepoint = wchar_t; // sizeof(wchar_t) == 4
|
||||
#else
|
||||
using codepoint = uint32_t; // Windows: sizeof(wchar_t) == 2
|
||||
#define CUSTOM_CTYPE_CODEPOINT
|
||||
#endif
|
||||
|
||||
#ifdef CUSTOM_CTYPE_CODEPOINT
|
||||
// Minimal required implementation for std::regex string processing
|
||||
template<> // custom specialized std::ctype<codepoint>
|
||||
class ctype<codepoint> {
|
||||
public:
|
||||
|
||||
using CharT = codepoint;
|
||||
using char_type = CharT;
|
||||
|
||||
using mask = uint8_t; //NOTE: see std::ctype_base
|
||||
static const mask digit = 1; // requiered variable names
|
||||
static const mask xdigit = 2; // user defined values
|
||||
static const mask alpha = 3; // used to be a bitmask
|
||||
static const mask upper = 4; // we do not need a bitmask
|
||||
static const mask lower = 5; // using a sequence instead
|
||||
|
||||
static locale::id id; // required by std::locale::facet
|
||||
|
||||
bool is(mask m, char_type c) const {
|
||||
switch (m) {
|
||||
case digit: return ('0' <= c && c <= '9');
|
||||
case xdigit: return ('0' <= c && c <= '9') || ('A' <= c && c <= 'F');
|
||||
case alpha: return ('A' <= c && c <= 'Z') || ('a' <= c && c <= 'z');
|
||||
case upper: return ('A' <= c && c <= 'Z');
|
||||
case lower: return ('a' <= c && c <= 'z');
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
||||
char_type toupper(char_type c) const {
|
||||
return ('a' <= c && c <= 'z') ? c - ('a' - 'A') : c;
|
||||
}
|
||||
|
||||
char_type tolower(char_type c) const {
|
||||
return ('A' <= c && c <= 'Z') ? c + ('a' - 'A') : c;
|
||||
}
|
||||
|
||||
char_type widen(char c) const { // char to codepoint
|
||||
return (char_type) c;
|
||||
}
|
||||
|
||||
char narrow(char_type c, char dfault) const { // codepoint to char
|
||||
return (c < 0x80 ? (char)c : dfault);
|
||||
}
|
||||
};
|
||||
|
||||
locale::id ctype<codepoint>::id = {};
|
||||
|
||||
template<> // specialization to use our custom specialized std::ctype<codepoint>
|
||||
const std::ctype<codepoint> & use_facet<std::ctype<codepoint>>(const std::locale &) {
|
||||
static std::ctype<codepoint> ctype_uint32 = {};
|
||||
return ctype_uint32;
|
||||
}
|
||||
|
||||
template<> // specialization to use our custom specialized std::ctype<codepoint>
|
||||
const std::ctype<codepoint> & use_facet<const std::ctype<codepoint>>(const std::locale & loc) {
|
||||
return use_facet<std::ctype<codepoint>>(loc);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Minimal required implementation for std::regex string processing
|
||||
template<> // custom specialized std::regex_traits<codepoint>
|
||||
class regex_traits<codepoint> {
|
||||
public:
|
||||
|
||||
using CharT = codepoint;
|
||||
using char_type = codepoint;
|
||||
using size_type = size_t;
|
||||
using string_type = std::basic_string<CharT>;
|
||||
using locale_type = std::locale;
|
||||
using char_class_type = uint64_t;
|
||||
|
||||
#if (defined(_WIN32) || defined(_WIN64)) // MSVC class _Regex_traits
|
||||
using _Uelem = CharT;
|
||||
static const auto _Ch_upper = std::ctype<CharT>::upper;
|
||||
static const auto _Ch_alpha = std::ctype<CharT>::alpha;
|
||||
#endif
|
||||
|
||||
CharT translate(CharT c) const {
|
||||
return c;
|
||||
}
|
||||
|
||||
CharT translate_nocase(CharT c) const {
|
||||
return unicode_tolower(c);
|
||||
}
|
||||
|
||||
template<typename It>
|
||||
string_type transform(It first, It last) const {
|
||||
GGML_ASSERT(false); //TODO: not needed ?
|
||||
return {first, last}; //TODO: not tested
|
||||
}
|
||||
|
||||
template<typename It>
|
||||
string_type transform_primary(It first, It last) const {
|
||||
(void) first;
|
||||
(void) last;
|
||||
GGML_ASSERT((uint32_t) *first < MAX_CODEPOINTS); // check valid codepoint
|
||||
return {};
|
||||
}
|
||||
|
||||
template<typename It>
|
||||
string_type lookup_collatename(It first, It last) const {
|
||||
(void) last;
|
||||
GGML_ASSERT(*first & (1 << 31));
|
||||
return {*first};
|
||||
}
|
||||
|
||||
template<typename It>
|
||||
char_class_type lookup_classname(It first, It last, bool icase = false) const {
|
||||
(void) last;
|
||||
(void) icase;
|
||||
const uint32_t encoded = *first;
|
||||
codepoint_categ categ = {};
|
||||
switch(encoded) {
|
||||
case 's':
|
||||
case 'S': // negation is internally tracked
|
||||
categ.set_flag(codepoint_categ::WHITESPACES);
|
||||
return categ.expand_bits();
|
||||
case 'w':
|
||||
case 'W': // negation is internally tracked
|
||||
categ.set_flag(codepoint_categ::WORDS);
|
||||
return categ.expand_bits();
|
||||
case 'd':
|
||||
case 'D': // negation is internally tracked
|
||||
categ.set_flag(codepoint_categ::DIGITS);
|
||||
return categ.expand_bits();
|
||||
default: { // unicode category \p{Xx} encoded in codepoint
|
||||
GGML_ASSERT(encoded & (1 << 31)); // make sure its our custom codepoint encoding the category
|
||||
const bool negated = encoded & (1 << 30); // negation of 'character class expression' are not internally tracked
|
||||
categ = {(uint16_t) encoded};
|
||||
return ((uint64_t) negated << 63) | categ.expand_bits(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool isctype(CharT c, char_class_type mask) const {
|
||||
const bool negated = mask & (1llu << 63);
|
||||
mask &= unicode_cpt_category(c).expand_bits();
|
||||
return negated ^ (bool) mask;
|
||||
}
|
||||
|
||||
int value(CharT c, int radix) const { // char to int value
|
||||
switch (radix) {
|
||||
case 8: return ('0' <= c && c <= '7') ? (int)c - '0' : -1;
|
||||
case 10: return ('0' <= c && c <= '9') ? (int)c - '0' : -1;
|
||||
case 16: return ('0' <= c && c <= '9') ? (int)c - '0' : (('A' <= c && c <= 'F') ? (int)c - 'A' + 10 : -1);
|
||||
default: return -1;
|
||||
}
|
||||
}
|
||||
|
||||
const locale_type & imbue(const locale_type &) { // set locale //NOTE: ignoring locales
|
||||
return std::locale::classic();
|
||||
}
|
||||
|
||||
const locale_type & getloc() const { // get locale //NOTE: ignoring locales
|
||||
return std::locale::classic();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
static std::vector<uint32_t> unicode_regex_prepare(const std::string & regex) {
|
||||
std::vector<uint32_t> regex_cpts;
|
||||
regex_cpts.reserve(regex.size() * 12 / 10); // estimate +20%
|
||||
|
||||
size_t offset = 0;
|
||||
int inside_square = 0;
|
||||
bool any_positive = false;
|
||||
bool any_negative = false;
|
||||
|
||||
const size_t size = regex.size();
|
||||
while (offset < size) {
|
||||
inside_square += regex[offset] == '[';
|
||||
inside_square -= regex[offset] == ']';
|
||||
GGML_ASSERT(inside_square >= 0);
|
||||
if (!inside_square) {
|
||||
any_positive = false;
|
||||
any_negative = false;
|
||||
}
|
||||
|
||||
if (regex[offset] == '\\') {
|
||||
const size_t i = offset + 1;
|
||||
if (regex[i] == 'p' || regex[i] == 'P') {
|
||||
// convert \p{Xx} to custom 'character class expression' [:Xy:]
|
||||
if (regex[i + 1] == '{' && regex[i + 2] && regex[i + 3]) {
|
||||
codepoint_categ categ = {};
|
||||
if (regex[i + 3] == '}') {
|
||||
categ = codepoint_categ::from_chars(regex[i + 2]);
|
||||
offset += 5;
|
||||
} else if (regex[i + 3] != '}' && regex[i + 4] == '}') {
|
||||
categ = codepoint_categ::from_chars(regex[i + 2], regex[i + 3]);
|
||||
offset += 6;
|
||||
}
|
||||
bool negated = regex[i] == 'P';
|
||||
any_positive |= !negated;
|
||||
any_negative |= negated;
|
||||
GGML_ASSERT(any_positive != any_negative); //BUG: can not mix 'p' and 'P' inside []
|
||||
GGML_ASSERT(sizeof(categ) <= 2);
|
||||
// encoded category in 32 bits codepoint
|
||||
uint32_t cpt_categ = (1 << 31) | (negated << 30) | categ.encoded;
|
||||
if (inside_square) {
|
||||
regex_cpts.insert(regex_cpts.end(), {'[', ':', cpt_categ, ':', ']'});
|
||||
} else {
|
||||
regex_cpts.insert(regex_cpts.end(), {'[', '[', ':', cpt_categ, ':', ']', ']'});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
regex_cpts.push_back(unicode_cpt_from_utf8(regex, offset));
|
||||
}
|
||||
|
||||
return regex_cpts;
|
||||
}
|
||||
|
||||
// use std::basic_regex<uint32_t> to split the text codepoints
|
||||
static std::vector<size_t> unicode_regex_split_stl(const std::vector<uint32_t> & text_cpts, const std::vector<uint32_t> & regex_cpts, const std::vector<size_t> & offsets) {
|
||||
GGML_ASSERT(sizeof(std::codepoint) == sizeof(uint32_t));
|
||||
using regex_type = std::basic_regex<std::codepoint>;
|
||||
using iter_type = std::regex_iterator<const std::codepoint *>;
|
||||
|
||||
const std::codepoint * text_data = (const std::codepoint *) text_cpts.data();
|
||||
const std::codepoint * regex_data = (const std::codepoint *) regex_cpts.data();
|
||||
regex_type regex(regex_data, regex_data+regex_cpts.size());
|
||||
const iter_type end;
|
||||
|
||||
std::vector<size_t> bpe_offsets; // store the offset of each word
|
||||
bpe_offsets.reserve(offsets.size()); // reserve memory for the approximate size
|
||||
for (auto offset : offsets) {
|
||||
iter_type it(text_data, text_data + offset, regex);
|
||||
int64_t start_idx = 0;
|
||||
while (it != end) {
|
||||
if (it->position() > start_idx) {
|
||||
bpe_offsets.emplace_back(it->position() - start_idx);
|
||||
}
|
||||
bpe_offsets.emplace_back(it->length());
|
||||
start_idx = it->position() + it->length();
|
||||
++it;
|
||||
}
|
||||
|
||||
if (start_idx < (int64_t) offset) {
|
||||
bpe_offsets.emplace_back(offset - start_idx);
|
||||
}
|
||||
text_data += offset;
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
}
|
||||
|
||||
//
|
||||
// interface
|
||||
//
|
||||
|
@ -612,19 +785,46 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
|
|||
return result;
|
||||
}
|
||||
|
||||
codepoint_flags unicode_cpt_flags(const uint32_t cp) {
|
||||
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
||||
static const auto cpt_flags = unicode_cpt_flags_array();
|
||||
return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
|
||||
codepoint_categ unicode_cpt_category(const uint32_t cp) {
|
||||
static const std::vector<codepoint_categ> cpt_categs = [] {
|
||||
std::vector<codepoint_categ> cpt_categs(MAX_CODEPOINTS, codepoint_categ::UNDEF);
|
||||
uint32_t cpt = 0;
|
||||
for (uint16_t rle : unicode_rle_codepoints_categs) {
|
||||
const uint32_t index = rle & 31;
|
||||
const uint32_t count = rle >> 5;
|
||||
auto categ = codepoint_categ::from_index(index);
|
||||
//printf("Codepoints 0x%05X to 0x%05X categ %s\n", cpt, cpt + count, categ.c_str());
|
||||
categ.set_flag(codepoint_categ::DIGITS, categ.is_Nd()); // \d --> \p{Nd}
|
||||
categ.set_flag(codepoint_categ::WORDS, categ.is_L() | categ.is_N()); // \w --> \p{L} \p{N} _
|
||||
for (uint32_t i = 0; i <= count; ++i) {
|
||||
cpt_categs[cpt++] = categ;
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(cpt == MAX_CODEPOINTS);
|
||||
|
||||
cpt_categs['_'].set_flag(codepoint_categ::WORDS); // \w --> \p{L} \p{N} _
|
||||
|
||||
for (auto p : unicode_ranges_whitespace) {
|
||||
for (uint32_t cpt = p.first; cpt <= p.second; ++cpt) {
|
||||
cpt_categs[cpt].set_flag(codepoint_categ::WHITESPACES);
|
||||
}
|
||||
}
|
||||
|
||||
//for (auto &range : unicode_ranges_nfd) { // start, last, nfd
|
||||
// cpt_categs[cpt].set_flag(codepoint_categ::NORM_NFD);
|
||||
//}
|
||||
|
||||
return cpt_categs;
|
||||
}();
|
||||
return cp < cpt_categs.size() ? cpt_categs[cp] : codepoint_categ{};
|
||||
}
|
||||
|
||||
codepoint_flags unicode_cpt_flags(const std::string & utf8) {
|
||||
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
||||
codepoint_categ unicode_cpt_category(const std::string & utf8) {
|
||||
if (utf8.empty()) {
|
||||
return undef; // undefined
|
||||
return codepoint_categ{}; // undefined
|
||||
}
|
||||
size_t offset = 0;
|
||||
return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset));
|
||||
return unicode_cpt_category(unicode_cpt_from_utf8(utf8, offset));
|
||||
}
|
||||
|
||||
std::string unicode_byte_to_utf8(uint8_t byte) {
|
||||
|
@ -642,171 +842,28 @@ uint32_t unicode_tolower(uint32_t cp) {
|
|||
return it == unicode_map_lowercase.end() ? cp : it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
|
||||
// unicode categories
|
||||
static const std::map<std::string, int> k_ucat_enum = {
|
||||
{ "\\p{N}", codepoint_flags::NUMBER },
|
||||
{ "\\p{L}", codepoint_flags::LETTER },
|
||||
{ "\\p{P}", codepoint_flags::PUNCTUATION },
|
||||
};
|
||||
|
||||
static const std::map<int, int> k_ucat_cpt = {
|
||||
{ codepoint_flags::NUMBER, 0xD1 },
|
||||
{ codepoint_flags::LETTER, 0xD2 },
|
||||
{ codepoint_flags::PUNCTUATION, 0xD3 },
|
||||
};
|
||||
|
||||
static const std::map<int, std::string> k_ucat_map = {
|
||||
{ codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
|
||||
{ codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
||||
{ codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
||||
};
|
||||
|
||||
// compute collapsed codepoints only if needed by at least one regex
|
||||
bool need_collapse = false;
|
||||
for (auto & regex_expr : regex_exprs) {
|
||||
// search for unicode categories
|
||||
for (const auto & ucat : k_ucat_enum) {
|
||||
if (std::string::npos != regex_expr.find(ucat.first)) {
|
||||
need_collapse = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto cpts = unicode_cpts_from_utf8(text);
|
||||
|
||||
// generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
|
||||
std::string text_collapsed;
|
||||
if (need_collapse) {
|
||||
// collapse all unicode categories
|
||||
text_collapsed.resize(cpts.size());
|
||||
|
||||
for (size_t i = 0; i < cpts.size(); ++i) {
|
||||
// keep single-byte codepoints as is
|
||||
if (cpts[i] < 128) {
|
||||
text_collapsed[i] = cpts[i];
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto flags = unicode_cpt_flags(cpts[i]);
|
||||
|
||||
if (flags.is_whitespace) {
|
||||
//NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
|
||||
//text_collapsed[i] = (char) 0x85; // <Next Line> as whitespace fallback
|
||||
text_collapsed[i] = (char) 0x0B; // <vertical tab> as whitespace fallback
|
||||
} else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
|
||||
text_collapsed[i] = k_ucat_cpt.at(flags.category_flag());
|
||||
} else {
|
||||
text_collapsed[i] = (char) 0xD0; // fallback
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<size_t> bpe_offsets = { cpts.size() };
|
||||
std::vector<std::string> unicode_regex_split(const std::string & text_utf8, const std::vector<std::string> & regex_exprs) {
|
||||
const std::vector<uint32_t> cpts = unicode_cpts_from_utf8(text_utf8);
|
||||
std::vector<size_t> offsets = { cpts.size() };
|
||||
|
||||
for (auto & regex_expr : regex_exprs) {
|
||||
// first, see if we have an efficient custom regex implementation
|
||||
auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets);
|
||||
auto tmp = unicode_regex_split_custom(text_utf8, regex_expr, offsets);
|
||||
|
||||
if (!tmp.empty()) {
|
||||
bpe_offsets = std::move(tmp);
|
||||
offsets = std::move(tmp);
|
||||
continue;
|
||||
}
|
||||
|
||||
// fallback to general-purpose std::regex / std::wregex
|
||||
try {
|
||||
// if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
|
||||
// with the corresponding collapsed representation
|
||||
bool use_collapsed = false;
|
||||
for (auto & ucat : k_ucat_enum) {
|
||||
if (std::string::npos != regex_expr.find(ucat.first)) {
|
||||
use_collapsed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (use_collapsed) {
|
||||
// sanity-check that the original regex does not contain any non-ASCII characters
|
||||
const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
|
||||
for (size_t i = 0; i < cpts_regex.size(); ++i) {
|
||||
if (cpts_regex[i] >= 128) {
|
||||
throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported");
|
||||
}
|
||||
}
|
||||
|
||||
// generate a collapsed representation of the regex
|
||||
std::string regex_expr_collapsed;
|
||||
|
||||
// track if we are inside [], because nested [] are not allowed
|
||||
bool inside = false;
|
||||
for (size_t i = 0; i < regex_expr.size(); ++i) {
|
||||
if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) {
|
||||
regex_expr_collapsed += '[';
|
||||
inside = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') {
|
||||
regex_expr_collapsed += ']';
|
||||
inside = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() &&
|
||||
regex_expr[i + 1] == 'p' &&
|
||||
regex_expr[i + 2] == '{' &&
|
||||
regex_expr[i + 4] == '}') {
|
||||
const std::string pat = regex_expr.substr(i, 5);
|
||||
if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
|
||||
if (!inside) {
|
||||
regex_expr_collapsed += '[';
|
||||
}
|
||||
regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
|
||||
regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
|
||||
if (!inside) {
|
||||
regex_expr_collapsed += ']';
|
||||
}
|
||||
i += 4;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
regex_expr_collapsed += regex_expr[i];
|
||||
}
|
||||
|
||||
//printf("text_collapsed: %s\n", text_collapsed.c_str());
|
||||
//printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str());
|
||||
bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
|
||||
} else {
|
||||
// no unicode category used, we can use std::wregex directly
|
||||
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
|
||||
|
||||
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
|
||||
std::wstring wtext(cpts.begin(), cpts.end());
|
||||
for (size_t i = 0; i < wtext.size(); ++i) {
|
||||
if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) {
|
||||
wtext[i] = 0x0B;
|
||||
}
|
||||
}
|
||||
|
||||
//printf("text: %s\n", text.c_str());
|
||||
//printf("regex_expr: %s\n", regex_expr.c_str());
|
||||
bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
|
||||
}
|
||||
} catch (std::regex_error & e) {
|
||||
fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str());
|
||||
fprintf(stderr, "Regex error: %s\n", e.what());
|
||||
throw std::runtime_error("Failed to process regex");
|
||||
}
|
||||
const auto regex_cpts = unicode_regex_prepare(regex_expr);
|
||||
offsets = unicode_regex_split_stl(cpts, regex_cpts, offsets);
|
||||
}
|
||||
|
||||
std::vector<std::string> bpe_words;
|
||||
bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size
|
||||
bpe_words.reserve(offsets.size()); // reserve memory for the approximate size
|
||||
|
||||
size_t start = 0;
|
||||
for (size_t & offset : bpe_offsets) {
|
||||
for (size_t & offset : offsets) {
|
||||
bpe_words.emplace_back();
|
||||
for (size_t i = start; i < start + offset; ++i) {
|
||||
bpe_words.back() += unicode_cpt_to_utf8(cpts[i]);
|
||||
|
|
206
src/unicode.h
206
src/unicode.h
|
@ -1,51 +1,185 @@
|
|||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include <map>
|
||||
|
||||
// TODO: prefix all symbols with "llama_"
|
||||
|
||||
struct codepoint_flags {
|
||||
enum {
|
||||
UNDEFINED = 0x0001,
|
||||
NUMBER = 0x0002, // regex: \p{N}
|
||||
LETTER = 0x0004, // regex: \p{L}
|
||||
SEPARATOR = 0x0008, // regex: \p{Z}
|
||||
ACCENT_MARK = 0x0010, // regex: \p{M}
|
||||
PUNCTUATION = 0x0020, // regex: \p{P}
|
||||
SYMBOL = 0x0040, // regex: \p{S}
|
||||
CONTROL = 0x0080, // regex: \p{C}
|
||||
MASK_CATEGORIES = 0x00FF,
|
||||
struct codepoint_categ {
|
||||
// 0bffffff'ccccccc'sss --> 6 bits flags + 7 bits category + 3 bits subcategory
|
||||
enum _category : uint16_t {
|
||||
UNDEF = 0, // \p{Cn} Undefined
|
||||
C = 1 << (0 + 3), // \p{C} Control
|
||||
L = 1 << (1 + 3), // \p{L} Letter
|
||||
M = 1 << (2 + 3), // \p{M} Mark
|
||||
N = 1 << (3 + 3), // \p{N} Number
|
||||
P = 1 << (4 + 3), // \p{P} Punctuation
|
||||
S = 1 << (5 + 3), // \p{S} Symbol
|
||||
Z = 1 << (6 + 3), // \p{Z} Separator
|
||||
Cc = C | 1, // \p{Cc} Control
|
||||
Cf = C | 2, // \p{Cf} Format
|
||||
Co = C | 3, // \p{Co} Private Use
|
||||
Cs = C | 4, // \p{Cs} Surrrogate
|
||||
Ll = L | 1, // \p{Ll} Lowercase Letter
|
||||
Lm = L | 2, // \p{Lm} Modifier Letter
|
||||
Lo = L | 3, // \p{Lo} Other Letter
|
||||
Lt = L | 4, // \p{Lt} Titlecase Letter
|
||||
Lu = L | 5, // \p{Lu} Uppercase Letter
|
||||
Mc = M | 1, // \p{Mc} Spacing Mark
|
||||
Me = M | 2, // \p{Me} Enclosing Mark
|
||||
Mn = M | 3, // \p{Mn} Nonspacing Mark
|
||||
Nd = N | 1, // \p{Nd} Decimal Number
|
||||
Nl = N | 2, // \p{Nl} Letter Number
|
||||
No = N | 3, // \p{No} Other Number
|
||||
Pc = P | 1, // \p{Pc} Connector Punctuation
|
||||
Pd = P | 2, // \p{Pd} Dash Punctuation
|
||||
Pe = P | 3, // \p{Pe} Close Punctuation
|
||||
Pf = P | 4, // \p{Pf} Final Punctuation
|
||||
Pi = P | 5, // \p{Pi} Initial Punctuation
|
||||
Po = P | 6, // \p{Po} Other Punctuation
|
||||
Ps = P | 7, // \p{Ps} Open Punctuation
|
||||
Sc = S | 1, // \p{Sc} Currency Symbol
|
||||
Sk = S | 2, // \p{Sk} Modifier Symbol
|
||||
Sm = S | 3, // \p{Sm} Math Symbol
|
||||
So = S | 4, // \p{So} Other Symbol
|
||||
Zl = Z | 1, // \p{Zl} Line Separator
|
||||
Zp = Z | 2, // \p{Zp} Paragraph Separator
|
||||
Zs = Z | 3, // \p{Zs} Space Separator
|
||||
SUBMASK = (1 << 3) - 1, // 3 bits 0b000000'0000000'111
|
||||
MASK = (1 << 10) - 1, // 7+3 bits 0b000000'1111111'111
|
||||
};
|
||||
|
||||
// codepoint type
|
||||
uint16_t is_undefined : 1;
|
||||
uint16_t is_number : 1; // regex: \p{N}
|
||||
uint16_t is_letter : 1; // regex: \p{L}
|
||||
uint16_t is_separator : 1; // regex: \p{Z}
|
||||
uint16_t is_accent_mark : 1; // regex: \p{M}
|
||||
uint16_t is_punctuation : 1; // regex: \p{P}
|
||||
uint16_t is_symbol : 1; // regex: \p{S}
|
||||
uint16_t is_control : 1; // regex: \p{C}
|
||||
// helper flags
|
||||
uint16_t is_whitespace : 1; // regex: \s
|
||||
uint16_t is_lowercase : 1;
|
||||
uint16_t is_uppercase : 1;
|
||||
uint16_t is_nfd : 1;
|
||||
enum _flags : uint16_t {
|
||||
WHITESPACES = (1 << 10), // regex: \s
|
||||
WORDS = (1 << 11), // regex: \w
|
||||
DIGITS = (1 << 12), // regex: \d
|
||||
//Norm NFD/NFC = ...,
|
||||
};
|
||||
|
||||
// decode from uint16
|
||||
inline codepoint_flags(const uint16_t flags=0) {
|
||||
*reinterpret_cast<uint16_t*>(this) = flags;
|
||||
inline codepoint_categ(const uint16_t categ=0) : encoded{categ} {}
|
||||
|
||||
inline void set_flag(_flags flags, bool value = true) {
|
||||
flags = (_flags) (flags & ~MASK); // do not modify category bits
|
||||
encoded = value ? (encoded | flags) : (encoded & ~flags);
|
||||
}
|
||||
|
||||
inline uint16_t as_uint() const {
|
||||
return *reinterpret_cast<const uint16_t*>(this);
|
||||
inline uint16_t get_category() const { return encoded & MASK; }
|
||||
|
||||
inline bool is_undefined() const { return !encoded; }
|
||||
inline bool is_defined() const { return encoded; }
|
||||
|
||||
inline uint16_t is_whitespace() const { return encoded & WHITESPACES; }
|
||||
inline uint16_t is_word() const { return encoded & WORDS; }
|
||||
inline uint16_t is_digit() const { return encoded & DIGITS; }
|
||||
|
||||
inline uint16_t is_C() const { return encoded & C; }
|
||||
inline uint16_t is_L() const { return encoded & L; }
|
||||
inline uint16_t is_M() const { return encoded & M; }
|
||||
inline uint16_t is_N() const { return encoded & N; }
|
||||
inline uint16_t is_P() const { return encoded & P; }
|
||||
inline uint16_t is_S() const { return encoded & S; }
|
||||
inline uint16_t is_Z() const { return encoded & Z; }
|
||||
|
||||
inline bool is_Cc() const { return (encoded & MASK) == Cc; }
|
||||
inline bool is_Cf() const { return (encoded & MASK) == Cf; }
|
||||
inline bool is_Co() const { return (encoded & MASK) == Co; }
|
||||
inline bool is_Cs() const { return (encoded & MASK) == Cs; }
|
||||
inline bool is_Ll() const { return (encoded & MASK) == Ll; }
|
||||
inline bool is_Lm() const { return (encoded & MASK) == Lm; }
|
||||
inline bool is_Lo() const { return (encoded & MASK) == Lo; }
|
||||
inline bool is_Lt() const { return (encoded & MASK) == Lt; }
|
||||
inline bool is_Lu() const { return (encoded & MASK) == Lu; }
|
||||
inline bool is_Mc() const { return (encoded & MASK) == Mc; }
|
||||
inline bool is_Me() const { return (encoded & MASK) == Me; }
|
||||
inline bool is_Mn() const { return (encoded & MASK) == Mn; }
|
||||
inline bool is_Nd() const { return (encoded & MASK) == Nd; }
|
||||
inline bool is_Nl() const { return (encoded & MASK) == Nl; }
|
||||
inline bool is_No() const { return (encoded & MASK) == No; }
|
||||
inline bool is_Pc() const { return (encoded & MASK) == Pc; }
|
||||
inline bool is_Pd() const { return (encoded & MASK) == Pd; }
|
||||
inline bool is_Pe() const { return (encoded & MASK) == Pe; }
|
||||
inline bool is_Pf() const { return (encoded & MASK) == Pf; }
|
||||
inline bool is_Pi() const { return (encoded & MASK) == Pi; }
|
||||
inline bool is_Po() const { return (encoded & MASK) == Po; }
|
||||
inline bool is_Ps() const { return (encoded & MASK) == Ps; }
|
||||
inline bool is_Sc() const { return (encoded & MASK) == Sc; }
|
||||
inline bool is_Sk() const { return (encoded & MASK) == Sk; }
|
||||
inline bool is_Sm() const { return (encoded & MASK) == Sm; }
|
||||
inline bool is_So() const { return (encoded & MASK) == So; }
|
||||
inline bool is_Zl() const { return (encoded & MASK) == Zl; }
|
||||
inline bool is_Zp() const { return (encoded & MASK) == Zp; }
|
||||
inline bool is_Zs() const { return (encoded & MASK) == Zs; }
|
||||
|
||||
inline uint64_t expand_bits(const bool add_categ=true) const { // one bit for each category/subcateory and flags
|
||||
const uint32_t subindex = encoded & SUBMASK;
|
||||
const uint64_t bits = (encoded & MASK) >> 3;
|
||||
const uint64_t flags = encoded >> 10;
|
||||
return (flags << (7 * 8)) | (bits << (7 * subindex)) | (bits * add_categ);
|
||||
}
|
||||
|
||||
inline uint16_t category_flag() const {
|
||||
return this->as_uint() & MASK_CATEGORIES;
|
||||
inline bool is_in_range(const codepoint_categ other) const { // this.first <= other <= this.last
|
||||
if (encoded & SUBMASK) {
|
||||
return encoded == other.encoded; // no range
|
||||
}
|
||||
if (encoded & MASK) {
|
||||
return encoded == (other.encoded & ~SUBMASK); // from 0bffffff'ccccccc'000 to 0bffffff'ccccccc'111
|
||||
}
|
||||
return encoded == (other.encoded & ~MASK); // from 0bffffff'0000000'000 to 0bffffff'1111111'111
|
||||
}
|
||||
|
||||
inline bool operator == (const codepoint_categ other) const {
|
||||
return encoded == other.encoded;
|
||||
}
|
||||
|
||||
inline bool operator != (const codepoint_categ other) const {
|
||||
return encoded != other.encoded;
|
||||
}
|
||||
|
||||
const char * c_str() const {
|
||||
static const std::map<uint16_t, const char *> map = {
|
||||
{UNDEF, "UNDEF"}, {C, "C"}, {L, "L"}, {M, "M"}, {N, "N"}, {P, "P"}, {S, "S"}, {Z, "Z"},
|
||||
{Cc, "Cc"}, {Cf, "Cf"}, {Co, "Co"}, {Cs, "Cs"}, {Ll, "Ll"}, {Lm, "Lm"}, {Lo, "Lo"}, {Lt, "Lt"},
|
||||
{Lu, "Lu"}, {Mc, "Mc"}, {Me, "Me"}, {Mn, "Mn"}, {Nd, "Nd"}, {Nl, "Nl"}, {No, "No"}, {Pc, "Pc"},
|
||||
{Pd, "Pd"}, {Pe, "Pe"}, {Pf, "Pf"}, {Pi, "Pi"}, {Po, "Po"}, {Ps, "Ps"}, {Sc, "Sc"}, {Sk, "Sk"},
|
||||
{Sm, "Sm"}, {So, "So"}, {Zl, "Zl"}, {Zp, "Zp"}, {Zs, "Zs"},
|
||||
};
|
||||
const auto it = map.find(encoded & MASK);
|
||||
return it == map.end() ? "INVALID" : it->second;
|
||||
}
|
||||
|
||||
static codepoint_categ from_index(int index) {
|
||||
static const std::array<codepoint_categ, 32> table = {
|
||||
UNDEF, Cc, Cf, Co, Cs, Ll, Lm, Lo, Lt, Lu, Mc, Me, Mn, Nd, Nl, No, Pc, Pd, Pe, Pf, Pi, Po, Ps, Sc, Sk, Sm, So, Zl, Zp, Zs, UNDEF, UNDEF
|
||||
};
|
||||
return (size_t)index < table.size() ? table[index] : table[0];
|
||||
}
|
||||
|
||||
static codepoint_categ from_chars(const char categ, const char subcateg = '\0') {
|
||||
auto _subindex = [] (const char subcateg, const char subcategs[]) -> uint16_t {
|
||||
if (!subcateg) {
|
||||
return 0;
|
||||
}
|
||||
const char * p = strchr(subcategs, subcateg);
|
||||
GGML_ASSERT(p);
|
||||
return (uint16_t) (p - subcategs + 1);
|
||||
};
|
||||
switch(categ) {
|
||||
case 'C': if(subcateg == 'n') return 0; // undefined
|
||||
return C | _subindex(subcateg, "cfos" );
|
||||
case 'L': return L | _subindex(subcateg, "lmotu" );
|
||||
case 'M': return M | _subindex(subcateg, "cen" );
|
||||
case 'N': return N | _subindex(subcateg, "dlo" );
|
||||
case 'P': return P | _subindex(subcateg, "cdefios");
|
||||
case 'S': return S | _subindex(subcateg, "ckmo" );
|
||||
case 'Z': return Z | _subindex(subcateg, "lps" );
|
||||
default: GGML_ABORT("invalid category character");
|
||||
}
|
||||
}
|
||||
|
||||
uint16_t encoded;
|
||||
};
|
||||
|
||||
size_t unicode_len_utf8(char src);
|
||||
|
@ -56,8 +190,8 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
|
|||
|
||||
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
|
||||
|
||||
codepoint_flags unicode_cpt_flags(const uint32_t cp);
|
||||
codepoint_flags unicode_cpt_flags(const std::string & utf8);
|
||||
codepoint_categ unicode_cpt_category(const uint32_t cp);
|
||||
codepoint_categ unicode_cpt_category(const std::string & utf8);
|
||||
|
||||
std::string unicode_byte_to_utf8(uint8_t byte);
|
||||
uint8_t unicode_utf8_to_byte(const std::string & utf8);
|
||||
|
|
|
@ -116,9 +116,24 @@ class LibLlamaModel:
|
|||
num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special)
|
||||
return str(cast(Buffer, self.ffi.buffer(self.text_buff, num)), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
|
||||
|
||||
def get_vocab(self, detokenize=False) -> list[str]:
|
||||
vocab: list[str] = []
|
||||
num_tokens = self.lib.llama_n_vocab(self.model)
|
||||
for id in range(num_tokens):
|
||||
if detokenize:
|
||||
text = self.detokenize([id], remove_special=False, unparse_special=True)
|
||||
else:
|
||||
text = self.lib.llama_token_get_text(self.model, id)
|
||||
text = str(cast(bytes, self.ffi.string(text)), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
|
||||
vocab.append(text)
|
||||
return vocab
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
|
||||
def get_vocab(self, detokenize=False) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
def encode(self, text: str) -> list[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -129,7 +144,7 @@ class Tokenizer:
|
|||
class TokenizerGroundtruth (Tokenizer):
|
||||
|
||||
def __init__(self, dir_tokenizer: str):
|
||||
self.model: PreTrainedTokenizer = AutoTokenizer.from_pretrained(dir_tokenizer)
|
||||
self.model: PreTrainedTokenizer = AutoTokenizer.from_pretrained(dir_tokenizer, trust_remote_code=False)
|
||||
# guess BOS and EOS
|
||||
ids = self.encode("a")
|
||||
assert 1 <= len(ids) <= 3
|
||||
|
@ -138,15 +153,25 @@ class TokenizerGroundtruth (Tokenizer):
|
|||
self.add_bos_token = getattr(self.model, "add_bos_token", add_bos_token)
|
||||
self.add_eos_token = getattr(self.model, "add_eos_token", add_eos_token)
|
||||
# build vocab
|
||||
tokens = list(self.model.get_vocab().values())
|
||||
self.vocab = self.model.batch_decode(tokens, skip_special_tokens=True)
|
||||
self.vocab = list(sorted(self.vocab))
|
||||
self.vocab = self.get_vocab(detokenize=True)
|
||||
# tokens and lists
|
||||
self.special_tokens = list(self.model.all_special_tokens)
|
||||
self.added_tokens = self.model.batch_decode(self.model.added_tokens_encoder.values(), skip_special_tokens=False)
|
||||
self.special_tokens = [self.vocab[i] for i in sorted(self.model.all_special_ids)]
|
||||
self.added_tokens = [self.vocab[i] for i in sorted(self.model.added_tokens_encoder.values())]
|
||||
self.bos_token = self.model.bos_token
|
||||
self.eos_token = self.model.eos_token
|
||||
|
||||
def get_vocab(self, detokenize=False) -> list[str]:
|
||||
vocab: list[str] = []
|
||||
max_token_id = max(self.model.get_vocab().values())
|
||||
if detokenize:
|
||||
ids = list(range(max_token_id + 1))
|
||||
vocab = self.model.batch_decode(ids, skip_special_tokens=False)
|
||||
else:
|
||||
vocab = [""] * (max_token_id + 1)
|
||||
for text, id in self.model.get_vocab().items():
|
||||
vocab[id] = text
|
||||
return vocab
|
||||
|
||||
def encode(self, text: str) -> list[int]:
|
||||
return self.model.encode(text, add_special_tokens=True)
|
||||
|
||||
|
@ -163,6 +188,9 @@ class TokenizerLlamaCpp (Tokenizer):
|
|||
self.libllama = LibLlama()
|
||||
self.model = LibLlamaModel(self.libllama, vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
|
||||
|
||||
def get_vocab(self, detokenize=False) -> list[str]:
|
||||
return self.model.get_vocab(detokenize)
|
||||
|
||||
def encode(self, text: str) -> list[int]:
|
||||
return self.model.tokenize(text, add_special=True, parse_special=True)
|
||||
|
||||
|
@ -253,6 +281,23 @@ def generator_vocab_words(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
|
|||
yield from tokenizer.vocab
|
||||
|
||||
|
||||
def generator_byte_tokens() -> Iterator[str]:
|
||||
"""Brute force check common byte encoding"""
|
||||
for a, b in ["<>", "[]", "()", ("\\", "")]:
|
||||
yield from [f"{a}{i}{b}" for i in range(256)]
|
||||
yield from [f"{a}{i:x}{b}" for i in range(256)]
|
||||
yield from [f"{a}{i:X}{b}" for i in range(256)]
|
||||
yield from [f"{a}x{i:x}{b}" for i in range(256)]
|
||||
yield from [f"{a}x{i:X}{b}" for i in range(256)]
|
||||
yield from [f"{a}x{i:02x}{b}" for i in range(256)]
|
||||
yield from [f"{a}x{i:02X}{b}" for i in range(256)]
|
||||
yield from [f"{a}0x{i:x}{b}" for i in range(256)]
|
||||
yield from [f"{a}0x{i:X}{b}" for i in range(256)]
|
||||
yield from [f"{a}0x{i:02x}{b}" for i in range(256)]
|
||||
yield from [f"{a}0x{i:02X}{b}" for i in range(256)]
|
||||
yield from [f"{a}{chr(i)}{b}" for i in range(256)]
|
||||
|
||||
|
||||
def generator_ascii_lr_strip() -> Iterator[str]:
|
||||
WHITESPACES = ["", " ", " "]
|
||||
CHARACTERS = list(chr(i) for i in range(1, 0x80)) + [""]
|
||||
|
@ -275,10 +320,11 @@ def generator_apostrophe() -> Iterator[str]:
|
|||
yield char1 + lstrip + "'" + rstrip + char2
|
||||
yield char1 + char2 + lstrip + "'" + rstrip + "z"
|
||||
yield "a" + lstrip + "'" + rstrip + char1 + char2
|
||||
yield "a" + lstrip + "'" + char1 + char2 + rstrip + "z"
|
||||
|
||||
|
||||
def generator_added_lr_strip(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
|
||||
WHITESPACES = ["", " ", " ", "\n", "\r\n", "\n\n", "\t", "\t\t"]
|
||||
WHITESPACES = ["", " ", " ", "\n", "\r\n", "\n\n", "\t", "\t\t", " "]
|
||||
all_tokens = list(sorted(set(tokenizer.special_tokens + tokenizer.added_tokens)))
|
||||
for token in all_tokens:
|
||||
for lstrip in WHITESPACES:
|
||||
|
@ -409,14 +455,6 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100
|
|||
|
||||
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
|
||||
|
||||
def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str):
|
||||
for i, (a, b) in enumerate(zip(ids1, ids2)):
|
||||
if a != b:
|
||||
return i
|
||||
if len(ids1) == len(ids2):
|
||||
return -1
|
||||
return min(len(ids1), len(ids2))
|
||||
|
||||
def check_detokenizer(text: str, text1: str, text2: str) -> bool:
|
||||
if text1 == text2: # equal to TokenizerGroundtruth?
|
||||
return True
|
||||
|
@ -434,9 +472,11 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
|
|||
t_decode1 = 0
|
||||
t_decode2 = 0
|
||||
t_start = time.perf_counter()
|
||||
total_tests = 0
|
||||
failing_texts = set()
|
||||
encode_errors = 0
|
||||
decode_errors = 0
|
||||
MAX_ERRORS = 10
|
||||
MAX_ERRORS = 5
|
||||
|
||||
logger.info("%s: %s" % (generator.__qualname__, "ini"))
|
||||
for text in generator:
|
||||
|
@ -455,31 +495,100 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
|
|||
t_encode2 += t2 - t1
|
||||
t_decode1 += t3 - t2
|
||||
t_decode2 += t4 - t3
|
||||
if encode_errors < MAX_ERRORS and ids1 != ids2:
|
||||
i = find_first_mismatch(ids1, ids2)
|
||||
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
|
||||
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
|
||||
logger.error(" Expected: " + str(ids1))
|
||||
logger.error(" Result: " + str(ids2))
|
||||
encode_errors += 1
|
||||
logger.error(f" {encode_errors=}")
|
||||
if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2):
|
||||
i = find_first_mismatch(text1, text2)
|
||||
text1 = list(text1[max(0, i - 2) : i + 5 + 1])
|
||||
text2 = list(text2[max(0, i - 2) : i + 5 + 1])
|
||||
logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1))
|
||||
logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2))
|
||||
decode_errors += 1
|
||||
logger.error(f" {decode_errors=}")
|
||||
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
|
||||
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
|
||||
# raise Exception()
|
||||
break
|
||||
total_tests += 1
|
||||
# compare
|
||||
encode_ok = ids1 == ids2
|
||||
decode_ok = check_detokenizer(text, text1, text2)
|
||||
if not (encode_ok and decode_ok):
|
||||
def _compare(text: str):
|
||||
ids1 = tokenizer1.encode(text)
|
||||
ids2 = tokenizer2.encode(text)
|
||||
text1 = tokenizer1.decode(ids1)
|
||||
text2 = tokenizer2.decode(ids1)
|
||||
encode_ok = ids1 == ids2
|
||||
decode_ok = check_detokenizer(text, text1, text2)
|
||||
ok = encode_ok and decode_ok
|
||||
return ok, ids1, ids2, text1, text2
|
||||
# binary search upper and lower failing range
|
||||
a, b = 0, len(text)
|
||||
step = b
|
||||
while step > 1:
|
||||
step = (step + 1) // 2
|
||||
t = max(a, b - step)
|
||||
if not _compare(text[a : t])[0]:
|
||||
b = t
|
||||
step = b
|
||||
while step > 1:
|
||||
step = (step + 1) // 2
|
||||
t = min(a + step, b)
|
||||
if not _compare(text[t : b])[0]:
|
||||
a = t
|
||||
ok, ids1, ids2, text1, text2 = _compare(text[a : b])
|
||||
assert a <= b and not ok
|
||||
# show unique failing texts differences
|
||||
failing_text = text[a : b]
|
||||
if failing_text not in failing_texts:
|
||||
failing_texts.add(failing_text)
|
||||
if encode_errors < MAX_ERRORS and not encode_ok:
|
||||
encode_errors += 1
|
||||
logger.error(f" {encode_errors=}")
|
||||
logger.error(" Text:" + repr(failing_text))
|
||||
logger.error(" " + " ".join(repr(x) + ":" + hex(ord(x)) for x in failing_text))
|
||||
logger.error(" Expected: " + str(ids1))
|
||||
logger.error(" Result: " + str(ids2))
|
||||
if decode_errors < MAX_ERRORS and not decode_ok:
|
||||
decode_errors += 1
|
||||
logger.error(f" {decode_errors=}")
|
||||
logger.error(" Text:" + repr(failing_text))
|
||||
logger.error(" " + " ".join(repr(x) + ":" + hex(ord(x)) for x in failing_text))
|
||||
logger.error(" Expected: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text1))
|
||||
logger.error(" Result: " + " ".join(repr(x) + ":" + hex(ord(x)) for x in text2))
|
||||
if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
|
||||
logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
|
||||
# raise Exception()
|
||||
break
|
||||
|
||||
t_total = time.perf_counter() - t_start
|
||||
logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
|
||||
|
||||
|
||||
def compare_vocabs(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp):
|
||||
|
||||
MAX_PRINT_ERRORS = 10
|
||||
|
||||
logger.info("compare_vocabs: ini")
|
||||
|
||||
t_start = time.perf_counter()
|
||||
|
||||
for detokenize in (False, True):
|
||||
vocab1 = tokenizer1.get_vocab(detokenize)
|
||||
vocab2 = tokenizer2.get_vocab(detokenize)
|
||||
if vocab1 != vocab2:
|
||||
num_errors = 0
|
||||
for i in range(max(len(vocab1), len(vocab2))):
|
||||
text1 = vocab1[i] if i < len(vocab1) else None
|
||||
text2 = vocab2[i] if i < len(vocab2) else None
|
||||
if text1 != text2:
|
||||
# is "[UNUSED_TOKEN_" and "[PAD" valid for all models ? #TODO: use toktypes
|
||||
if text1 is not None:
|
||||
text1 = text1.replace("[UNUSED_TOKEN_", "[PAD")
|
||||
if text2 is not None:
|
||||
text2 = text2.replace("[UNUSED_TOKEN_", "[PAD")
|
||||
if text1 is None and (text2 or "").startswith('[PAD'):
|
||||
text2 = None
|
||||
if text2 is None and (text1 or "").startswith('[PAD'):
|
||||
text1 = None
|
||||
if text1 != text2:
|
||||
num_errors += 1
|
||||
if num_errors < MAX_PRINT_ERRORS:
|
||||
logger.error(f" {detokenize=} id={i} expected={repr(text1)} result={repr(text2)}")
|
||||
if num_errors:
|
||||
logger.error(f" {num_errors=}")
|
||||
|
||||
t_total = time.perf_counter() - t_start
|
||||
logger.info(f"compare_vocabs: end, {t_total=:.3f}")
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file")
|
||||
|
@ -493,18 +602,21 @@ def main(argv: list[str] | None = None):
|
|||
tokenizer1 = TokenizerGroundtruth(args.dir_tokenizer)
|
||||
tokenizer2 = TokenizerLlamaCpp(args.vocab_file)
|
||||
|
||||
# compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text())
|
||||
# compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text_edge_cases())
|
||||
compare_vocabs(tokenizer1, tokenizer2)
|
||||
|
||||
compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text())
|
||||
compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text_edge_cases())
|
||||
compare_tokenizers(tokenizer1, tokenizer2, generator_byte_tokens())
|
||||
compare_tokenizers(tokenizer1, tokenizer2, generator_ascii_lr_strip())
|
||||
compare_tokenizers(tokenizer1, tokenizer2, generator_apostrophe())
|
||||
compare_tokenizers(tokenizer1, tokenizer2, generator_unicodes())
|
||||
compare_tokenizers(tokenizer1, tokenizer2, generator_vocab_words(tokenizer1))
|
||||
compare_tokenizers(tokenizer1, tokenizer2, generator_added_lr_strip(tokenizer1))
|
||||
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_added_tokens(tokenizer1, 10_000))
|
||||
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_chars(10_000))
|
||||
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_unicodes(10_000))
|
||||
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_chars(tokenizer1, 10_000))
|
||||
# compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_words(tokenizer1, 5_000))
|
||||
compare_tokenizers(tokenizer1, tokenizer2, generator_random_added_tokens(tokenizer1, 10_000))
|
||||
compare_tokenizers(tokenizer1, tokenizer2, generator_random_chars(10_000))
|
||||
compare_tokenizers(tokenizer1, tokenizer2, generator_random_unicodes(10_000))
|
||||
compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_chars(tokenizer1, 10_000))
|
||||
compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_words(tokenizer1, 5_000))
|
||||
|
||||
tokenizer2.model.free()
|
||||
|
||||
|
@ -533,21 +645,19 @@ if __name__ == "__main__":
|
|||
"phi-3", # SPM
|
||||
"gemma", # SPM
|
||||
"gemma-2", # SPM
|
||||
"baichuan", # SPM
|
||||
# "baichuan", # SPM
|
||||
"bert-bge", # WPM
|
||||
"jina-v2-en", # WPM
|
||||
# "t5", # UGM
|
||||
"llama-bpe", # BPE
|
||||
"phi-2", # BPE
|
||||
"deepseek-llm", # BPE
|
||||
"deepseek-coder", # BPE
|
||||
"falcon", # BPE
|
||||
"mpt", # BPE
|
||||
"starcoder", # BPE
|
||||
"gpt-2", # BPE
|
||||
"stablelm2", # BPE
|
||||
"refact", # BPE
|
||||
"qwen2", # BPE
|
||||
"olmo", # BPE
|
||||
"jina-v2-es", # BPE
|
||||
"jina-v2-de", # BPE
|
||||
"smaug-bpe", # BPE
|
||||
|
@ -555,6 +665,14 @@ if __name__ == "__main__":
|
|||
"jina-v2-code", # BPE
|
||||
"viking", # BPE
|
||||
"jais", # BPE
|
||||
"codeshell", # BPE
|
||||
"tekken", # BPE
|
||||
"smollm", # BPE
|
||||
"mpt", # BPE NFC
|
||||
"command-r", # BPE NFC
|
||||
"qwen2", # BPE NFC
|
||||
"olmo", # BPE NFC
|
||||
"gpt-neox", # BPE NFC
|
||||
]
|
||||
|
||||
logger.info("=" * 50)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue