From 6568c62bca315105968be5347ef1e95af9c20149 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 11 Mar 2024 13:19:55 +0200 Subject: [PATCH] unicode : put nfd normalization behind API ggml-ci --- llama.cpp | 26 ++++++-------------------- unicode.cpp | 48 +++++++++++++++++++++++++++++------------------- unicode.h | 6 ++---- 3 files changed, 37 insertions(+), 43 deletions(-) diff --git a/llama.cpp b/llama.cpp index 76f44aa45..9a0d28c89 100644 --- a/llama.cpp +++ b/llama.cpp @@ -9705,9 +9705,9 @@ private: bpe_words.reserve(text.size()); bpe_encoded_words.reserve(text.size()); - auto cps = unicode_cpts_from_utf8(text); - for (size_t i = 0; i < cps.size(); ++i) - text_utf.emplace_back(unicode_cpt_to_utf8(cps[i])); + const auto cpts = unicode_cpts_from_utf8(text); + for (size_t i = 0; i < cpts.size(); ++i) + text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i])); for (int i = 0; i < (int)text_utf.size(); i++) { const std::string & utf_char = text_utf[i]; @@ -9893,25 +9893,12 @@ struct llm_tokenizer_wpm { } std::vector preprocess(const std::string & text) { - // normalalization form D - std::vector cpts = unicode_cpts_from_utf8(text); - std::vector nfd_cpts; - const auto & nfd_map = unicode_nfd_map(); - for (uint32_t code : cpts) { - auto it = nfd_map.equal_range(code); - if (it.first != it.second) { - for (auto jt = it.first; jt != it.second; jt++) { - nfd_cpts.push_back(jt->second); - } - } else { - nfd_cpts.push_back(code); - } - } + std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); // strip accents, strip control, uniformize whitespace, // to lowercase, pad chinese characters, pad punctuation std::string new_str = ""; - for (uint32_t code : nfd_cpts) { + for (uint32_t code : cpts_nfd) { int type = unicode_cpt_type(code); if (type == CODEPOINT_TYPE_ACCENT_MARK || type == CODEPOINT_TYPE_CONTROL) { continue; @@ -9940,8 +9927,7 @@ struct llm_tokenizer_wpm { if (r > l) words.push_back(new_str.substr(l, (r - l))); l = r + 1; r = l; - } - else { + } else { r += 1; } } diff --git a/unicode.cpp b/unicode.cpp index 5eab0806b..01c7c1391 100644 --- a/unicode.cpp +++ b/unicode.cpp @@ -7,7 +7,7 @@ #include #include -static const std::vector> digit_ranges = { +static const std::vector> unicode_ranges_digit = { {0x00000030, 0x00000039}, {0x000000B2, 0x000000B3}, {0x000000B9, 0x000000B9}, {0x00000660, 0x00000669}, {0x000006F0, 0x000006F9}, {0x000007C0, 0x000007C9}, {0x00000966, 0x0000096F}, {0x000009E6, 0x000009EF}, {0x00000A66, 0x00000A6F}, {0x00000AE6, 0x00000AEF}, {0x00000B66, 0x00000B6F}, {0x00000BE6, 0x00000BEF}, @@ -30,7 +30,7 @@ static const std::vector> digit_ranges = { {0x0001E2F0, 0x0001E2F9}, {0x0001E950, 0x0001E959}, {0x0001F100, 0x0001F10A}, {0x0001FBF0, 0x0001FBF9}, }; -static const std::vector> letter_ranges = { +static const std::vector> unicode_ranges_letter = { {0x00000041, 0x0000005A}, {0x00000061, 0x0000007A}, {0x000000AA, 0x000000AA}, {0x000000B5, 0x000000B5}, {0x000000BA, 0x000000BA}, {0x000000C0, 0x000000D6}, {0x000000D8, 0x000000F6}, {0x000000F8, 0x000002C1}, {0x000002C6, 0x000002D1}, {0x000002E0, 0x000002E4}, {0x000002EC, 0x000002EC}, {0x000002EE, 0x000002EE}, @@ -189,13 +189,13 @@ static const std::vector> letter_ranges = { {0x0002F800, 0x0002FA1D}, {0x00030000, 0x0003134A}, }; -static const std::vector> whitespace_ranges = { +static const std::vector> unicode_ranges_whitespace = { {0x00000009, 0x0000000D}, {0x0000001C, 0x00000020}, {0x00000085, 0x00000085}, {0x000000A0, 0x000000A0}, {0x00001680, 0x00001680}, {0x00002000, 0x0000200A}, {0x00002028, 0x00002029}, {0x0000202F, 0x0000202F}, {0x0000205F, 0x0000205F}, {0x00003000, 0x00003000}, }; -static const std::vector> accent_mark_ranges = { +static const std::vector> unicode_ranges_accent_mark = { {0x00000300, 0x0000036F}, {0x00000483, 0x00000489}, {0x00000591, 0x000005BD}, {0x000005BF, 0x000005BF}, {0x000005C1, 0x000005C2}, {0x000005C4, 0x000005C5}, {0x000005C7, 0x000005C7}, {0x00000610, 0x0000061A}, {0x0000064B, 0x0000065F}, {0x00000670, 0x00000670}, {0x000006D6, 0x000006DC}, {0x000006DF, 0x000006E4}, @@ -271,7 +271,7 @@ static const std::vector> accent_mark_ranges = { {0x0001E944, 0x0001E94A}, {0x000E0100, 0x000E01EF}, }; -static const std::vector> punctuation_ranges = { +static const std::vector> unicode_ranges_punctuation = { {0x00000021, 0x00000023}, {0x00000025, 0x0000002A}, {0x0000002C, 0x0000002F}, {0x0000003A, 0x0000003B}, {0x0000003F, 0x00000040}, {0x0000005B, 0x0000005D}, {0x0000005F, 0x0000005F}, {0x0000007B, 0x0000007B}, {0x0000007D, 0x0000007D}, {0x000000A1, 0x000000A1}, {0x000000A7, 0x000000A7}, {0x000000AB, 0x000000AB}, @@ -321,7 +321,7 @@ static const std::vector> punctuation_ranges = { {0x0001E95E, 0x0001E95F}, }; -static const std::vector> symbol_ranges = { +static const std::vector> unicode_ranges_symbol = { {0x00000024, 0x00000024}, {0x0000002B, 0x0000002B}, {0x0000003C, 0x0000003E}, {0x0000005E, 0x0000005E}, {0x00000060, 0x00000060}, {0x0000007C, 0x0000007C}, {0x0000007E, 0x0000007E}, {0x000000A2, 0x000000A6}, {0x000000A8, 0x000000A9}, {0x000000AC, 0x000000AC}, {0x000000AE, 0x000000B1}, {0x000000B4, 0x000000B4}, @@ -382,7 +382,7 @@ static const std::vector> symbol_ranges = { {0x0001FB94, 0x0001FBCA}, }; -static const std::vector> control_ranges = { +static const std::vector> unicode_ranges_control = { {0x00000000, 0x00000008}, {0x0000000E, 0x0000001B}, {0x0000007F, 0x00000084}, {0x00000086, 0x0000009F}, {0x000000AD, 0x000000AD}, {0x00000378, 0x00000379}, {0x00000380, 0x00000383}, {0x0000038B, 0x0000038B}, {0x0000038D, 0x0000038D}, {0x000003A2, 0x000003A2}, {0x00000530, 0x00000530}, {0x00000557, 0x00000558}, @@ -556,7 +556,7 @@ static const std::vector> control_ranges = { {0x000E01F0, 0x0010FFFF}, }; -static const std::multimap nfd_map = { +static const std::multimap unicode_map_nfd = { {0x000000C0, 0x00000041}, {0x000000C0, 0x00000300}, {0x000000C1, 0x00000041}, {0x000000C1, 0x00000301}, {0x000000C2, 0x00000041}, {0x000000C2, 0x00000302}, {0x000000C3, 0x00000041}, {0x000000C3, 0x00000303}, {0x000000C4, 0x00000041}, {0x000000C4, 0x00000308}, {0x000000C5, 0x00000041}, {0x000000C5, 0x0000030A}, @@ -1507,37 +1507,37 @@ static uint32_t cpt_from_utf16(const std::vector & utf16, size_t & off static std::unordered_map unicode_cpt_type_map() { std::unordered_map cpt_types; - for (auto p : digit_ranges) { + for (auto p : unicode_ranges_digit) { for (auto i = p.first; i <= p.second; ++ i) { cpt_types[i] = CODEPOINT_TYPE_DIGIT; } } - for (auto p : letter_ranges) { + for (auto p : unicode_ranges_letter) { for (auto i = p.first; i <= p.second; ++ i) { cpt_types[i] = CODEPOINT_TYPE_LETTER; } } - for (auto p : whitespace_ranges) { + for (auto p : unicode_ranges_whitespace) { for (auto i = p.first; i <= p.second; ++ i) { cpt_types[i] = CODEPOINT_TYPE_WHITESPACE; } } - for (auto p : accent_mark_ranges) { + for (auto p : unicode_ranges_accent_mark) { for (auto i = p.first; i <= p.second; ++ i) { cpt_types[i] = CODEPOINT_TYPE_ACCENT_MARK; } } - for (auto p : punctuation_ranges) { + for (auto p : unicode_ranges_punctuation) { for (auto i = p.first; i <= p.second; ++ i) { cpt_types[i] = CODEPOINT_TYPE_PUNCTUATION; } } - for (auto p : symbol_ranges) { + for (auto p : unicode_ranges_symbol) { for (auto i = p.first; i <= p.second; ++i) { cpt_types[i] = CODEPOINT_TYPE_SYMBOL; } } - for (auto p : control_ranges) { + for (auto p : unicode_ranges_control) { for (auto i = p.first; i <= p.second; ++ i) { cpt_types[i] = CODEPOINT_TYPE_CONTROL; } @@ -1597,10 +1597,6 @@ static std::unordered_map unicode_utf8_to_byte_map() { // interface // -const std::multimap & unicode_nfd_map() { - return nfd_map; -} - std::string unicode_cpt_to_utf8(uint32_t cp) { std::string result; if (/* 0x00 <= cp && */ cp <= 0x7f) { @@ -1627,6 +1623,20 @@ std::string unicode_cpt_to_utf8(uint32_t cp) { return result; } +std::vector unicode_cpts_normalize_nfd(std::vector cpts) { + std::vector result; + result.reserve(cpts.size()); + for (size_t i = 0; i < cpts.size(); ++i) { + auto it = unicode_map_nfd.find(cpts[i]); + if (it == unicode_map_nfd.end()) { + result.push_back(cpts[i]); + } else { + result.push_back(it->second); + } + } + return result; +} + std::vector unicode_cpts_from_utf8(const std::string & utf8) { std::vector result; size_t offset = 0; diff --git a/unicode.h b/unicode.h index ac5c66bba..f6c226194 100644 --- a/unicode.h +++ b/unicode.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include #define CODEPOINT_TYPE_UNIDENTIFIED 0 @@ -13,12 +12,11 @@ #define CODEPOINT_TYPE_SYMBOL 6 #define CODEPOINT_TYPE_CONTROL 7 -// TODO: remove -const std::multimap & unicode_nfd_map(); - std::string unicode_cpt_to_utf8(uint32_t cp); std::vector unicode_cpts_from_utf8(const std::string & utf8); +std::vector unicode_cpts_normalize_nfd(std::vector cpts); + int unicode_cpt_type(uint32_t cp); int unicode_cpt_type(const std::string & utf8);