llama : improve BERT tokenization (#5740)

* implement nfd for stripping accents in wpm tokenizer

* sort nfd map; reuse iterator

* use builtin tolower

* add locale include

* Simplify to_lower cases

Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>

---------

Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>
This commit is contained in:
Douglas Hanley 2024-02-28 02:51:11 -06:00 committed by GitHub
parent 6c4416868d
commit 177628bfd8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 305 additions and 94 deletions

137
llama.cpp
View file

@ -68,10 +68,12 @@
#include <cstdio>
#include <cstring>
#include <ctime>
#include <cwctype>
#include <forward_list>
#include <fstream>
#include <functional>
#include <initializer_list>
#include <locale>
#include <map>
#include <memory>
#include <mutex>
@ -8941,37 +8943,46 @@ struct llm_tokenizer_wpm {
}
std::vector<std::string> preprocess(const std::string & text) {
std::string ori_str = normalize(text);
uint64_t ori_size = ori_str.size();
// normalalization form D
std::vector<uint32_t> codepoints = codepoints_from_utf8(text);
std::vector<uint32_t> nfd_codepoints;
for (uint32_t code : codepoints) {
auto it = nfd_map.find(code);
if (it != nfd_map.end()) {
for (uint32_t c : it->second) {
nfd_codepoints.push_back(c);
}
} else {
nfd_codepoints.push_back(code);
}
}
// single punct / single symbol / single digit
// baseline: add whitespace on the left and right of punct and chinese characters
std::vector<std::string> words;
// strip accents, strip control, uniformize whitespace,
// to lowercase, pad chinese characters, pad punctuation
std::string new_str = "";
uint64_t i = 0;
while (i < ori_size) {
int utf_char_len = utf8_len(ori_str[i]);
if ((utf_char_len == 1) && ispunct(ori_str[i])) {
new_str += " ";
new_str += ori_str[i];
new_str += " ";
i += 1;
for (uint32_t code : nfd_codepoints) {
int type = codepoint_type(code);
if (type == CODEPOINT_TYPE_ACCENT_MARK || type == CODEPOINT_TYPE_CONTROL) {
continue;
}
else if ((utf_char_len == 3) && is_chinese_char(ori_str.substr(i, 3))) {
new_str += " ";
new_str += ori_str.substr(i, 3);
new_str += " ";
i += 3;
code = to_lower(code);
if (type == CODEPOINT_TYPE_WHITESPACE) {
code = ' ';
}
else {
new_str += ori_str[i];
i += 1;
std::string s = codepoint_to_utf8(code);
if (type == CODEPOINT_TYPE_PUNCTUATION || is_ascii_punct(code) || is_chinese_char(code)) {
new_str += " ";
new_str += s;
new_str += " ";
} else {
new_str += s;
}
}
// split by whitespace
uint64_t l = 0;
uint64_t r = 0;
std::vector<std::string> words;
while (r < new_str.size()) {
// if is whitespace
if (isspace(new_str[r])) {
@ -8989,47 +9000,20 @@ struct llm_tokenizer_wpm {
return words;
}
std::string normalize(const std::string & text) {
// TODO: handle chinese characters? https://github.com/huggingface/tokenizers/blob/ef5f50605ddf9f8caef1598c0e4853862b9707a7/tokenizers/src/normalizers/bert.rs#L98
std::string text2 = strip_accents(text);
for (size_t i = 0; i < text2.size(); i += utf8_len(text2[i])) {
char c = text2[i];
if (c >= 'A' && c <= 'Z') {
text2[i] = c - 'A' + 'a';
}
uint32_t to_lower(uint32_t code) {
#if defined(_WIN32)
if (code > 0xFFFF) {
return code;
}
return text2;
#endif
return std::tolower(wchar_t(code), std::locale("en_US.UTF-8"));
}
bool is_chinese_char(const std::string & str) {
int len = str.length();
unsigned int codepoint = 0;
int num_bytes = 0;
int i = 0;
unsigned char ch = static_cast<unsigned char>(str[i]);
if (ch <= 0x7f) {
codepoint = ch;
num_bytes = 1;
} else if ((ch >> 5) == 0x06) {
codepoint = ch & 0x1f;
num_bytes = 2;
} else if ((ch >> 4) == 0x0e) {
codepoint = ch & 0x0f;
num_bytes = 3;
} else if ((ch >> 3) == 0x1e) {
codepoint = ch & 0x07;
num_bytes = 4;
}
for (int j = 1; j < num_bytes; ++j) {
if (i + j >= len) {
return false; // incomplete UTF-8 character
}
unsigned char next_ch = static_cast<unsigned char>(str[i + j]);
if ((next_ch >> 6) != 0x02) {
return false; // invalid trailing byte
}
codepoint = (codepoint << 6) | (next_ch & 0x3f);
}
bool is_ascii_punct(uint32_t code) {
return code < 256 && ispunct(code);
}
bool is_chinese_char(uint32_t codepoint) {
if ((codepoint >= 0x4E00 && codepoint <= 0x9FFF) ||
(codepoint >= 0x3400 && codepoint <= 0x4DBF) ||
(codepoint >= 0x20000 && codepoint <= 0x2A6DF) ||
@ -9045,41 +9029,6 @@ struct llm_tokenizer_wpm {
return false;
}
std::string strip_accents(const std::string & input_string) {
std::string resultString;
std::map<std::string, char> accent_map = {
{"À", 'A'}, {"Á", 'A'}, {"Â", 'A'}, {"Ã", 'A'}, {"Ä", 'A'}, {"Å", 'A'},
{"à", 'a'}, {"á", 'a'}, {"â", 'a'}, {"ã", 'a'}, {"ä", 'a'}, {"å", 'a'},
{"È", 'E'}, {"É", 'E'}, {"Ê", 'E'}, {"Ë", 'E'}, {"è", 'e'}, {"é", 'e'},
{"ê", 'e'}, {"ë", 'e'}, {"Ì", 'I'}, {"Í", 'I'}, {"Î", 'I'}, {"Ï", 'I'},
{"ì", 'i'}, {"í", 'i'}, {"î", 'i'}, {"ï", 'i'}, {"Ò", 'O'}, {"Ó", 'O'},
{"Ô", 'O'}, {"Õ", 'O'}, {"Ö", 'O'}, {"ò", 'o'}, {"ó", 'o'}, {"ô", 'o'},
{"õ", 'o'}, {"ö", 'o'}, {"Ù", 'U'}, {"Ú", 'U'}, {"Û", 'U'}, {"Ü", 'U'},
{"ù", 'u'}, {"ú", 'u'}, {"û", 'u'}, {"ü", 'u'}, {"Ý", 'Y'}, {"ý", 'y'},
{"Ç", 'C'}, {"ç", 'c'}, {"Ñ", 'N'}, {"ñ", 'n'},
};
for (size_t i = 0; i < input_string.length();) {
int len = utf8_len(input_string[i]);
std::string curChar = input_string.substr(i, len);
auto iter = accent_map.find(curChar);
if (iter != accent_map.end()) {
resultString += iter->second;
} else {
resultString += curChar;
}
i += len;
}
return resultString;
}
static size_t utf8_len(char src) {
const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4};
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
return lookup[highbits];
}
const llama_vocab & vocab;
};