fix: fix NFD computation

This commit is contained in:
Joan Martinez 2024-05-07 14:12:55 +02:00
parent bcdee0daa7
commit 88e943f921
5 changed files with 76 additions and 9 deletions

View file

@ -12456,7 +12456,8 @@ struct llm_tokenizer_wpm {
} }
std::vector<std::string> preprocess(const std::string & text) { std::vector<std::string> preprocess(const std::string & text) {
std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); auto unicode_cpts = unicode_cpts_from_utf8(text);
std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts);
// strip accents, strip control, uniformize whitespace, // strip accents, strip control, uniformize whitespace,
// to lowercase, pad chinese characters, pad punctuation // to lowercase, pad chinese characters, pad punctuation

View file

@ -1691,3 +1691,11 @@ const std::map<char32_t, char32_t> unicode_map_lowercase = {
{0x1E917, 0x1E939}, {0x1E918, 0x1E93A}, {0x1E919, 0x1E93B}, {0x1E91A, 0x1E93C}, {0x1E91B, 0x1E93D}, {0x1E91C, 0x1E93E}, {0x1E917, 0x1E939}, {0x1E918, 0x1E93A}, {0x1E919, 0x1E93B}, {0x1E91A, 0x1E93C}, {0x1E91B, 0x1E93D}, {0x1E91C, 0x1E93E},
{0x1E91D, 0x1E93F}, {0x1E91E, 0x1E940}, {0x1E91F, 0x1E941}, {0x1E920, 0x1E942}, {0x1E921, 0x1E943}, {0x1E91D, 0x1E93F}, {0x1E91E, 0x1E940}, {0x1E91F, 0x1E941}, {0x1E920, 0x1E942}, {0x1E921, 0x1E943},
}; };
const std::map<uint32_t, uint32_t> unicode_canonical_class = {
{65, 0}, // Example: Unicode point A has canonical class 0
{769, 1}, // Example: Combining acute accent has canonical class 1
{99, 0}, // Example: Unicode point c has canonical class 0
{807, 1} // Example: Combining cedilla has canonical class 1
};

View file

@ -14,3 +14,4 @@ extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control; extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd; extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd;
extern const std::map<char32_t, char32_t> unicode_map_lowercase; extern const std::map<char32_t, char32_t> unicode_map_lowercase;
extern const std::map<uint32_t, uint32_t> unicode_canonical_class;

View file

@ -469,20 +469,76 @@ std::string unicode_cpt_to_utf8(uint32_t cp) {
throw std::invalid_argument("invalid codepoint"); throw std::invalid_argument("invalid codepoint");
} }
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) { // Function to sort subsequences based on canonical class
std::vector<uint32_t> sort_by_canonical_class(const std::vector<uint32_t> & cpts) {
std::vector<uint32_t> subsequence;
std::vector<uint32_t> result; std::vector<uint32_t> result;
result.reserve(cpts.size()); auto compareByCanonicalClass = [&](const uint32_t& a, const uint32_t& b) {
for (size_t i = 0; i < cpts.size(); ++i) { auto cc_a_it = unicode_canonical_class.find(a);
auto it = unicode_map_nfd.find(cpts[i]); if (cc_a_it != unicode_canonical_class.end()) {
if (it == unicode_map_nfd.end()) { auto cc_b_it = unicode_canonical_class.find(b);
result.push_back(cpts[i]); if (cc_b_it != unicode_canonical_class.end()) {
return cc_a_it->second < cc_b_it->second;
}
}
return false;
};
for (const auto& cpt : cpts) {
auto it = unicode_canonical_class.find(cpt);
if (it != unicode_canonical_class.end()) {
if (it->second > 0) {
subsequence.push_back(cpt);
} else { } else {
result.push_back(it->second); if (!subsequence.empty()) {
sort(subsequence.begin(), subsequence.end(), compareByCanonicalClass);
for (const auto& codepoint : subsequence) {
result.push_back(codepoint);
}
subsequence.clear();
}
result.push_back(cpt);
}
}
}
if (!subsequence.empty()) {
sort(subsequence.begin(), subsequence.end(), compareByCanonicalClass);
for (const auto& codepoint : subsequence) {
result.push_back(codepoint);
}
}
return result;
}
std::vector<uint32_t> canonical_decomposition_cpts(std::vector<uint32_t> & cpts, const std::vector<uint32_t>::iterator& cpt_begin, const std::vector<uint32_t>::iterator& cpt_end) {
std::vector<uint32_t> result;
for (auto cpt_it = cpt_begin; cpt_it != cpt_end; ++cpt_it) {
auto it = unicode_map_nfd.equal_range(*cpt_it);
if (it.first != it.second) {
uint offset = 0;
for (auto jt = it.first; jt != it.second; jt++) {
cpts.insert(cpt_it + offset, jt->second);
offset++;
}
const auto & inner_result = canonical_decomposition_cpts(cpts, cpt_it, cpt_end);
result.insert(result.end(), inner_result.begin(), inner_result.end());
break;
} else {
result.push_back(*cpt_it);
} }
} }
return result; return result;
} }
std::vector<uint32_t> unicode_cpts_normalize_nfd(std::vector<uint32_t> & cpts) {
auto result = canonical_decomposition_cpts(cpts, cpts.begin(), cpts.end());
return sort_by_canonical_class(result);
}
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) { std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
std::vector<uint32_t> result; std::vector<uint32_t> result;
size_t offset = 0; size_t offset = 0;

View file

@ -16,7 +16,8 @@
std::string unicode_cpt_to_utf8(uint32_t cp); std::string unicode_cpt_to_utf8(uint32_t cp);
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8); std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts); std::vector<uint32_t> unicode_cpts_normalize_nfd(std::vector<uint32_t> & cpts);
std::vector<uint32_t> canonical_decomposition_cpts(std::vector<uint32_t> & cpts, const std::vector<uint32_t>::iterator& cpt_begin, const std::vector<uint32_t>::iterator& cpt_end);
int unicode_cpt_type(uint32_t cp); int unicode_cpt_type(uint32_t cp);
int unicode_cpt_type(const std::string & utf8); int unicode_cpt_type(const std::string & utf8);