unicode : support \p{N}, \p{L} and \p{P} natively

This commit is contained in:
Georgi Gerganov 2024-04-27 17:48:38 +03:00
parent ce5485aee0
commit 91eaa414bf
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
7 changed files with 94 additions and 26 deletions

View file

@ -1678,6 +1678,18 @@ std::vector<std::string> string_split(std::string input, char separator) {
return parts;
}
std::string string_strip(const std::string & str) {
size_t start = 0;
size_t end = str.size();
while (start < end && std::isspace(str[start])) {
start++;
}
while (end > start && std::isspace(str[end - 1])) {
end--;
}
return str.substr(start, end - start);
}
std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
{"top_k", llama_sampler_type::TOP_K},

View file

@ -193,6 +193,7 @@ bool validate_file_name(const std::string & filename);
std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & names_string);
std::vector<std::string> string_split(std::string input, char separator);
std::string string_strip(const std::string & str);
std::string sampler_type_to_name_string(llama_sampler_type sampler_type);
//

View file

@ -12036,13 +12036,6 @@ struct llm_tokenizer_bpe {
// adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
// TODO: this is not the same as the original regex:
// - need to use ReFlex and update unicode.cpp to support the regex above
// - or implement a custom function similar to unicode_gpt2_regex_preprocess()
//"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
//"\\p{N}+",
//"[0-9][0-9][0-9]"
});
break;
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:

View file

@ -111,7 +111,7 @@ if fname_tok:
# f.write(str(x) + ' \' ' + tokenizer.decode(x) + '\'\n')
# else:
# f.write(str(x) + ' \'' + tokenizer.decode(x) + '\'\n')
f.write(str(x) + ' \'' + tokenizer.decode(x) + '\'\n')
f.write(str(x) + ' \'' + tokenizer.decode(x).strip() + '\'\n')
print('len(res): ', len(res))
print('len(lines): ', len(lines))
print('results written to: ', fname_out)

View file

@ -183,7 +183,7 @@ int main(int argc, char **argv) {
}
for (const auto & tok : res) {
ofs << tok << " '" << llama_detokenize_bpe(ctx, std::vector<int>{tok}) << "'" << std::endl;
ofs << tok << " '" << string_strip(llama_detokenize_bpe(ctx, std::vector<int>{tok})) << "'" << std::endl;
}
}

File diff suppressed because one or more lines are too long

View file

@ -494,19 +494,81 @@ char32_t unicode_tolower(char32_t cp) {
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
std::wstring wtext = unicode_wstring_from_utf8(text);
// compute collapsed codepoints only if needed by at least one regex
bool need_collapse = false;
for (auto & regex_expr : regex_exprs) {
// search for \\p{L} or \\p{N}
if (std::string::npos != regex_expr.find("\\p{N}") ||
std::string::npos != regex_expr.find("\\p{L}") ||
std::string::npos != regex_expr.find("\\p{P}")) {
need_collapse = true;
break;
}
}
std::wstring wtext_collapsed;
if (need_collapse) {
// collapse all digit, letter and punctuation cpts to a single codepoint
// the collapsed codepoint is selected to be the one at the end of the range
//
// - convert text to cpts
// - collapse digit cpts to 0x0001FBF9
// - collapse letter cpts to 0x0003134A
// - collapse punctuation cpts to 0x0001E95F
// - convert back to text
auto cpts = unicode_cpts_from_utf8(text);
for (size_t i = 0; i < cpts.size(); ++i) {
if (unicode_cpt_type(cpts[i]) == CODEPOINT_TYPE_DIGIT) {
cpts[i] = 0x0001FBF9;
} else if (unicode_cpt_type(cpts[i]) == CODEPOINT_TYPE_LETTER) {
cpts[i] = 0x0003134A;
} else if (unicode_cpt_type(cpts[i]) == CODEPOINT_TYPE_PUNCTUATION) {
cpts[i] = 0x0001E95F;
}
}
wtext_collapsed = unicode_wstring_from_utf8(unicode_cpts_to_utf8(cpts));
}
std::vector<size_t> bpe_offsets = {wtext.size()};
for (auto & regex_expr : regex_exprs) {
if (unicode_regex_with_custom_preprocessor_exists(regex_expr)) {
bpe_offsets = unicode_regex_custom_preprocess(regex_expr, wtext, bpe_offsets);
} else if (unicode_regex_equivalent_wregex_exists(regex_expr)) {
const std::wstring & wregex_expr = unicode_regex_equivalent_wregex.at(regex_expr);
bpe_offsets = unicode_regex_preprocess(wtext, bpe_offsets, wregex_expr);
//} else if (unicode_regex_equivalent_wregex_exists(regex_expr)) {
// const std::wstring & wregex_expr = unicode_regex_equivalent_wregex.at(regex_expr);
// bpe_offsets = unicode_regex_preprocess(wtext, bpe_offsets, wregex_expr);
} else {
// fallback
try {
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
bpe_offsets = unicode_regex_preprocess(wtext, bpe_offsets, wregex_expr);
// if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
// with the corresponding collapsed codepoint
if (std::string::npos != regex_expr.find("\\p{N}") ||
std::string::npos != regex_expr.find("\\p{L}") ||
std::string::npos != regex_expr.find("\\p{P}")) {
// replace \\p{N} with \U0001FBF9
// replace \\p{L} with \U0003134A
// replace \\p{P} with \U0001E95F
std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
for (size_t i = 0; i < wregex_expr.size(); ++i) {
if (wregex_expr[i] == L'\\' && i + 1 < wregex_expr.size()) {
if (wregex_expr[i + 1] == L'p' && i + 3 < wregex_expr.size()) {
if (wregex_expr[i + 2] == L'{' && wregex_expr[i + 4] == L'}') {
if (wregex_expr[i + 3] == L'N') {
wregex_expr.replace(i, 5, L"\U0001FBF9");
} else if (wregex_expr[i + 3] == L'L') {
wregex_expr.replace(i, 5, L"\U0003134A");
} else if (wregex_expr[i + 3] == L'P') {
wregex_expr.replace(i, 5, L"\U0001E95F");
}
}
}
}
}
bpe_offsets = unicode_regex_preprocess(wtext_collapsed, bpe_offsets, wregex_expr);
} else {
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
bpe_offsets = unicode_regex_preprocess(wtext, bpe_offsets, wregex_expr);
}
} catch (std::regex_error & e) {
fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str());
fprintf(stderr, "Regex error: %s\n", e.what());