llama : refactor unicode stuff (#5992)
* llama : refactor unicode stuff ggml-ci * unicode : names * make : fix c++ compiler * unicode : names * unicode : straighten tables * zig : fix build * unicode : put nfd normalization behind API ggml-ci * swift : fix build * unicode : add BOM * unicode : add <cstdint> ggml-ci * unicode : pass as cpts as const ref
This commit is contained in:
parent
828defefb6
commit
83796e62bc
9 changed files with 1744 additions and 836 deletions
87
llama.cpp
87
llama.cpp
|
@ -3703,7 +3703,7 @@ static void llm_load_vocab(
|
|||
|
||||
for (int i = 0; i < n_merges; i++) {
|
||||
const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
|
||||
GGML_ASSERT(codepoints_from_utf8(word).size() > 0);
|
||||
GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
|
||||
|
||||
std::string first;
|
||||
std::string second;
|
||||
|
@ -3748,7 +3748,7 @@ static void llm_load_vocab(
|
|||
|
||||
for (uint32_t i = 0; i < n_vocab; i++) {
|
||||
std::string word = gguf_get_arr_str(ctx, token_idx, i);
|
||||
GGML_ASSERT(codepoints_from_utf8(word).size() > 0);
|
||||
GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
|
||||
|
||||
vocab.token_to_id[word] = i;
|
||||
|
||||
|
@ -9340,7 +9340,7 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
|
|||
}
|
||||
case LLAMA_VOCAB_TYPE_BPE: {
|
||||
GGML_ASSERT(false);
|
||||
return unicode_to_bytes_bpe(token_data.text);
|
||||
return unicode_utf8_to_byte(token_data.text);
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_WPM: {
|
||||
GGML_ASSERT(false);
|
||||
|
@ -9365,7 +9365,7 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
|
|||
}
|
||||
case LLAMA_VOCAB_TYPE_WPM:
|
||||
case LLAMA_VOCAB_TYPE_BPE: {
|
||||
return vocab.token_to_id.at(bytes_to_unicode_bpe(ch));
|
||||
return vocab.token_to_id.at(unicode_byte_to_utf8(ch));
|
||||
}
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
|
@ -9705,9 +9705,9 @@ private:
|
|||
bpe_words.reserve(text.size());
|
||||
bpe_encoded_words.reserve(text.size());
|
||||
|
||||
auto cps = codepoints_from_utf8(text);
|
||||
for (size_t i = 0; i < cps.size(); ++i)
|
||||
text_utf.emplace_back(codepoint_to_utf8(cps[i]));
|
||||
const auto cpts = unicode_cpts_from_utf8(text);
|
||||
for (size_t i = 0; i < cpts.size(); ++i)
|
||||
text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i]));
|
||||
|
||||
for (int i = 0; i < (int)text_utf.size(); i++) {
|
||||
const std::string & utf_char = text_utf[i];
|
||||
|
@ -9757,40 +9757,40 @@ private:
|
|||
}
|
||||
|
||||
if (!split_condition && !collecting) {
|
||||
if (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) {
|
||||
if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) {
|
||||
collecting_letter = true;
|
||||
collecting = true;
|
||||
}
|
||||
else if (codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
|
||||
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
|
||||
collecting_numeric = true;
|
||||
collecting = true;
|
||||
}
|
||||
else if (
|
||||
((codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (codepoint_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
|
||||
(!token.size() && utf_char == " " && codepoint_type(utf_char_next) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && codepoint_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
|
||||
((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
|
||||
(!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
|
||||
) {
|
||||
collecting_special = true;
|
||||
collecting = true;
|
||||
}
|
||||
else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && codepoint_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
|
||||
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
|
||||
collecting_whitespace_lookahead = true;
|
||||
collecting = true;
|
||||
}
|
||||
else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
|
||||
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
|
||||
split_condition = true;
|
||||
}
|
||||
}
|
||||
else if (!split_condition && collecting) {
|
||||
if (collecting_letter && codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER) {
|
||||
if (collecting_letter && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER) {
|
||||
split_condition = true;
|
||||
}
|
||||
else if (collecting_numeric && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) {
|
||||
else if (collecting_numeric && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) {
|
||||
split_condition = true;
|
||||
}
|
||||
else if (collecting_special && (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
|
||||
else if (collecting_special && (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
|
||||
split_condition = true;
|
||||
}
|
||||
else if (collecting_whitespace_lookahead && (codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
|
||||
else if (collecting_whitespace_lookahead && (unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
|
||||
split_condition = true;
|
||||
}
|
||||
}
|
||||
|
@ -9819,7 +9819,7 @@ private:
|
|||
for (std::string & word : bpe_words) {
|
||||
std::string encoded_token = "";
|
||||
for (char & c : word) {
|
||||
encoded_token += bytes_to_unicode_bpe(c);
|
||||
encoded_token += unicode_byte_to_utf8(c);
|
||||
}
|
||||
bpe_encoded_words.emplace_back(encoded_token);
|
||||
}
|
||||
|
@ -9893,25 +9893,13 @@ struct llm_tokenizer_wpm {
|
|||
}
|
||||
|
||||
std::vector<std::string> preprocess(const std::string & text) {
|
||||
// normalalization form D
|
||||
std::vector<uint32_t> codepoints = codepoints_from_utf8(text);
|
||||
std::vector<uint32_t> nfd_codepoints;
|
||||
for (uint32_t code : codepoints) {
|
||||
auto it = nfd_map.equal_range(code);
|
||||
if (it.first != it.second) {
|
||||
for (auto jt = it.first; jt != it.second; jt++) {
|
||||
nfd_codepoints.push_back(jt->second);
|
||||
}
|
||||
} else {
|
||||
nfd_codepoints.push_back(code);
|
||||
}
|
||||
}
|
||||
std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
|
||||
|
||||
// strip accents, strip control, uniformize whitespace,
|
||||
// to lowercase, pad chinese characters, pad punctuation
|
||||
std::string new_str = "";
|
||||
for (uint32_t code : nfd_codepoints) {
|
||||
int type = codepoint_type(code);
|
||||
for (uint32_t code : cpts_nfd) {
|
||||
int type = unicode_cpt_type(code);
|
||||
if (type == CODEPOINT_TYPE_ACCENT_MARK || type == CODEPOINT_TYPE_CONTROL) {
|
||||
continue;
|
||||
}
|
||||
|
@ -9919,7 +9907,7 @@ struct llm_tokenizer_wpm {
|
|||
if (type == CODEPOINT_TYPE_WHITESPACE) {
|
||||
code = ' ';
|
||||
}
|
||||
std::string s = codepoint_to_utf8(code);
|
||||
std::string s = unicode_cpt_to_utf8(code);
|
||||
if (type == CODEPOINT_TYPE_PUNCTUATION || is_ascii_punct(code) || is_chinese_char(code)) {
|
||||
new_str += " ";
|
||||
new_str += s;
|
||||
|
@ -9939,8 +9927,7 @@ struct llm_tokenizer_wpm {
|
|||
if (r > l) words.push_back(new_str.substr(l, (r - l)));
|
||||
l = r + 1;
|
||||
r = l;
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
r += 1;
|
||||
}
|
||||
}
|
||||
|
@ -9964,17 +9951,17 @@ struct llm_tokenizer_wpm {
|
|||
return code < 256 && ispunct(code);
|
||||
}
|
||||
|
||||
bool is_chinese_char(uint32_t codepoint) {
|
||||
if ((codepoint >= 0x4E00 && codepoint <= 0x9FFF) ||
|
||||
(codepoint >= 0x3400 && codepoint <= 0x4DBF) ||
|
||||
(codepoint >= 0x20000 && codepoint <= 0x2A6DF) ||
|
||||
(codepoint >= 0x2A700 && codepoint <= 0x2B73F) ||
|
||||
(codepoint >= 0x2B740 && codepoint <= 0x2B81F) ||
|
||||
(codepoint >= 0x2B920 && codepoint <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
|
||||
(codepoint >= 0xF900 && codepoint <= 0xFAFF) ||
|
||||
(codepoint >= 0x2F800 && codepoint <= 0x2FA1F) ||
|
||||
(codepoint >= 0x3000 && codepoint <= 0x303F) ||
|
||||
(codepoint >= 0xFF00 && codepoint <= 0xFFEF)) {
|
||||
bool is_chinese_char(uint32_t cpt) {
|
||||
if ((cpt >= 0x4E00 && cpt <= 0x9FFF) ||
|
||||
(cpt >= 0x3400 && cpt <= 0x4DBF) ||
|
||||
(cpt >= 0x20000 && cpt <= 0x2A6DF) ||
|
||||
(cpt >= 0x2A700 && cpt <= 0x2B73F) ||
|
||||
(cpt >= 0x2B740 && cpt <= 0x2B81F) ||
|
||||
(cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
|
||||
(cpt >= 0xF900 && cpt <= 0xFAFF) ||
|
||||
(cpt >= 0x2F800 && cpt <= 0x2FA1F) ||
|
||||
(cpt >= 0x3000 && cpt <= 0x303F) ||
|
||||
(cpt >= 0xFF00 && cpt <= 0xFFEF)) {
|
||||
return true; // NOLINT
|
||||
}
|
||||
return false;
|
||||
|
@ -13953,9 +13940,9 @@ int32_t llama_tokenize(
|
|||
|
||||
static std::string llama_decode_text(const std::string & text) {
|
||||
std::string decoded_text;
|
||||
auto unicode_sequences = codepoints_from_utf8(text);
|
||||
for (auto& unicode_sequence : unicode_sequences) {
|
||||
decoded_text += unicode_to_bytes_bpe(codepoint_to_utf8(unicode_sequence));
|
||||
auto unicode_sequences = unicode_cpts_from_utf8(text);
|
||||
for (auto & unicode_sequence : unicode_sequences) {
|
||||
decoded_text += unicode_utf8_to_byte(unicode_cpt_to_utf8(unicode_sequence));
|
||||
}
|
||||
|
||||
return decoded_text;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue