unicode : put nfd normalization behind API

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-03-11 13:19:55 +02:00
parent be12d8b12a
commit 6568c62bca
No known key found for this signature in database
GPG key ID: BF970631944C16B7
3 changed files with 37 additions and 43 deletions

View file

@ -9705,9 +9705,9 @@ private:
bpe_words.reserve(text.size());
bpe_encoded_words.reserve(text.size());
auto cps = unicode_cpts_from_utf8(text);
for (size_t i = 0; i < cps.size(); ++i)
text_utf.emplace_back(unicode_cpt_to_utf8(cps[i]));
const auto cpts = unicode_cpts_from_utf8(text);
for (size_t i = 0; i < cpts.size(); ++i)
text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i]));
for (int i = 0; i < (int)text_utf.size(); i++) {
const std::string & utf_char = text_utf[i];
@ -9893,25 +9893,12 @@ struct llm_tokenizer_wpm {
}
std::vector<std::string> preprocess(const std::string & text) {
// normalalization form D
std::vector<uint32_t> cpts = unicode_cpts_from_utf8(text);
std::vector<uint32_t> nfd_cpts;
const auto & nfd_map = unicode_nfd_map();
for (uint32_t code : cpts) {
auto it = nfd_map.equal_range(code);
if (it.first != it.second) {
for (auto jt = it.first; jt != it.second; jt++) {
nfd_cpts.push_back(jt->second);
}
} else {
nfd_cpts.push_back(code);
}
}
std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
// strip accents, strip control, uniformize whitespace,
// to lowercase, pad chinese characters, pad punctuation
std::string new_str = "";
for (uint32_t code : nfd_cpts) {
for (uint32_t code : cpts_nfd) {
int type = unicode_cpt_type(code);
if (type == CODEPOINT_TYPE_ACCENT_MARK || type == CODEPOINT_TYPE_CONTROL) {
continue;
@ -9940,8 +9927,7 @@ struct llm_tokenizer_wpm {
if (r > l) words.push_back(new_str.substr(l, (r - l)));
l = r + 1;
r = l;
}
else {
} else {
r += 1;
}
}

View file

@ -7,7 +7,7 @@
#include <unordered_map>
#include <vector>
static const std::vector<std::pair<uint32_t, uint32_t>> digit_ranges = {
static const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_digit = {
{0x00000030, 0x00000039}, {0x000000B2, 0x000000B3}, {0x000000B9, 0x000000B9}, {0x00000660, 0x00000669},
{0x000006F0, 0x000006F9}, {0x000007C0, 0x000007C9}, {0x00000966, 0x0000096F}, {0x000009E6, 0x000009EF},
{0x00000A66, 0x00000A6F}, {0x00000AE6, 0x00000AEF}, {0x00000B66, 0x00000B6F}, {0x00000BE6, 0x00000BEF},
@ -30,7 +30,7 @@ static const std::vector<std::pair<uint32_t, uint32_t>> digit_ranges = {
{0x0001E2F0, 0x0001E2F9}, {0x0001E950, 0x0001E959}, {0x0001F100, 0x0001F10A}, {0x0001FBF0, 0x0001FBF9},
};
static const std::vector<std::pair<uint32_t, uint32_t>> letter_ranges = {
static const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_letter = {
{0x00000041, 0x0000005A}, {0x00000061, 0x0000007A}, {0x000000AA, 0x000000AA}, {0x000000B5, 0x000000B5},
{0x000000BA, 0x000000BA}, {0x000000C0, 0x000000D6}, {0x000000D8, 0x000000F6}, {0x000000F8, 0x000002C1},
{0x000002C6, 0x000002D1}, {0x000002E0, 0x000002E4}, {0x000002EC, 0x000002EC}, {0x000002EE, 0x000002EE},
@ -189,13 +189,13 @@ static const std::vector<std::pair<uint32_t, uint32_t>> letter_ranges = {
{0x0002F800, 0x0002FA1D}, {0x00030000, 0x0003134A},
};
static const std::vector<std::pair<uint32_t, uint32_t>> whitespace_ranges = {
static const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace = {
{0x00000009, 0x0000000D}, {0x0000001C, 0x00000020}, {0x00000085, 0x00000085}, {0x000000A0, 0x000000A0},
{0x00001680, 0x00001680}, {0x00002000, 0x0000200A}, {0x00002028, 0x00002029}, {0x0000202F, 0x0000202F},
{0x0000205F, 0x0000205F}, {0x00003000, 0x00003000},
};
static const std::vector<std::pair<uint32_t, uint32_t>> accent_mark_ranges = {
static const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_accent_mark = {
{0x00000300, 0x0000036F}, {0x00000483, 0x00000489}, {0x00000591, 0x000005BD}, {0x000005BF, 0x000005BF},
{0x000005C1, 0x000005C2}, {0x000005C4, 0x000005C5}, {0x000005C7, 0x000005C7}, {0x00000610, 0x0000061A},
{0x0000064B, 0x0000065F}, {0x00000670, 0x00000670}, {0x000006D6, 0x000006DC}, {0x000006DF, 0x000006E4},
@ -271,7 +271,7 @@ static const std::vector<std::pair<uint32_t, uint32_t>> accent_mark_ranges = {
{0x0001E944, 0x0001E94A}, {0x000E0100, 0x000E01EF},
};
static const std::vector<std::pair<uint32_t, uint32_t>> punctuation_ranges = {
static const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_punctuation = {
{0x00000021, 0x00000023}, {0x00000025, 0x0000002A}, {0x0000002C, 0x0000002F}, {0x0000003A, 0x0000003B},
{0x0000003F, 0x00000040}, {0x0000005B, 0x0000005D}, {0x0000005F, 0x0000005F}, {0x0000007B, 0x0000007B},
{0x0000007D, 0x0000007D}, {0x000000A1, 0x000000A1}, {0x000000A7, 0x000000A7}, {0x000000AB, 0x000000AB},
@ -321,7 +321,7 @@ static const std::vector<std::pair<uint32_t, uint32_t>> punctuation_ranges = {
{0x0001E95E, 0x0001E95F},
};
static const std::vector<std::pair<uint32_t, uint32_t>> symbol_ranges = {
static const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol = {
{0x00000024, 0x00000024}, {0x0000002B, 0x0000002B}, {0x0000003C, 0x0000003E}, {0x0000005E, 0x0000005E},
{0x00000060, 0x00000060}, {0x0000007C, 0x0000007C}, {0x0000007E, 0x0000007E}, {0x000000A2, 0x000000A6},
{0x000000A8, 0x000000A9}, {0x000000AC, 0x000000AC}, {0x000000AE, 0x000000B1}, {0x000000B4, 0x000000B4},
@ -382,7 +382,7 @@ static const std::vector<std::pair<uint32_t, uint32_t>> symbol_ranges = {
{0x0001FB94, 0x0001FBCA},
};
static const std::vector<std::pair<uint32_t, uint32_t>> control_ranges = {
static const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control = {
{0x00000000, 0x00000008}, {0x0000000E, 0x0000001B}, {0x0000007F, 0x00000084}, {0x00000086, 0x0000009F},
{0x000000AD, 0x000000AD}, {0x00000378, 0x00000379}, {0x00000380, 0x00000383}, {0x0000038B, 0x0000038B},
{0x0000038D, 0x0000038D}, {0x000003A2, 0x000003A2}, {0x00000530, 0x00000530}, {0x00000557, 0x00000558},
@ -556,7 +556,7 @@ static const std::vector<std::pair<uint32_t, uint32_t>> control_ranges = {
{0x000E01F0, 0x0010FFFF},
};
static const std::multimap<uint32_t, uint32_t> nfd_map = {
static const std::multimap<uint32_t, uint32_t> unicode_map_nfd = {
{0x000000C0, 0x00000041}, {0x000000C0, 0x00000300}, {0x000000C1, 0x00000041}, {0x000000C1, 0x00000301},
{0x000000C2, 0x00000041}, {0x000000C2, 0x00000302}, {0x000000C3, 0x00000041}, {0x000000C3, 0x00000303},
{0x000000C4, 0x00000041}, {0x000000C4, 0x00000308}, {0x000000C5, 0x00000041}, {0x000000C5, 0x0000030A},
@ -1507,37 +1507,37 @@ static uint32_t cpt_from_utf16(const std::vector<uint16_t> & utf16, size_t & off
static std::unordered_map<uint32_t, int> unicode_cpt_type_map() {
std::unordered_map<uint32_t, int> cpt_types;
for (auto p : digit_ranges) {
for (auto p : unicode_ranges_digit) {
for (auto i = p.first; i <= p.second; ++ i) {
cpt_types[i] = CODEPOINT_TYPE_DIGIT;
}
}
for (auto p : letter_ranges) {
for (auto p : unicode_ranges_letter) {
for (auto i = p.first; i <= p.second; ++ i) {
cpt_types[i] = CODEPOINT_TYPE_LETTER;
}
}
for (auto p : whitespace_ranges) {
for (auto p : unicode_ranges_whitespace) {
for (auto i = p.first; i <= p.second; ++ i) {
cpt_types[i] = CODEPOINT_TYPE_WHITESPACE;
}
}
for (auto p : accent_mark_ranges) {
for (auto p : unicode_ranges_accent_mark) {
for (auto i = p.first; i <= p.second; ++ i) {
cpt_types[i] = CODEPOINT_TYPE_ACCENT_MARK;
}
}
for (auto p : punctuation_ranges) {
for (auto p : unicode_ranges_punctuation) {
for (auto i = p.first; i <= p.second; ++ i) {
cpt_types[i] = CODEPOINT_TYPE_PUNCTUATION;
}
}
for (auto p : symbol_ranges) {
for (auto p : unicode_ranges_symbol) {
for (auto i = p.first; i <= p.second; ++i) {
cpt_types[i] = CODEPOINT_TYPE_SYMBOL;
}
}
for (auto p : control_ranges) {
for (auto p : unicode_ranges_control) {
for (auto i = p.first; i <= p.second; ++ i) {
cpt_types[i] = CODEPOINT_TYPE_CONTROL;
}
@ -1597,10 +1597,6 @@ static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
// interface
//
const std::multimap<uint32_t, uint32_t> & unicode_nfd_map() {
return nfd_map;
}
std::string unicode_cpt_to_utf8(uint32_t cp) {
std::string result;
if (/* 0x00 <= cp && */ cp <= 0x7f) {
@ -1627,6 +1623,20 @@ std::string unicode_cpt_to_utf8(uint32_t cp) {
return result;
}
std::vector<uint32_t> unicode_cpts_normalize_nfd(std::vector<uint32_t> cpts) {
std::vector<uint32_t> result;
result.reserve(cpts.size());
for (size_t i = 0; i < cpts.size(); ++i) {
auto it = unicode_map_nfd.find(cpts[i]);
if (it == unicode_map_nfd.end()) {
result.push_back(cpts[i]);
} else {
result.push_back(it->second);
}
}
return result;
}
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
std::vector<uint32_t> result;
size_t offset = 0;

View file

@ -1,7 +1,6 @@
#pragma once
#include <string>
#include <map>
#include <vector>
#define CODEPOINT_TYPE_UNIDENTIFIED 0
@ -13,12 +12,11 @@
#define CODEPOINT_TYPE_SYMBOL 6
#define CODEPOINT_TYPE_CONTROL 7
// TODO: remove
const std::multimap<uint32_t, uint32_t> & unicode_nfd_map();
std::string unicode_cpt_to_utf8(uint32_t cp);
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
std::vector<uint32_t> unicode_cpts_normalize_nfd(std::vector<uint32_t> cpts);
int unicode_cpt_type(uint32_t cp);
int unicode_cpt_type(const std::string & utf8);