feat: remove extra complexity in NFD

This commit is contained in:
Joan Martinez 2024-05-08 10:12:10 +02:00
parent 043f298775
commit 668e0d9b73
6 changed files with 15 additions and 51 deletions

View file

@ -16,10 +16,17 @@ Feature: llama.cpp server
Then the server is starting Then the server is starting
Then the server is healthy Then the server is healthy
Scenario: Embedding
When embeddings are computed for:
"""
What is the capital of Bulgaria ?
"""
Then embeddings are generated
Scenario: Tokenize / Detokenize complex Scenario: Tokenize / Detokenize complex
When tokenizing: When tokenizing:
""" """
España is your's mine's l'heure èspciâl café über naïve résumé cañón élite cañas Barça ि ि ि ि ि, ि España is your's mine's l'heure èspciâl café über naïve résumé cañón élite cañas Barça
""" """
Then tokens can be detokenize and is equivalent False Then tokens can be detokenize and is equivalent False

View file

@ -12456,8 +12456,7 @@ struct llm_tokenizer_wpm {
} }
std::vector<std::string> preprocess(const std::string & text) { std::vector<std::string> preprocess(const std::string & text) {
auto unicode_cpts = unicode_cpts_from_utf8(text); std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts);
// strip accents, strip control, uniformize whitespace, // strip accents, strip control, uniformize whitespace,
// to lowercase, pad chinese characters, pad punctuation // to lowercase, pad chinese characters, pad punctuation

File diff suppressed because one or more lines are too long

View file

@ -14,4 +14,3 @@ extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control; extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd; extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd;
extern const std::map<char32_t, char32_t> unicode_map_lowercase; extern const std::map<char32_t, char32_t> unicode_map_lowercase;
extern const std::map<uint32_t, uint32_t> unicode_canonical_class;

View file

@ -13,7 +13,6 @@
#include <vector> #include <vector>
#include <locale> #include <locale>
#include <codecvt> #include <codecvt>
#include <algorithm>
static std::string unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) { static std::string unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
std::string result; std::string result;
@ -470,54 +469,21 @@ std::string unicode_cpt_to_utf8(uint32_t cp) {
throw std::invalid_argument("invalid codepoint"); throw std::invalid_argument("invalid codepoint");
} }
auto compareByCanonicalClass = [&](const uint32_t& a, const uint32_t& b) { std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
auto cc_a_it = unicode_canonical_class.find(a);
if (cc_a_it != unicode_canonical_class.end()) {
auto cc_b_it = unicode_canonical_class.find(b);
if (cc_b_it != unicode_canonical_class.end()) {
return cc_a_it->second < cc_b_it->second;
}
}
return false;
};
// Function to sort subsequences based on canonical class
std::vector<uint32_t> sort_by_canonical_class(std::vector<uint32_t> & cpts) {
// Sort the sequence using the custom comparator function
sort(cpts.begin(), cpts.end(), compareByCanonicalClass);
return cpts;
}
std::vector<uint32_t> canonical_decomposition_cpts(std::vector<uint32_t> & cpts, uint32_t starting_offset) {
std::vector<uint32_t> result; std::vector<uint32_t> result;
for (auto i = starting_offset; i < cpts.size(); i++) { for (uint32_t cpt : cpts) {
const auto& it = unicode_map_nfd.equal_range(cpts[i]); auto it = unicode_map_nfd.equal_range(cpt);
if (it.first != it.second) { if (it.first != it.second) {
uint offset = 0;
for (auto jt = it.first; jt != it.second; jt++) { for (auto jt = it.first; jt != it.second; jt++) {
if (offset == 0) { result.push_back(jt->second);
cpts[i] = jt->second;
} else {
cpts.emplace(cpts.begin() + i + offset, jt->second);
}
offset++;
} }
const auto & inner_result = canonical_decomposition_cpts(cpts, i);
result.insert(result.end(), inner_result.begin(), inner_result.end());
break;
} else { } else {
result.push_back(cpts[i]); result.push_back(cpt);
} }
} }
return result; return result;
} }
std::vector<uint32_t> unicode_cpts_normalize_nfd(std::vector<uint32_t> & cpts) {
auto result = canonical_decomposition_cpts(cpts, 0);
return sort_by_canonical_class(result);
}
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) { std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
std::vector<uint32_t> result; std::vector<uint32_t> result;
size_t offset = 0; size_t offset = 0;

View file

@ -16,9 +16,7 @@
std::string unicode_cpt_to_utf8(uint32_t cp); std::string unicode_cpt_to_utf8(uint32_t cp);
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8); std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
std::vector<uint32_t> unicode_cpts_normalize_nfd(std::vector<uint32_t> & cpts); std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
std::vector<uint32_t> canonical_decomposition_cpts(std::vector<uint32_t> & cpts, uint32_t starting_offset);
std::vector<uint32_t> sort_by_canonical_class(std::vector<uint32_t> & cpts);
int unicode_cpt_type(uint32_t cp); int unicode_cpt_type(uint32_t cp);
int unicode_cpt_type(const std::string & utf8); int unicode_cpt_type(const std::string & utf8);