From 88e943f9219b1162cc38a3c99c52afd0056cd2e0 Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Tue, 7 May 2024 14:12:55 +0200 Subject: [PATCH] fix: fix NFD computation --- llama.cpp | 3 ++- unicode-data.cpp | 8 ++++++ unicode-data.h | 1 + unicode.cpp | 70 +++++++++++++++++++++++++++++++++++++++++++----- unicode.h | 3 ++- 5 files changed, 76 insertions(+), 9 deletions(-) diff --git a/llama.cpp b/llama.cpp index aeb5c08df..177b928e3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12456,7 +12456,8 @@ struct llm_tokenizer_wpm { } std::vector preprocess(const std::string & text) { - std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); + auto unicode_cpts = unicode_cpts_from_utf8(text); + std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts); // strip accents, strip control, uniformize whitespace, // to lowercase, pad chinese characters, pad punctuation diff --git a/unicode-data.cpp b/unicode-data.cpp index 07bf02c45..b8aafc4ad 100644 --- a/unicode-data.cpp +++ b/unicode-data.cpp @@ -1691,3 +1691,11 @@ const std::map unicode_map_lowercase = { {0x1E917, 0x1E939}, {0x1E918, 0x1E93A}, {0x1E919, 0x1E93B}, {0x1E91A, 0x1E93C}, {0x1E91B, 0x1E93D}, {0x1E91C, 0x1E93E}, {0x1E91D, 0x1E93F}, {0x1E91E, 0x1E940}, {0x1E91F, 0x1E941}, {0x1E920, 0x1E942}, {0x1E921, 0x1E943}, }; + + +const std::map unicode_canonical_class = { + {65, 0}, // Example: Unicode point A has canonical class 0 + {769, 1}, // Example: Combining acute accent has canonical class 1 + {99, 0}, // Example: Unicode point c has canonical class 0 + {807, 1} // Example: Combining cedilla has canonical class 1 +}; \ No newline at end of file diff --git a/unicode-data.h b/unicode-data.h index 3cf84117c..794859935 100644 --- a/unicode-data.h +++ b/unicode-data.h @@ -14,3 +14,4 @@ extern const std::vector> unicode_ranges_symbol; extern const std::vector> unicode_ranges_control; extern const std::multimap unicode_map_nfd; extern const std::map unicode_map_lowercase; +extern const std::map unicode_canonical_class; diff --git a/unicode.cpp b/unicode.cpp index 955c56965..37a5fbde4 100644 --- a/unicode.cpp +++ b/unicode.cpp @@ -469,20 +469,76 @@ std::string unicode_cpt_to_utf8(uint32_t cp) { throw std::invalid_argument("invalid codepoint"); } -std::vector unicode_cpts_normalize_nfd(const std::vector & cpts) { +// Function to sort subsequences based on canonical class +std::vector sort_by_canonical_class(const std::vector & cpts) { + std::vector subsequence; std::vector result; - result.reserve(cpts.size()); - for (size_t i = 0; i < cpts.size(); ++i) { - auto it = unicode_map_nfd.find(cpts[i]); - if (it == unicode_map_nfd.end()) { - result.push_back(cpts[i]); + auto compareByCanonicalClass = [&](const uint32_t& a, const uint32_t& b) { + auto cc_a_it = unicode_canonical_class.find(a); + if (cc_a_it != unicode_canonical_class.end()) { + auto cc_b_it = unicode_canonical_class.find(b); + if (cc_b_it != unicode_canonical_class.end()) { + return cc_a_it->second < cc_b_it->second; + } + + } + return false; + }; + + for (const auto& cpt : cpts) { + auto it = unicode_canonical_class.find(cpt); + if (it != unicode_canonical_class.end()) { + if (it->second > 0) { + subsequence.push_back(cpt); + } else { + if (!subsequence.empty()) { + sort(subsequence.begin(), subsequence.end(), compareByCanonicalClass); + for (const auto& codepoint : subsequence) { + result.push_back(codepoint); + } + subsequence.clear(); + } + + result.push_back(cpt); + } + } + } + + if (!subsequence.empty()) { + sort(subsequence.begin(), subsequence.end(), compareByCanonicalClass); + for (const auto& codepoint : subsequence) { + result.push_back(codepoint); + } + } + + return result; +} + +std::vector canonical_decomposition_cpts(std::vector & cpts, const std::vector::iterator& cpt_begin, const std::vector::iterator& cpt_end) { + std::vector result; + for (auto cpt_it = cpt_begin; cpt_it != cpt_end; ++cpt_it) { + auto it = unicode_map_nfd.equal_range(*cpt_it); + if (it.first != it.second) { + uint offset = 0; + for (auto jt = it.first; jt != it.second; jt++) { + cpts.insert(cpt_it + offset, jt->second); + offset++; + } + const auto & inner_result = canonical_decomposition_cpts(cpts, cpt_it, cpt_end); + result.insert(result.end(), inner_result.begin(), inner_result.end()); + break; } else { - result.push_back(it->second); + result.push_back(*cpt_it); } } return result; } +std::vector unicode_cpts_normalize_nfd(std::vector & cpts) { + auto result = canonical_decomposition_cpts(cpts, cpts.begin(), cpts.end()); + return sort_by_canonical_class(result); +} + std::vector unicode_cpts_from_utf8(const std::string & utf8) { std::vector result; size_t offset = 0; diff --git a/unicode.h b/unicode.h index e9026dc81..923758008 100644 --- a/unicode.h +++ b/unicode.h @@ -16,7 +16,8 @@ std::string unicode_cpt_to_utf8(uint32_t cp); std::vector unicode_cpts_from_utf8(const std::string & utf8); -std::vector unicode_cpts_normalize_nfd(const std::vector & cpts); +std::vector unicode_cpts_normalize_nfd(std::vector & cpts); +std::vector canonical_decomposition_cpts(std::vector & cpts, const std::vector::iterator& cpt_begin, const std::vector::iterator& cpt_end); int unicode_cpt_type(uint32_t cp); int unicode_cpt_type(const std::string & utf8);